Store all mappings in a single map instead of split by type

This opens up the possibility of storing other request types outside
of T_A, T_PTR and T_TXT without requiring redoing the caching. It
also fixes the caching code in DNS_Mapping, adding a version number
to the start of the cache file so the cache structure can be modified
and old caches invalidated more easily.
This commit is contained in:
Tim Wojtulewicz 2022-02-07 12:41:03 -07:00
parent fb59239f41
commit e8f833b8a6
5 changed files with 261 additions and 154 deletions

View file

@ -1,16 +1,20 @@
#include "zeek/DNS_Mapping.h" #include "zeek/DNS_Mapping.h"
#include <ares_nameser.h>
#include "zeek/3rdparty/doctest.h" #include "zeek/3rdparty/doctest.h"
#include "zeek/DNS_Mgr.h" #include "zeek/DNS_Mgr.h"
#include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail
{ {
DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl) DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type)
{ {
Init(h); Init(h);
req_host = host; req_host = host;
req_ttl = ttl; req_ttl = ttl;
req_type = type;
if ( names.empty() ) if ( names.empty() )
names.push_back(std::move(host)); names.push_back(std::move(host));
@ -21,6 +25,7 @@ DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl)
Init(h); Init(h);
req_addr = addr; req_addr = addr;
req_ttl = ttl; req_ttl = ttl;
req_type = T_PTR;
} }
DNS_Mapping::DNS_Mapping(FILE* f) DNS_Mapping::DNS_Mapping(FILE* f)
@ -45,7 +50,7 @@ DNS_Mapping::DNS_Mapping(FILE* f)
int num_addrs; int num_addrs;
if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf,
&failed_local, name_buf, &map_type, &num_addrs, &req_ttl) != 8 ) &failed_local, name_buf, &req_type, &num_addrs, &req_ttl) != 8 )
{ {
no_mapping = true; no_mapping = true;
return; return;
@ -126,7 +131,6 @@ void DNS_Mapping::Init(struct hostent* h)
return; return;
} }
map_type = h->h_addrtype;
if ( h->h_name ) if ( h->h_name )
// for now, just use the official name // for now, just use the official name
// TODO: this could easily be expanded to include all of the aliases as well // TODO: this could easily be expanded to include all of the aliases as well
@ -153,7 +157,7 @@ void DNS_Mapping::Clear()
addrs.clear(); addrs.clear();
addrs_val = nullptr; addrs_val = nullptr;
no_mapping = false; no_mapping = false;
map_type = 0; req_type = 0;
failed = true; failed = true;
} }
@ -161,7 +165,7 @@ void DNS_Mapping::Save(FILE* f) const
{ {
fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(), fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(),
req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed, req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed,
names.empty() ? "*" : names[0].c_str(), map_type, addrs.size(), req_ttl); names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl);
for ( const auto& addr : addrs ) for ( const auto& addr : addrs )
fprintf(f, "%s\n", addr.AsString().c_str()); fprintf(f, "%s\n", addr.AsString().c_str());
@ -173,13 +177,38 @@ void DNS_Mapping::Merge(DNS_Mapping* other)
std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs)); std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs));
} }
// This value needs to be incremented if something changes in the data stored by Save(). This
// allows us to change the structure of the cache without breaking something in DNS_Mgr.
constexpr int FILE_VERSION = 1;
void DNS_Mapping::InitializeCache(FILE* f)
{
fprintf(f, "%d\n", FILE_VERSION);
}
bool DNS_Mapping::ValidateCacheVersion(FILE* f)
{
char buf[512];
if ( ! fgets(buf, sizeof(buf), f) )
return false;
int version;
if ( sscanf(buf, "%d", &version) != 1 )
{
reporter->Warning("Existing DNS cache did not have correct version, ignoring");
return false;
}
return FILE_VERSION == version;
}
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
TEST_CASE("dns_mapping init null hostent") TEST_CASE("dns_mapping init null hostent")
{ {
DNS_Mapping mapping(std::string("www.apple.com"), nullptr, 123); DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A);
CHECK(! mapping.Valid()); CHECK(! mapping.Valid());
CHECK(mapping.Addrs() == nullptr); CHECK(mapping.Addrs() == nullptr);
@ -202,7 +231,7 @@ TEST_CASE("dns_mapping init host")
std::vector<in_addr*> addrs = {&in4, NULL}; std::vector<in_addr*> addrs = {&in4, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data()); he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping(std::string("testing.home"), &he, 123); DNS_Mapping mapping("testing.home", &he, 123, T_A);
CHECK(mapping.Valid()); CHECK(mapping.Valid());
CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified); CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified);
CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0); CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0);
@ -347,7 +376,7 @@ TEST_CASE("dns_mapping multiple addresses")
std::vector<in_addr*> addrs = {&in4_1, &in4_2, NULL}; std::vector<in_addr*> addrs = {&in4_1, &in4_2, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data()); he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping(std::string("testing.home"), &he, 123); DNS_Mapping mapping("testing.home", &he, 123, T_A);
CHECK(mapping.Valid()); CHECK(mapping.Valid());
auto lva = mapping.Addrs(); auto lva = mapping.Addrs();

