diff --git a/src/DNS_Mapping.cc b/src/DNS_Mapping.cc index d0de26b016..45a8b9cf68 100644 --- a/src/DNS_Mapping.cc +++ b/src/DNS_Mapping.cc @@ -1,16 +1,20 @@ #include "zeek/DNS_Mapping.h" +#include + #include "zeek/3rdparty/doctest.h" #include "zeek/DNS_Mgr.h" +#include "zeek/Reporter.h" namespace zeek::detail { -DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl) +DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) { Init(h); req_host = host; req_ttl = ttl; + req_type = type; if ( names.empty() ) names.push_back(std::move(host)); @@ -21,6 +25,7 @@ DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) Init(h); req_addr = addr; req_ttl = ttl; + req_type = T_PTR; } DNS_Mapping::DNS_Mapping(FILE* f) @@ -45,7 +50,7 @@ DNS_Mapping::DNS_Mapping(FILE* f) int num_addrs; if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, - &failed_local, name_buf, &map_type, &num_addrs, &req_ttl) != 8 ) + &failed_local, name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) { no_mapping = true; return; @@ -126,7 +131,6 @@ void DNS_Mapping::Init(struct hostent* h) return; } - map_type = h->h_addrtype; if ( h->h_name ) // for now, just use the official name // TODO: this could easily be expanded to include all of the aliases as well @@ -153,7 +157,7 @@ void DNS_Mapping::Clear() addrs.clear(); addrs_val = nullptr; no_mapping = false; - map_type = 0; + req_type = 0; failed = true; } @@ -161,7 +165,7 @@ void DNS_Mapping::Save(FILE* f) const { fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(), req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed, - names.empty() ? "*" : names[0].c_str(), map_type, addrs.size(), req_ttl); + names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl); for ( const auto& addr : addrs ) fprintf(f, "%s\n", addr.AsString().c_str()); @@ -173,13 +177,38 @@ void DNS_Mapping::Merge(DNS_Mapping* other) std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs)); } +// This value needs to be incremented if something changes in the data stored by Save(). This +// allows us to change the structure of the cache without breaking something in DNS_Mgr. +constexpr int FILE_VERSION = 1; + +void DNS_Mapping::InitializeCache(FILE* f) + { + fprintf(f, "%d\n", FILE_VERSION); + } + +bool DNS_Mapping::ValidateCacheVersion(FILE* f) + { + char buf[512]; + if ( ! fgets(buf, sizeof(buf), f) ) + return false; + + int version; + if ( sscanf(buf, "%d", &version) != 1 ) + { + reporter->Warning("Existing DNS cache did not have correct version, ignoring"); + return false; + } + + return FILE_VERSION == version; + } + ////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////// TEST_CASE("dns_mapping init null hostent") { - DNS_Mapping mapping(std::string("www.apple.com"), nullptr, 123); + DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A); CHECK(! mapping.Valid()); CHECK(mapping.Addrs() == nullptr); @@ -202,7 +231,7 @@ TEST_CASE("dns_mapping init host") std::vector addrs = {&in4, NULL}; he.h_addr_list = reinterpret_cast(addrs.data()); - DNS_Mapping mapping(std::string("testing.home"), &he, 123); + DNS_Mapping mapping("testing.home", &he, 123, T_A); CHECK(mapping.Valid()); CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified); CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0); @@ -347,7 +376,7 @@ TEST_CASE("dns_mapping multiple addresses") std::vector addrs = {&in4_1, &in4_2, NULL}; he.h_addr_list = reinterpret_cast(addrs.data()); - DNS_Mapping mapping(std::string("testing.home"), &he, 123); + DNS_Mapping mapping("testing.home", &he, 123, T_A); CHECK(mapping.Valid()); auto lva = mapping.Addrs(); diff --git a/src/DNS_Mapping.h b/src/DNS_Mapping.h index c9b3ea6c03..c7cb698297 100644 --- a/src/DNS_Mapping.h +++ b/src/DNS_Mapping.h @@ -15,7 +15,7 @@ class DNS_Mapping { public: DNS_Mapping() = delete; - DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl); + DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type); DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl); DNS_Mapping(FILE* f); @@ -29,6 +29,7 @@ public: const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); } const IPAddr& ReqAddr() const { return req_addr; } std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; } + int ReqType() const { return req_type; } ListValPtr Addrs(); TableValPtr AddrsSet(); // addresses returned as a set @@ -50,10 +51,11 @@ public: return util::current_time() > (creation_time + req_ttl); } - int Type() const { return map_type; } - void Merge(DNS_Mapping* other); + static void InitializeCache(FILE* f); + static bool ValidateCacheVersion(FILE* f); + protected: friend class DNS_Mgr; @@ -63,6 +65,7 @@ protected: std::string req_host; IPAddr req_addr; uint32_t req_ttl = 0; + int req_type = 0; // This class supports multiple names per address, but we only store one of them. std::vector names; @@ -72,7 +75,6 @@ protected: ListValPtr addrs_val; double creation_time = 0.0; - int map_type = AF_UNSPEC; bool no_mapping = false; // when initializing from a file, immediately hit EOF bool init_failed = false; bool failed = false; diff --git a/src/DNS_Mgr.cc b/src/DNS_Mgr.cc index 609ffde677..c0a4c507f9 100644 --- a/src/DNS_Mgr.cc +++ b/src/DNS_Mgr.cc @@ -49,6 +49,130 @@ constexpr int DNS_TIMEOUT = 5; // The maximum allowed number of pending asynchronous requests. constexpr int MAX_PENDING_REQUESTS = 20; +// This unfortunately doesn't exist in c-ares, even though it seems rather useful. +static const char* request_type_string(int request_type) + { + switch ( request_type ) + { + case T_A: + return "T_A"; + case T_NS: + return "T_NS"; + case T_MD: + return "T_MD"; + case T_MF: + return "T_MF"; + case T_CNAME: + return "T_CNAME"; + case T_SOA: + return "T_SOA"; + case T_MB: + return "T_MB"; + case T_MG: + return "T_MG"; + case T_MR: + return "T_MR"; + case T_NULL: + return "T_NULL"; + case T_WKS: + return "T_WKS"; + case T_PTR: + return "T_PTR"; + case T_HINFO: + return "T_HINFO"; + case T_MINFO: + return "T_MINFO"; + case T_MX: + return "T_MX"; + case T_TXT: + return "T_TXT"; + case T_RP: + return "T_RP"; + case T_AFSDB: + return "T_AFSDB"; + case T_X25: + return "T_X25"; + case T_ISDN: + return "T_ISDN"; + case T_RT: + return "T_RT"; + case T_NSAP: + return "T_NSAP"; + case T_NSAP_PTR: + return "T_NSAP_PTR"; + case T_SIG: + return "T_SIG"; + case T_KEY: + return "T_KEY"; + case T_PX: + return "T_PX"; + case T_GPOS: + return "T_GPOS"; + case T_AAAA: + return "T_AAAA"; + case T_LOC: + return "T_LOC"; + case T_NXT: + return "T_NXT"; + case T_EID: + return "T_EID"; + case T_NIMLOC: + return "T_NIMLOC"; + case T_SRV: + return "T_SRV"; + case T_ATMA: + return "T_ATMA"; + case T_NAPTR: + return "T_NAPTR"; + case T_KX: + return "T_KX"; + case T_CERT: + return "T_CERT"; + case T_A6: + return "T_A6"; + case T_DNAME: + return "T_DNAME"; + case T_SINK: + return "T_SINK"; + case T_OPT: + return "T_OPT"; + case T_APL: + return "T_APL"; + case T_DS: + return "T_DS"; + case T_SSHFP: + return "T_SSHFP"; + case T_RRSIG: + return "T_RRSIG"; + case T_NSEC: + return "T_NSEC"; + case T_DNSKEY: + return "T_DNSKEY"; + case T_TKEY: + return "T_TKEY"; + case T_TSIG: + return "T_TSIG"; + case T_IXFR: + return "T_IXFR"; + case T_AXFR: + return "T_AXFR"; + case T_MAILB: + return "T_MAILB"; + case T_MAILA: + return "T_MAILA"; + case T_ANY: + return "T_ANY"; + case T_URI: + return "T_URI"; + case T_CAA: + return "T_CAA"; + case T_MAX: + return "T_MAX"; + default: + return ""; + } + } + namespace zeek::detail { static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result); @@ -90,6 +214,10 @@ uint16_t DNS_Request::request_id = 0; DNS_Request::DNS_Request(std::string host, int request_type, bool async) : host(std::move(host)), request_type(request_type), async(async) { + // We combine the T_A and T_AAAA requests together in one request, so set the type + // to T_A to make things easier in other parts of the code (mostly around lookups). + if ( request_type == T_AAAA ) + request_type = T_A; } DNS_Request::DNS_Request(const IPAddr& addr, bool async) : addr(addr), async(async) @@ -113,7 +241,7 @@ void DNS_Request::MakeRequest(ares_channel channel, DNS_Mgr* mgr) // all of them would be in flight at the same time. DNS_Request::request_id++; - if ( request_type == T_A || request_type == T_AAAA ) + if ( request_type == T_A ) { // For A/AAAA requests, we use a different method than the other requests. Since // we're using the AF_UNSPEC family, we get both the ipv4 and ipv6 responses @@ -149,7 +277,7 @@ void DNS_Request::ProcessAsyncResult(bool timed_out, DNS_Mgr* mgr) if ( ! async ) return; - if ( request_type == T_A || request_type == T_AAAA ) + if ( request_type == T_A ) mgr->CheckAsyncHostRequest(host, timed_out); else if ( request_type == T_PTR ) mgr->CheckAsyncAddrRequest(addr, timed_out); @@ -385,7 +513,8 @@ static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, in } default: - reporter->Error("Requests of type %d are unsupported", req->RequestType()); + reporter->Error("Requests of type %d (%s) are unsupported", req->RequestType(), + request_type_string(req->RequestType())); break; } } @@ -531,9 +660,9 @@ static TableValPtr fake_name_lookup_result(const std::string& name) return hv->ToSetVal(); } -static std::string fake_text_lookup_result(const std::string name) +static std::string fake_lookup_result(const std::string& name, int request_type) { - return util::fmt("fake_text_lookup_result_%s", name.c_str()); + return util::fmt("fake_lookup_result_%s_%s", request_type_string(request_type), name.c_str()); } static std::string fake_addr_lookup_result(const IPAddr& addr) @@ -558,14 +687,14 @@ ValPtr DNS_Mgr::Lookup(const std::string& name, int request_type) if ( request_type == T_A || request_type == T_AAAA ) return LookupHost(name); - if ( mode == DNS_FAKE && request_type == T_TXT ) - return make_intrusive(fake_text_lookup_result(name)); + if ( mode == DNS_FAKE ) + return make_intrusive(fake_lookup_result(name, request_type)); InitSource(); - if ( mode != DNS_PRIME && request_type == T_TXT ) + if ( mode != DNS_PRIME ) { - if ( auto val = LookupTextInCache(name, false) ) + if ( auto val = LookupOtherInCache(name, request_type, false) ) return val; } @@ -579,8 +708,8 @@ ValPtr DNS_Mgr::Lookup(const std::string& name, int request_type) } case DNS_FORCE: - reporter->FatalError("can't find DNS entry for %s (req type %d) in cache", name.c_str(), - request_type); + reporter->FatalError("can't find DNS entry for %s (req type %d / %s) in cache", + name.c_str(), request_type, request_type_string(request_type)); return nullptr; case DNS_DEFAULT: @@ -783,12 +912,12 @@ void DNS_Mgr::Lookup(const std::string& name, int request_type, LookupCallback* if ( mode == DNS_FAKE ) { - resolve_lookup_cb(callback, fake_text_lookup_result(name)); + resolve_lookup_cb(callback, fake_lookup_result(name, request_type)); return; } // Do we already know the answer? - if ( auto txt = LookupTextInCache(name, true) ) + if ( auto txt = LookupOtherInCache(name, request_type, true) ) { resolve_lookup_cb(callback, txt->CheckString()); return; @@ -885,98 +1014,50 @@ void DNS_Mgr::AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool m DNS_Mapping* prev_mapping = nullptr; bool keep_prev = true; - if ( ! dr->Host().empty() ) - { - new_mapping = new DNS_Mapping(dr->Host(), h, ttl); - - if ( dr->IsTxt() ) - { - TextMap::iterator it = text_mappings.find(dr->Host()); - - if ( it == text_mappings.end() ) - { - auto result = text_mappings.emplace(dr->Host(), new_mapping); - it = result.first; - } - else - prev_mapping = it->second; - - if ( prev_mapping && prev_mapping->Valid() ) - { - if ( new_mapping->Valid() ) - { - if ( merge ) - new_mapping->Merge(prev_mapping); - - it->second = new_mapping; - keep_prev = false; - } - } - else - { - it->second = new_mapping; - keep_prev = false; - } - } - else - { - HostMap::iterator it = host_mappings.find(dr->Host()); - if ( it == host_mappings.end() ) - { - auto result = host_mappings.emplace(dr->Host(), new_mapping); - it = result.first; - } - else - prev_mapping = it->second; - - if ( prev_mapping && prev_mapping->Valid() ) - { - if ( new_mapping->Valid() ) - { - if ( merge ) - new_mapping->Merge(prev_mapping); - - it->second = new_mapping; - keep_prev = false; - } - } - else - { - it->second = new_mapping; - keep_prev = false; - } - } - } - else + MappingMap::iterator it; + if ( dr->RequestType() == T_PTR ) { new_mapping = new DNS_Mapping(dr->Addr(), h, ttl); - - AddrMap::iterator it = addr_mappings.find(dr->Addr()); - if ( it == addr_mappings.end() ) + it = all_mappings.find(dr->Addr()); + if ( it == all_mappings.end() ) { - auto result = addr_mappings.emplace(dr->Addr(), new_mapping); + auto result = all_mappings.emplace(dr->Addr(), new_mapping); it = result.first; } else prev_mapping = it->second; + } + else + { + new_mapping = new DNS_Mapping(dr->Host(), h, ttl, dr->RequestType()); + auto key = std::make_pair(dr->RequestType(), dr->Host()); - if ( prev_mapping && prev_mapping->Valid() ) + it = all_mappings.find(key); + if ( it == all_mappings.end() ) { - if ( new_mapping->Valid() ) - { - if ( merge ) - new_mapping->Merge(prev_mapping); - - it->second = new_mapping; - keep_prev = false; - } + auto result = all_mappings.emplace(std::move(key), new_mapping); + it = result.first; } else + prev_mapping = it->second; + } + + if ( prev_mapping && prev_mapping->Valid() ) + { + if ( new_mapping->Valid() ) { + if ( merge ) + new_mapping->Merge(prev_mapping); + it->second = new_mapping; keep_prev = false; } } + else + { + it->second = new_mapping; + keep_prev = false; + } if ( prev_mapping && ! dr->IsTxt() ) CompareMappings(prev_mapping, new_mapping); @@ -1065,14 +1146,17 @@ void DNS_Mgr::LoadCache(const std::string& path) if ( ! f ) return; + if ( ! DNS_Mapping::ValidateCacheVersion(f) ) + return; + // Loop until we find a mapping that doesn't initialize correctly. DNS_Mapping* m = new DNS_Mapping(f); for ( ; ! m->NoMapping() && ! m->InitFailed(); m = new DNS_Mapping(f) ) { if ( m->ReqHost() ) - host_mappings.insert_or_assign(m->ReqHost(), m); + all_mappings.insert_or_assign(std::make_pair(m->ReqType(), m->ReqHost()), m); else - addr_mappings.insert_or_assign(m->ReqAddr(), m); + all_mappings.insert_or_assign(m->ReqAddr(), m); } if ( ! m->NoMapping() ) @@ -1092,38 +1176,28 @@ bool DNS_Mgr::Save() if ( ! f ) return false; - Save(f, host_mappings); - Save(f, addr_mappings); - // Save(f, text_mappings); // We don't save the TXT mappings (yet?). + DNS_Mapping::InitializeCache(f); + Save(f, all_mappings); fclose(f); return true; } -void DNS_Mgr::Save(FILE* f, const AddrMap& m) +void DNS_Mgr::Save(FILE* f, const MappingMap& m) { - for ( AddrMap::const_iterator it = m.begin(); it != m.end(); ++it ) + for ( const auto& [key, mapping] : m ) { - if ( it->second ) - it->second->Save(f); - } - } - -void DNS_Mgr::Save(FILE* f, const HostMap& m) - { - for ( HostMap::const_iterator it = m.begin(); it != m.end(); ++it ) - { - if ( it->second ) - it->second->Save(f); + if ( mapping ) + mapping->Save(f); } } TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_expired, bool check_failed) { - HostMap::iterator it = host_mappings.find(name); - if ( it == host_mappings.end() ) + auto it = all_mappings.find(std::make_pair(T_A, name)); + if ( it == all_mappings.end() ) return nullptr; DNS_Mapping* d = it->second; @@ -1133,7 +1207,7 @@ TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_exp if ( cleanup_expired && (d && d->Expired()) ) { - host_mappings.erase(it); + all_mappings.erase(it); delete d; return nullptr; } @@ -1149,15 +1223,15 @@ TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_exp StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired, bool check_failed) { - AddrMap::iterator it = addr_mappings.find(addr); - if ( it == addr_mappings.end() ) + auto it = all_mappings.find(addr); + if ( it == all_mappings.end() ) return nullptr; DNS_Mapping* d = it->second; if ( cleanup_expired && d->Expired() ) { - addr_mappings.erase(it); + all_mappings.erase(it); delete d; return nullptr; } @@ -1174,17 +1248,18 @@ StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired return make_intrusive("<\?\?\?>"); } -StringValPtr DNS_Mgr::LookupTextInCache(const std::string& name, bool cleanup_expired) +StringValPtr DNS_Mgr::LookupOtherInCache(const std::string& name, int request_type, + bool cleanup_expired) { - TextMap::iterator it = text_mappings.find(name); - if ( it == text_mappings.end() ) + auto it = all_mappings.find(std::make_pair(request_type, name)); + if ( it == all_mappings.end() ) return nullptr; DNS_Mapping* d = it->second; if ( cleanup_expired && d->Expired() ) { - text_mappings.erase(it); + all_mappings.erase(it); delete d; return nullptr; } @@ -1291,7 +1366,7 @@ void DNS_Mgr::CheckAsyncTextRequest(const std::string& host, bool timeout) ++failed; i->second->Timeout(); } - else if ( auto name = LookupTextInCache(host, true) ) + else if ( auto name = LookupOtherInCache(host, T_TXT, true) ) { ++successful; i->second->Resolved(name->CheckString()); @@ -1309,18 +1384,10 @@ void DNS_Mgr::Flush() { Resolve(); - for ( HostMap::iterator it = host_mappings.begin(); it != host_mappings.end(); ++it ) - delete it->second; + for ( auto& [key, mapping] : all_mappings ) + delete mapping; - for ( AddrMap::iterator it2 = addr_mappings.begin(); it2 != addr_mappings.end(); ++it2 ) - delete it2->second; - - for ( TextMap::iterator it3 = text_mappings.begin(); it3 != text_mappings.end(); ++it3 ) - delete it3->second; - - host_mappings.clear(); - addr_mappings.clear(); - text_mappings.clear(); + all_mappings.clear(); } double DNS_Mgr::GetNextTimeout() @@ -1357,9 +1424,20 @@ void DNS_Mgr::GetStats(Stats* stats) stats->successful = successful; stats->failed = failed; stats->pending = asyncs_pending; - stats->cached_hosts = host_mappings.size(); - stats->cached_addresses = addr_mappings.size(); - stats->cached_texts = text_mappings.size(); + + stats->cached_hosts = 0; + stats->cached_addresses = 0; + stats->cached_texts = 0; + + for ( const auto& [key, mapping] : all_mappings ) + { + if ( mapping->ReqType() == T_PTR ) + stats->cached_addresses++; + else if ( mapping->ReqType() == T_A ) + stats->cached_hosts++; + else + stats->cached_texts++; + } } void DNS_Mgr::AsyncRequest::Resolved(const std::string& name) diff --git a/src/DNS_Mgr.h b/src/DNS_Mgr.h index e3472218c8..f2f7d26e6d 100644 --- a/src/DNS_Mgr.h +++ b/src/DNS_Mgr.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "zeek/EventHandler.h" #include "zeek/IPAddr.h" @@ -248,7 +249,8 @@ protected: bool check_failed = false); TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, bool check_failed = false); - StringValPtr LookupTextInCache(const std::string& name, bool cleanup_expired = false); + StringValPtr LookupOtherInCache(const std::string& name, int request_type, + bool cleanup_expired = false); // Finish the request if we have a result. If not, time it out if // requested. @@ -265,12 +267,10 @@ protected: void CompareMappings(DNS_Mapping* prev_dm, DNS_Mapping* new_dm); ListValPtr AddrListDelta(ListVal* al1, ListVal* al2); - using HostMap = std::map; - using AddrMap = std::map; - using TextMap = std::map; + using MappingKey = std::variant>; + using MappingMap = std::map; void LoadCache(const std::string& path); - void Save(FILE* f, const AddrMap& m); - void Save(FILE* f, const HostMap& m); + void Save(FILE* f, const MappingMap& m); // Issue as many queued async requests as slots are available. void IssueAsyncRequests(); @@ -283,9 +283,7 @@ protected: DNS_MgrMode mode; - HostMap host_mappings; - AddrMap addr_mappings; - TextMap text_mappings; + MappingMap all_mappings; std::string cache_name; std::string dir; // directory in which cache_name resides diff --git a/testing/btest/Baseline/core.fake_dns/out b/testing/btest/Baseline/core.fake_dns/out index bc04a9ec86..45539d8098 100644 --- a/testing/btest/Baseline/core.fake_dns/out +++ b/testing/btest/Baseline/core.fake_dns/out @@ -4,7 +4,7 @@ 7a5f:b783:9808:380e:b1a2:ce20:b58e:2a4a, 4cc7:de52:d869:b2f9:f215:19b8:c828:3bdd } -lookup_hostname_txt, fake_text_lookup_result_bro.wp.dg.cx +lookup_hostname_txt, fake_lookup_result_T_TXT_bro.wp.dg.cx lookup_hostname, { ce06:236:f21f:587:8c10:121d:c47d:b412 }