Redis: Check server version when connecting

This commit is contained in:
Tim Wojtulewicz 2025-04-30 10:45:06 -07:00
parent 58d71d2fa3
commit ecd603516f
2 changed files with 69 additions and 9 deletions

View file

@ -2,6 +2,8 @@
#include "zeek/storage/backend/redis/Redis.h"
#include <algorithm>
#include "zeek/DebugLogger.h"
#include "zeek/Func.h"
#include "zeek/RunState.h"
@ -57,7 +59,7 @@ void redisErase(redisAsyncContext* ctx, void* reply, void* privdata) {
}
void redisZADD(redisAsyncContext* ctx, void* reply, void* privdata) {
auto t = Tracer("generic");
auto t = Tracer("zadd");
auto backend = static_cast<zeek::storage::backend::redis::Redis*>(ctx->data);
// We don't care about the reply from the ZADD, mostly because blocking to poll
@ -73,6 +75,12 @@ void redisGeneric(redisAsyncContext* ctx, void* reply, void* privdata) {
backend->HandleGeneric(static_cast<redisReply*>(reply));
}
void redisINFO(redisAsyncContext* ctx, void* reply, void* privdata) {
auto t = Tracer("generic");
auto backend = static_cast<zeek::storage::backend::redis::Redis*>(ctx->data);
backend->HandleInfoResult(static_cast<redisReply*>(reply));
}
// Because we called redisPollAttach in DoOpen(), privdata here is a
// redisPollEvents object. We can go through that object to get the context's
// data, which contains the backend. Because we overrode these callbacks in
@ -137,6 +145,8 @@ std::unique_lock<std::mutex> conditionally_lock(bool condition, std::mutex& mute
namespace zeek::storage::backend::redis {
constexpr char REQUIRED_VERSION[] = "6.2.0";
storage::BackendPtr Redis::Instantiate() { return make_intrusive<Redis>(); }
/**
@ -493,19 +503,67 @@ void Redis::HandleGeneric(redisReply* reply) {
reply_queue.push_back(reply);
}
void Redis::HandleInfoResult(redisReply* reply) {
DBG_LOG(DBG_STORAGE, "Redis backend: info event");
--active_ops;
auto lines = util::split(std::string{reply->str}, "\r\n");
OperationResult res = {ReturnCode::CONNECTION_FAILED};
if ( lines.empty() )
res.err_str = "INFO command return zero entries";
else {
std::string_view version_sv{REQUIRED_VERSION};
for ( const auto& e : lines ) {
// Skip empty lines and comments
if ( e.empty() || e[0] == '#' )
continue;
// We only care about the redis_version entry. Skip anything else.
if ( ! util::starts_with(e, "redis_version:") )
continue;
auto splits = util::split(e, ':');
DBG_LOG(DBG_STORAGE, "Redis backend: found server version %s", splits[1].c_str());
if ( std::lexicographical_compare(splits[1].begin(), splits[1].end(), version_sv.begin(),
version_sv.end()) )
res.err_str = util::fmt("Redis server version is too low: Found %s, need %s", splits[1].c_str(),
REQUIRED_VERSION);
else {
connected = true;
res.code = ReturnCode::SUCCESS;
}
}
}
if ( ! connected && res.err_str.empty() )
res.err_str = "INFO command did not return server version";
freeReplyObject(reply);
CompleteCallback(open_cb, res);
}
void Redis::OnConnect(int status) {
DBG_LOG(DBG_STORAGE, "Redis backend: connection event");
--active_ops;
if ( status == REDIS_OK ) {
connected = true;
CompleteCallback(open_cb, {ReturnCode::SUCCESS});
// The connection_established event is sent via the open callback handler.
return;
}
connected = false;
CompleteCallback(open_cb, {ReturnCode::CONNECTION_FAILED});
if ( status == REDIS_OK ) {
// Request the INFO block from the server that should contain the version information.
status = redisAsyncCommand(async_ctx, redisINFO, NULL, "INFO server");
if ( status == REDIS_ERR ) {
// TODO: do something with the error?
DBG_LOG(DBG_STORAGE, "INFO command failed: %s", async_ctx->errstr);
CompleteCallback(open_cb,
{ReturnCode::OPERATION_FAILED,
util::fmt("INFO command failed to retrieve server info: %s", async_ctx->errstr)});
return;
}
++active_ops;
}
// TODO: we could attempt to reconnect here
}

View file

@ -41,6 +41,7 @@ public:
void HandleGetResult(redisReply* reply, ResultCallback* callback);
void HandleEraseResult(redisReply* reply, ResultCallback* callback);
void HandleGeneric(redisReply* reply);
void HandleInfoResult(redisReply* reply);
/**
* Returns whether the backend is opened.
@ -60,6 +61,7 @@ private:
void DoPoll() override;
OperationResult ParseReplyError(std::string_view op_str, std::string_view reply_err_str) const;
OperationResult CheckServerVersion();
redisAsyncContext* async_ctx = nullptr;