View file

@ -15,7 +15,7 @@ class DNS_Mapping
{ {
public: public:
DNS_Mapping() = delete; DNS_Mapping() = delete;
DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl); DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type);
DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl); DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl);
DNS_Mapping(FILE* f); DNS_Mapping(FILE* f);
@ -29,6 +29,7 @@ public:
const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); } const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); }
const IPAddr& ReqAddr() const { return req_addr; } const IPAddr& ReqAddr() const { return req_addr; }
std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; } std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; }
int ReqType() const { return req_type; }
ListValPtr Addrs(); ListValPtr Addrs();
TableValPtr AddrsSet(); // addresses returned as a set TableValPtr AddrsSet(); // addresses returned as a set
@ -50,10 +51,11 @@ public:
return util::current_time() > (creation_time + req_ttl); return util::current_time() > (creation_time + req_ttl);
} }
int Type() const { return map_type; }
void Merge(DNS_Mapping* other); void Merge(DNS_Mapping* other);
static void InitializeCache(FILE* f);
static bool ValidateCacheVersion(FILE* f);
protected: protected:
friend class DNS_Mgr; friend class DNS_Mgr;
@ -63,6 +65,7 @@ protected:
std::string req_host; std::string req_host;
IPAddr req_addr; IPAddr req_addr;
uint32_t req_ttl = 0; uint32_t req_ttl = 0;
int req_type = 0;
// This class supports multiple names per address, but we only store one of them. // This class supports multiple names per address, but we only store one of them.
std::vector<std::string> names; std::vector<std::string> names;
@ -72,7 +75,6 @@ protected:
ListValPtr addrs_val; ListValPtr addrs_val;
double creation_time = 0.0; double creation_time = 0.0;
int map_type = AF_UNSPEC;
bool no_mapping = false; // when initializing from a file, immediately hit EOF bool no_mapping = false; // when initializing from a file, immediately hit EOF
bool init_failed = false; bool init_failed = false;
bool failed = false; bool failed = false;

View file

@ -49,6 +49,130 @@ constexpr int DNS_TIMEOUT = 5;
// The maximum allowed number of pending asynchronous requests. // The maximum allowed number of pending asynchronous requests.
constexpr int MAX_PENDING_REQUESTS = 20; constexpr int MAX_PENDING_REQUESTS = 20;
// This unfortunately doesn't exist in c-ares, even though it seems rather useful.
static const char* request_type_string(int request_type)
{
switch ( request_type )
{
case T_A:
return "T_A";
case T_NS:
return "T_NS";
case T_MD:
return "T_MD";
case T_MF:
return "T_MF";
case T_CNAME:
return "T_CNAME";
case T_SOA:
return "T_SOA";
case T_MB:
return "T_MB";
case T_MG:
return "T_MG";
case T_MR:
return "T_MR";
case T_NULL:
return "T_NULL";
case T_WKS:
return "T_WKS";
case T_PTR:
return "T_PTR";
case T_HINFO:
return "T_HINFO";
case T_MINFO:
return "T_MINFO";
case T_MX:
return "T_MX";
case T_TXT:
return "T_TXT";
case T_RP:
return "T_RP";
case T_AFSDB:
return "T_AFSDB";
case T_X25:
return "T_X25";
case T_ISDN:
return "T_ISDN";
case T_RT:
return "T_RT";
case T_NSAP:
return "T_NSAP";
case T_NSAP_PTR:
return "T_NSAP_PTR";
case T_SIG:
return "T_SIG";
case T_KEY:
return "T_KEY";
case T_PX:
return "T_PX";
case T_GPOS:
return "T_GPOS";
case T_AAAA:
return "T_AAAA";
case T_LOC:
return "T_LOC";
case T_NXT:
return "T_NXT";
case T_EID:
return "T_EID";
case T_NIMLOC:
return "T_NIMLOC";
case T_SRV:
return "T_SRV";
case T_ATMA:
return "T_ATMA";
case T_NAPTR:
return "T_NAPTR";
case T_KX:
return "T_KX";
case T_CERT:
return "T_CERT";
case T_A6:
return "T_A6";
case T_DNAME:
return "T_DNAME";
case T_SINK:
return "T_SINK";
case T_OPT:
return "T_OPT";
case T_APL:
return "T_APL";
case T_DS:
return "T_DS";
case T_SSHFP:
return "T_SSHFP";
case T_RRSIG:
return "T_RRSIG";
case T_NSEC:
return "T_NSEC";
case T_DNSKEY:
return "T_DNSKEY";
case T_TKEY:
return "T_TKEY";
case T_TSIG:
return "T_TSIG";
case T_IXFR:
return "T_IXFR";
case T_AXFR:
return "T_AXFR";
case T_MAILB:
return "T_MAILB";
case T_MAILA:
return "T_MAILA";
case T_ANY:
return "T_ANY";
case T_URI:
return "T_URI";
case T_CAA:
return "T_CAA";
case T_MAX:
return "T_MAX";
default:
return "";
}
}
namespace zeek::detail namespace zeek::detail
{ {
static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result); static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result);
@ -90,6 +214,10 @@ uint16_t DNS_Request::request_id = 0;
DNS_Request::DNS_Request(std::string host, int request_type, bool async) DNS_Request::DNS_Request(std::string host, int request_type, bool async)
: host(std::move(host)), request_type(request_type), async(async) : host(std::move(host)), request_type(request_type), async(async)
{ {
// We combine the T_A and T_AAAA requests together in one request, so set the type
// to T_A to make things easier in other parts of the code (mostly around lookups).
if ( request_type == T_AAAA )
request_type = T_A;
} }
DNS_Request::DNS_Request(const IPAddr& addr, bool async) : addr(addr), async(async) DNS_Request::DNS_Request(const IPAddr& addr, bool async) : addr(addr), async(async)
@ -113,7 +241,7 @@ void DNS_Request::MakeRequest(ares_channel channel, DNS_Mgr* mgr)
// all of them would be in flight at the same time. // all of them would be in flight at the same time.
DNS_Request::request_id++; DNS_Request::request_id++;
if ( request_type == T_A || request_type == T_AAAA ) if ( request_type == T_A )
{ {
// For A/AAAA requests, we use a different method than the other requests. Since // For A/AAAA requests, we use a different method than the other requests. Since
// we're using the AF_UNSPEC family, we get both the ipv4 and ipv6 responses // we're using the AF_UNSPEC family, we get both the ipv4 and ipv6 responses
@ -149,7 +277,7 @@ void DNS_Request::ProcessAsyncResult(bool timed_out, DNS_Mgr* mgr)
if ( ! async ) if ( ! async )
return; return;
if ( request_type == T_A || request_type == T_AAAA ) if ( request_type == T_A )
mgr->CheckAsyncHostRequest(host, timed_out); mgr->CheckAsyncHostRequest(host, timed_out);
else if ( request_type == T_PTR ) else if ( request_type == T_PTR )
mgr->CheckAsyncAddrRequest(addr, timed_out); mgr->CheckAsyncAddrRequest(addr, timed_out);
@ -385,7 +513,8 @@ static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, in
} }
default: default:
reporter->Error("Requests of type %d are unsupported", req->RequestType()); reporter->Error("Requests of type %d (%s) are unsupported", req->RequestType(),
request_type_string(req->RequestType()));
break; break;
} }
} }
@ -531,9 +660,9 @@ static TableValPtr fake_name_lookup_result(const std::string& name)
return hv->ToSetVal(); return hv->ToSetVal();
} }
static std::string fake_text_lookup_result(const std::string name) static std::string fake_lookup_result(const std::string& name, int request_type)
{ {
return util::fmt("fake_text_lookup_result_%s", name.c_str()); return util::fmt("fake_lookup_result_%s_%s", request_type_string(request_type), name.c_str());
} }
static std::string fake_addr_lookup_result(const IPAddr& addr) static std::string fake_addr_lookup_result(const IPAddr& addr)
@ -558,14 +687,14 @@ ValPtr DNS_Mgr::Lookup(const std::string& name, int request_type)
if ( request_type == T_A || request_type == T_AAAA ) if ( request_type == T_A || request_type == T_AAAA )
return LookupHost(name); return LookupHost(name);
if ( mode == DNS_FAKE && request_type == T_TXT ) if ( mode == DNS_FAKE )
return make_intrusive<StringVal>(fake_text_lookup_result(name)); return make_intrusive<StringVal>(fake_lookup_result(name, request_type));
InitSource(); InitSource();
if ( mode != DNS_PRIME && request_type == T_TXT ) if ( mode != DNS_PRIME )
{ {
if ( auto val = LookupTextInCache(name, false) ) if ( auto val = LookupOtherInCache(name, request_type, false) )
return val; return val;
} }
@ -579,8 +708,8 @@ ValPtr DNS_Mgr::Lookup(const std::string& name, int request_type)
} }
case DNS_FORCE: case DNS_FORCE:
reporter->FatalError("can't find DNS entry for %s (req type %d) in cache", name.c_str(), reporter->FatalError("can't find DNS entry for %s (req type %d / %s) in cache",
request_type); name.c_str(), request_type, request_type_string(request_type));
return nullptr; return nullptr;
case DNS_DEFAULT: case DNS_DEFAULT:
@ -783,12 +912,12 @@ void DNS_Mgr::Lookup(const std::string& name, int request_type, LookupCallback*
if ( mode == DNS_FAKE ) if ( mode == DNS_FAKE )
{ {
resolve_lookup_cb(callback, fake_text_lookup_result(name)); resolve_lookup_cb(callback, fake_lookup_result(name, request_type));
return; return;
} }
// Do we already know the answer? // Do we already know the answer?
if ( auto txt = LookupTextInCache(name, true) ) if ( auto txt = LookupOtherInCache(name, request_type, true) )
{ {
resolve_lookup_cb(callback, txt->CheckString()); resolve_lookup_cb(callback, txt->CheckString());
return; return;
@ -885,98 +1014,50 @@ void DNS_Mgr::AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool m
DNS_Mapping* prev_mapping = nullptr; DNS_Mapping* prev_mapping = nullptr;
bool keep_prev = true; bool keep_prev = true;
if ( ! dr->Host().empty() ) MappingMap::iterator it;
{ if ( dr->RequestType() == T_PTR )
new_mapping = new DNS_Mapping(dr->Host(), h, ttl);
if ( dr->IsTxt() )
{
TextMap::iterator it = text_mappings.find(dr->Host());
if ( it == text_mappings.end() )
{
auto result = text_mappings.emplace(dr->Host(), new_mapping);
it = result.first;
}
else
prev_mapping = it->second;
if ( prev_mapping && prev_mapping->Valid() )
{
if ( new_mapping->Valid() )
{
if ( merge )
new_mapping->Merge(prev_mapping);
it->second = new_mapping;
keep_prev = false;
}
}
else
{
it->second = new_mapping;
keep_prev = false;
}
}
else
{
HostMap::iterator it = host_mappings.find(dr->Host());
if ( it == host_mappings.end() )
{
auto result = host_mappings.emplace(dr->Host(), new_mapping);
it = result.first;
}
else
prev_mapping = it->second;
if ( prev_mapping && prev_mapping->Valid() )
{
if ( new_mapping->Valid() )
{
if ( merge )
new_mapping->Merge(prev_mapping);
it->second = new_mapping;
keep_prev = false;
}
}
else
{
it->second = new_mapping;
keep_prev = false;
}
}
}
else
{ {
new_mapping = new DNS_Mapping(dr->Addr(), h, ttl); new_mapping = new DNS_Mapping(dr->Addr(), h, ttl);
it = all_mappings.find(dr->Addr());
AddrMap::iterator it = addr_mappings.find(dr->Addr()); if ( it == all_mappings.end() )
if ( it == addr_mappings.end() )
{ {
auto result = addr_mappings.emplace(dr->Addr(), new_mapping); auto result = all_mappings.emplace(dr->Addr(), new_mapping);
it = result.first; it = result.first;
} }
else else
prev_mapping = it->second; prev_mapping = it->second;
}
else
{
new_mapping = new DNS_Mapping(dr->Host(), h, ttl, dr->RequestType());
auto key = std::make_pair(dr->RequestType(), dr->Host());
if ( prev_mapping && prev_mapping->Valid() ) it = all_mappings.find(key);
if ( it == all_mappings.end() )
{ {
if ( new_mapping->Valid() ) auto result = all_mappings.emplace(std::move(key), new_mapping);
{ it = result.first;
if ( merge )
new_mapping->Merge(prev_mapping);
it->second = new_mapping;
keep_prev = false;
}
} }
else else
prev_mapping = it->second;
}
if ( prev_mapping && prev_mapping->Valid() )
{
if ( new_mapping->Valid() )
{ {
if ( merge )
new_mapping->Merge(prev_mapping);
it->second = new_mapping; it->second = new_mapping;
keep_prev = false; keep_prev = false;
} }
} }
else
{
it->second = new_mapping;
keep_prev = false;
}
if ( prev_mapping && ! dr->IsTxt() ) if ( prev_mapping && ! dr->IsTxt() )
CompareMappings(prev_mapping, new_mapping); CompareMappings(prev_mapping, new_mapping);
@ -1065,14 +1146,17 @@ void DNS_Mgr::LoadCache(const std::string& path)
if ( ! f ) if ( ! f )
return; return;
if ( ! DNS_Mapping::ValidateCacheVersion(f) )
return;
// Loop until we find a mapping that doesn't initialize correctly. // Loop until we find a mapping that doesn't initialize correctly.
DNS_Mapping* m = new DNS_Mapping(f); DNS_Mapping* m = new DNS_Mapping(f);
for ( ; ! m->NoMapping() && ! m->InitFailed(); m = new DNS_Mapping(f) ) for ( ; ! m->NoMapping() && ! m->InitFailed(); m = new DNS_Mapping(f) )
{ {
if ( m->ReqHost() ) if ( m->ReqHost() )
host_mappings.insert_or_assign(m->ReqHost(), m); all_mappings.insert_or_assign(std::make_pair(m->ReqType(), m->ReqHost()), m);
else else
addr_mappings.insert_or_assign(m->ReqAddr(), m); all_mappings.insert_or_assign(m->ReqAddr(), m);
} }
if ( ! m->NoMapping() ) if ( ! m->NoMapping() )
@ -1092,38 +1176,28 @@ bool DNS_Mgr::Save()
if ( ! f ) if ( ! f )
return false; return false;
Save(f, host_mappings); DNS_Mapping::InitializeCache(f);
Save(f, addr_mappings); Save(f, all_mappings);
// Save(f, text_mappings); // We don't save the TXT mappings (yet?).
fclose(f); fclose(f);
return true; return true;
} }
void DNS_Mgr::Save(FILE* f, const AddrMap& m) void DNS_Mgr::Save(FILE* f, const MappingMap& m)
{ {
for ( AddrMap::const_iterator it = m.begin(); it != m.end(); ++it ) for ( const auto& [key, mapping] : m )
{ {
if ( it->second ) if ( mapping )
it->second->Save(f); mapping->Save(f);
}
}
void DNS_Mgr::Save(FILE* f, const HostMap& m)
{
for ( HostMap::const_iterator it = m.begin(); it != m.end(); ++it )
{
if ( it->second )
it->second->Save(f);
} }
} }
TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_expired, TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_expired,
bool check_failed) bool check_failed)
{ {
HostMap::iterator it = host_mappings.find(name); auto it = all_mappings.find(std::make_pair(T_A, name));
if ( it == host_mappings.end() ) if ( it == all_mappings.end() )
return nullptr; return nullptr;
DNS_Mapping* d = it->second; DNS_Mapping* d = it->second;
@ -1133,7 +1207,7 @@ TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_exp
if ( cleanup_expired && (d && d->Expired()) ) if ( cleanup_expired && (d && d->Expired()) )
{ {
host_mappings.erase(it); all_mappings.erase(it);
delete d; delete d;
return nullptr; return nullptr;
} }
@ -1149,15 +1223,15 @@ TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_exp
StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired, bool check_failed) StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired, bool check_failed)
{ {
AddrMap::iterator it = addr_mappings.find(addr); auto it = all_mappings.find(addr);
if ( it == addr_mappings.end() ) if ( it == all_mappings.end() )
return nullptr; return nullptr;
DNS_Mapping* d = it->second; DNS_Mapping* d = it->second;
if ( cleanup_expired && d->Expired() ) if ( cleanup_expired && d->Expired() )
{ {
addr_mappings.erase(it); all_mappings.erase(it);
delete d; delete d;
return nullptr; return nullptr;
} }
@ -1174,17 +1248,18 @@ StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired
return make_intrusive<StringVal>("<\?\?\?>"); return make_intrusive<StringVal>("<\?\?\?>");
} }
StringValPtr DNS_Mgr::LookupTextInCache(const std::string& name, bool cleanup_expired) StringValPtr DNS_Mgr::LookupOtherInCache(const std::string& name, int request_type,
bool cleanup_expired)
{ {
TextMap::iterator it = text_mappings.find(name); auto it = all_mappings.find(std::make_pair(request_type, name));
if ( it == text_mappings.end() ) if ( it == all_mappings.end() )
return nullptr; return nullptr;
DNS_Mapping* d = it->second; DNS_Mapping* d = it->second;
if ( cleanup_expired && d->Expired() ) if ( cleanup_expired && d->Expired() )
{ {
text_mappings.erase(it); all_mappings.erase(it);
delete d; delete d;
return nullptr; return nullptr;
} }
@ -1291,7 +1366,7 @@ void DNS_Mgr::CheckAsyncTextRequest(const std::string& host, bool timeout)
++failed; ++failed;
i->second->Timeout(); i->second->Timeout();
} }
else if ( auto name = LookupTextInCache(host, true) ) else if ( auto name = LookupOtherInCache(host, T_TXT, true) )
{ {
++successful; ++successful;
i->second->Resolved(name->CheckString()); i->second->Resolved(name->CheckString());
@ -1309,18 +1384,10 @@ void DNS_Mgr::Flush()
{ {
Resolve(); Resolve();
for ( HostMap::iterator it = host_mappings.begin(); it != host_mappings.end(); ++it ) for ( auto& [key, mapping] : all_mappings )
delete it->second; delete mapping;
for ( AddrMap::iterator it2 = addr_mappings.begin(); it2 != addr_mappings.end(); ++it2 ) all_mappings.clear();
delete it2->second;
for ( TextMap::iterator it3 = text_mappings.begin(); it3 != text_mappings.end(); ++it3 )
delete it3->second;
host_mappings.clear();
addr_mappings.clear();
text_mappings.clear();
} }
double DNS_Mgr::GetNextTimeout() double DNS_Mgr::GetNextTimeout()
@ -1357,9 +1424,20 @@ void DNS_Mgr::GetStats(Stats* stats)
stats->successful = successful; stats->successful = successful;
stats->failed = failed; stats->failed = failed;
stats->pending = asyncs_pending; stats->pending = asyncs_pending;
stats->cached_hosts = host_mappings.size();
stats->cached_addresses = addr_mappings.size(); stats->cached_hosts = 0;
stats->cached_texts = text_mappings.size(); stats->cached_addresses = 0;
stats->cached_texts = 0;
for ( const auto& [key, mapping] : all_mappings )
{
if ( mapping->ReqType() == T_PTR )
stats->cached_addresses++;
else if ( mapping->ReqType() == T_A )
stats->cached_hosts++;
else
stats->cached_texts++;
}
} }
void DNS_Mgr::AsyncRequest::Resolved(const std::string& name) void DNS_Mgr::AsyncRequest::Resolved(const std::string& name)

View file

@ -7,6 +7,7 @@
#include <map> #include <map>
#include <queue> #include <queue>
#include <utility> #include <utility>
#include <variant>
#include "zeek/EventHandler.h" #include "zeek/EventHandler.h"
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
@ -248,7 +249,8 @@ protected:
bool check_failed = false); bool check_failed = false);
TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false,
bool check_failed = false); bool check_failed = false);
StringValPtr LookupTextInCache(const std::string& name, bool cleanup_expired = false); StringValPtr LookupOtherInCache(const std::string& name, int request_type,
bool cleanup_expired = false);
// Finish the request if we have a result. If not, time it out if // Finish the request if we have a result. If not, time it out if
// requested. // requested.
@ -265,12 +267,10 @@ protected:
void CompareMappings(DNS_Mapping* prev_dm, DNS_Mapping* new_dm); void CompareMappings(DNS_Mapping* prev_dm, DNS_Mapping* new_dm);
ListValPtr AddrListDelta(ListVal* al1, ListVal* al2); ListValPtr AddrListDelta(ListVal* al1, ListVal* al2);
using HostMap = std::map<std::string, DNS_Mapping*>; using MappingKey = std::variant<IPAddr, std::pair<int, std::string>>;
using AddrMap = std::map<IPAddr, DNS_Mapping*>; using MappingMap = std::map<MappingKey, DNS_Mapping*>;
using TextMap = std::map<std::string, DNS_Mapping*>;
void LoadCache(const std::string& path); void LoadCache(const std::string& path);
void Save(FILE* f, const AddrMap& m); void Save(FILE* f, const MappingMap& m);
void Save(FILE* f, const HostMap& m);
// Issue as many queued async requests as slots are available. // Issue as many queued async requests as slots are available.
void IssueAsyncRequests(); void IssueAsyncRequests();
@ -283,9 +283,7 @@ protected:
DNS_MgrMode mode; DNS_MgrMode mode;
HostMap host_mappings; MappingMap all_mappings;
AddrMap addr_mappings;
TextMap text_mappings;
std::string cache_name; std::string cache_name;
std::string dir; // directory in which cache_name resides std::string dir; // directory in which cache_name resides

View file

@ -4,7 +4,7 @@
7a5f:b783:9808:380e:b1a2:ce20:b58e:2a4a, 7a5f:b783:9808:380e:b1a2:ce20:b58e:2a4a,
4cc7:de52:d869:b2f9:f215:19b8:c828:3bdd 4cc7:de52:d869:b2f9:f215:19b8:c828:3bdd
} }
lookup_hostname_txt, fake_text_lookup_result_bro.wp.dg.cx lookup_hostname_txt, fake_lookup_result_T_TXT_bro.wp.dg.cx
lookup_hostname, { lookup_hostname, {
ce06:236:f21f:587:8c10:121d:c47d:b412 ce06:236:f21f:587:8c10:121d:c47d:b412
} }