diff --git a/src/DNS_Mapping.cc b/src/DNS_Mapping.cc index 1e98788904..73c804bae1 100644 --- a/src/DNS_Mapping.cc +++ b/src/DNS_Mapping.cc @@ -6,14 +6,14 @@ namespace zeek::detail { -DNS_Mapping::DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl) +DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl) { Init(h); req_host = host; req_ttl = ttl; if ( names.empty() ) - names.push_back(host); + names.push_back(std::move(host)); } DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) @@ -46,7 +46,10 @@ DNS_Mapping::DNS_Mapping(FILE* f) if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, &failed_local, name_buf, &map_type, &num_addrs, &req_ttl) != 8 ) + { + no_mapping = true; return; + } failed = static_cast(failed_local); @@ -129,12 +132,15 @@ void DNS_Mapping::Init(struct hostent* h) // TODO: this could easily be expanded to include all of the aliases as well names.push_back(h->h_name); - for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) + if ( h->h_addr_list ) { - if ( h->h_addrtype == AF_INET ) - addrs.push_back(IPAddr(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network)); - else if ( h->h_addrtype == AF_INET6 ) - addrs.push_back(IPAddr(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network)); + for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) + { + if ( h->h_addrtype == AF_INET ) + addrs.push_back(IPAddr(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network)); + else if ( h->h_addrtype == AF_INET6 ) + addrs.push_back(IPAddr(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network)); + } } failed = false; @@ -167,7 +173,7 @@ void DNS_Mapping::Save(FILE* f) const TEST_CASE("dns_mapping init null hostent") { - DNS_Mapping mapping("www.apple.com", nullptr, 123); + DNS_Mapping mapping(std::string("www.apple.com"), nullptr, 123); CHECK(! mapping.Valid()); CHECK(mapping.Addrs() == nullptr); @@ -190,7 +196,7 @@ TEST_CASE("dns_mapping init host") std::vector addrs = {&in4, NULL}; he.h_addr_list = reinterpret_cast(addrs.data()); - DNS_Mapping mapping("testing.home", &he, 123); + DNS_Mapping mapping(std::string("testing.home"), &he, 123); CHECK(mapping.Valid()); CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified); CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0); @@ -335,7 +341,7 @@ TEST_CASE("dns_mapping multiple addresses") std::vector addrs = {&in4_1, &in4_2, NULL}; he.h_addr_list = reinterpret_cast(addrs.data()); - DNS_Mapping mapping("testing.home", &he, 123); + DNS_Mapping mapping(std::string("testing.home"), &he, 123); CHECK(mapping.Valid()); auto lva = mapping.Addrs(); diff --git a/src/DNS_Mapping.h b/src/DNS_Mapping.h index ef56ade2b3..34a3aff7d5 100644 --- a/src/DNS_Mapping.h +++ b/src/DNS_Mapping.h @@ -15,7 +15,7 @@ class DNS_Mapping { public: DNS_Mapping() = delete; - DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl); + DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl); DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl); DNS_Mapping(FILE* f); @@ -35,6 +35,7 @@ public: StringValPtr Host(); double CreationTime() const { return creation_time; } + uint32_t TTL() const { return req_ttl; } void Save(FILE* f) const; diff --git a/src/DNS_Mgr.cc b/src/DNS_Mgr.cc index 2e240a3bf2..0165808a3a 100644 --- a/src/DNS_Mgr.cc +++ b/src/DNS_Mgr.cc @@ -5,6 +5,7 @@ #include "zeek/zeek-config.h" #include +#include #include #include #include @@ -26,6 +27,7 @@ #include #include +#include #include "zeek/3rdparty/doctest.h" #include "zeek/DNS_Mapping.h" @@ -44,128 +46,306 @@ // Number of seconds we'll wait for a reply. constexpr int DNS_TIMEOUT = 5; +// The maximum allowed number of pending asynchronous requests. +constexpr int MAX_PENDING_REQUESTS = 20; + namespace zeek::detail { -static void hostbyaddr_callback(void* arg, int status, int timeouts, struct hostent* hostent) +static void hostbyaddr_cb(void* arg, int status, int timeouts, struct hostent* hostent); +static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result); +static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, int len); +static void sock_cb(void* data, int s, int read, int write); + +class DNS_Request + { +public: + DNS_Request(std::string host, int af, int request_type, bool async = false); + explicit DNS_Request(const IPAddr& addr, bool async = false); + ~DNS_Request(); + + std::string Host() const { return host; } + const IPAddr& Addr() const { return addr; } + int Family() const { return family; } + int RequestType() const { return request_type; } + bool IsTxt() const { return request_type == 16; } + + void MakeRequest(ares_channel channel); + void ProcessAsyncResult(bool timed_out); + +private: + std::string host; + IPAddr addr; + int family = 0; // address family query type for host requests + int request_type = 0; // Query type + bool async = false; + unsigned char* query = nullptr; + static uint16_t request_id; + }; + +uint16_t DNS_Request::request_id = 0; + +DNS_Request::DNS_Request(std::string host, int af, int request_type, bool async) + : host(std::move(host)), family(af), request_type(request_type), async(async) { - printf("host callback\n"); - // TODO: implement this - // TODO: figure out how to get TTL info here } -static void addrinfo_callback(void* arg, int status, int timeouts, struct ares_addrinfo* result) +DNS_Request::DNS_Request(const IPAddr& addr, bool async) : addr(addr), async(async) { - printf("addrinfo callback\n"); + // TODO: AF_UNSPEC for T_PTR requests? + family = addr.GetFamily() == IPv4 ? AF_INET : AF_INET6; + request_type = T_PTR; + } + +DNS_Request::~DNS_Request() + { + if ( query ) + ares_free_string(query); + } + +void DNS_Request::MakeRequest(ares_channel channel) + { + // It's completely fine if this rolls over. It's just to keep the query ID different + // from one query to the next, and it's unlikely we'd do 2^16 queries so fast that + // all of them would be in flight at the same time. + DNS_Request::request_id++; + + // TODO: how the heck do file lookups work? gethostbyname_file exists but gethostbyaddr_file + // doesn't. But then the code in ares_gethostbyaddr.c can switch on the setting in the channel + // for whether we should look at the file or not. If we don't care about file lookups at all, + // the T_PTR case below can be simplified and moved down into the else block. + + // We do normal host and address lookups via the specialized methods for them + // because those will attempt to do file lookups as well internally before + // reaching out to the DNS server. The remaining lookup types all use + // ares_create_query() and ares_send() for more genericness. + if ( request_type == T_A || request_type == T_AAAA ) + { + // TODO: gethostbyname_file? + // Use getaddrinfo here because it gives us the ttl information. If we don't + // care about TTL, we could use gethostbyname instead. + ares_addrinfo_hints hints = {ARES_AI_CANONNAME, family, 0, 0}; + ares_getaddrinfo(channel, host.c_str(), NULL, &hints, addrinfo_cb, this); + } + else if ( request_type == T_PTR ) + { + if ( addr.GetFamily() == IPv4 ) + { + struct sockaddr_in sa; + inet_pton(AF_INET, addr.AsString().c_str(), &(sa.sin_addr)); + ares_gethostbyaddr(channel, &sa.sin_addr, sizeof(sa.sin_addr), AF_INET, hostbyaddr_cb, + this); + } + else + { + struct sockaddr_in6 sa; + inet_pton(AF_INET6, addr.AsString().c_str(), &(sa.sin6_addr)); + ares_gethostbyaddr(channel, &sa.sin6_addr, sizeof(sa.sin6_addr), AF_INET6, + hostbyaddr_cb, this); + } + } + else + { + unsigned char* query = NULL; + int len = 0; + int status = ares_create_query(host.c_str(), C_IN, request_type, DNS_Request::request_id, 1, + &query, &len, 0); + if ( status != ARES_SUCCESS ) + return; + + // Store this so it can be destroyed when the request is destroyed. + this->query = query; + ares_send(channel, query, len, query_cb, this); + } + } + +void DNS_Request::ProcessAsyncResult(bool timed_out) + { + if ( ! async ) + return; + + if ( request_type == T_A || request_type == T_AAAA ) + dns_mgr->CheckAsyncHostRequest(host, timed_out); + else if ( request_type == T_PTR ) + dns_mgr->CheckAsyncAddrRequest(addr, timed_out); + else if ( request_type == T_TXT ) + dns_mgr->CheckAsyncTextRequest(host, timed_out); + } + +/** + * Called in response to ares_gethostbyaddr requests. Sends the hostent data to the + * DNS manager via AddResult(). + */ +static void hostbyaddr_cb(void* arg, int status, int timeouts, struct hostent* host) + { + auto req = reinterpret_cast(arg); + + if ( ! host || status != ARES_SUCCESS ) + { + printf("Failed hostbyaddr request: %s\n", ares_strerror(status)); + // TODO: pass DNS_TIMEOUT for the TTL here just so things work for testing. This + // will absolutely need to get the data from the request somehow instead. See + // https://github.com/c-ares/c-ares/issues/387. + dns_mgr->AddResult(req, nullptr, DNS_TIMEOUT); + } + else + { + // TODO: pass DNS_TIMEOUT for the TTL here just so things work for testing. This + // will absolutely need to get the data from the request somehow instead. See + // https://github.com/c-ares/c-ares/issues/387. + dns_mgr->AddResult(req, host, DNS_TIMEOUT); + } + + req->ProcessAsyncResult(timeouts > 0); + } + +/** + * Called in response to ares_getaddrinfo requests. Builds a hostent structure from + * the result data and sends it to the DNS manager via Addresult(). + */ +static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result) + { + auto req = reinterpret_cast(arg); if ( status != ARES_SUCCESS ) { - // TODO: error or something here, or just give up on it? - ares_freeaddrinfo(result); - return; + // TODO: reporter warning or something here, or just give up on it? + printf("Failed addrinfo request: %s", ares_strerror(status)); + dns_mgr->AddResult(req, nullptr, 0); + } + else + { + std::vector addrs; + std::vector addrs6; + for ( ares_addrinfo_node* entry = result->nodes; entry != NULL; entry = entry->ai_next ) + { + if ( entry->ai_family == AF_INET ) + { + struct sockaddr_in* addr = reinterpret_cast(entry->ai_addr); + addrs.push_back(&addr->sin_addr); + } + else if ( entry->ai_family == AF_INET6 ) + { + struct sockaddr_in6* addr = (struct sockaddr_in6*)(entry->ai_addr); + addrs6.push_back(&addr->sin6_addr); + } + } + + if ( ! addrs.empty() ) + { + // Push a null on the end so the addr list has a final point during later parsing. + addrs.push_back(NULL); + + struct hostent he; + memset(&he, 0, sizeof(struct hostent)); + he.h_name = util::copy_string(result->name); + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); + he.h_addr_list = reinterpret_cast(addrs.data()); + + dns_mgr->AddResult(req, &he, result->nodes[0].ai_ttl); + + delete[] he.h_name; + } + + // TODO: We can't do this here because we blow up the mapping added above by doing so. + // We need some sort of "merge mapping" mode in AddResult for this to work to add new + // IPs to an existing mapping. + /* + if ( ! addrs6.empty() ) + { + // Push a null on the end so the addr list has a final point during later parsing. + addrs6.push_back(NULL); + + struct hostent he; + memset(&he, 0, sizeof(struct hostent)); + he.h_name = util::copy_string(result->name); + he.h_addrtype = AF_INET6; + he.h_length = sizeof(in6_addr); + he.h_addr_list = reinterpret_cast(addrs6.data()); + + dns_mgr->AddResult(req, &he, result->nodes[0].ai_ttl); + + delete[] he.h_name; + } + */ } - // TODO: the existing code doesn't handle hostname aliases at all. Should we? - // TODO: handle IPv6 mode - - std::vector addrs; - for ( ares_addrinfo_node* entry = result->nodes; entry != NULL; entry = entry->ai_next ) - addrs.push_back(&reinterpret_cast(entry->ai_addr)->sin_addr); - - // Push a null on the end so the addr list has a final point during later parsing. - addrs.push_back(NULL); - - struct hostent he; - he.h_name = util::copy_string(result->name); - he.h_aliases = NULL; - he.h_addrtype = AF_INET; - he.h_length = sizeof(in_addr); - he.h_addr_list = reinterpret_cast(addrs.data()); - - auto req = reinterpret_cast(arg); - dns_mgr->AddResult(req, &he, result->nodes[0].ai_ttl); - - delete[] he.h_name; + req->ProcessAsyncResult(timeouts > 0); ares_freeaddrinfo(result); } -static void ares_sock_cb(void* data, int s, int read, int write) +/** + * Called in response to all other query types. + */ +static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, int len) { - printf("Change state fd %d read:%d write:%d\n", s, read, write); - if ( read == 1 ) - iosource_mgr->RegisterFd(s, reinterpret_cast(data)); - else - iosource_mgr->UnregisterFd(s, reinterpret_cast(data)); - } + auto req = reinterpret_cast(arg); -class DNS_Mgr_Request - { -public: - DNS_Mgr_Request(const char* h, int af, bool is_txt) - : host(util::copy_string(h)), fam(af), qtype(is_txt ? 16 : 0), addr() + if ( status != ARES_SUCCESS ) { - } - - DNS_Mgr_Request(const IPAddr& a) : addr(a) { } - - ~DNS_Mgr_Request() { delete[] host; } - - // Returns nil if this was an address request. - const char* ReqHost() const { return host; } - const IPAddr& ReqAddr() const { return addr; } - int Family() const { return fam; } - bool ReqIsTxt() const { return qtype == 16; } - - void MakeRequest(ares_channel channel); - - bool RequestPending() const { return request_pending; } - void RequestDone() { request_pending = false; } - -protected: - char* host = nullptr; // if non-nil, this is a host request - int fam = 0; // address family query type for host requests - int qtype = 0; // Query type - IPAddr addr; - bool request_pending = false; - }; - -void DNS_Mgr_Request::MakeRequest(ares_channel channel) - { - request_pending = true; - - // TODO: TXT requests? - // TODO: could this use ares_create_query/ares_query instead of the - // ares_get* methods to make it more generic? I think we might need - // to do that for TXT requests. - - if ( host ) - { - ares_addrinfo_hints hints = {ARES_AI_CANONNAME, fam, 0, 0}; - ares_getaddrinfo(channel, host, NULL, &hints, addrinfo_callback, this); + // TODO: reporter warning or something here, or just give up on it? + // TODO: what should we send to AddResult if we didn't get an answer back? + // struct hostent he; + // memset(&he, 0, sizeof(struct hostent)); + // dns_mgr->AddResult(req, &he, 0); } else { - const uint32_t* bytes; - int len = addr.GetBytes(&bytes); + switch ( req->RequestType() ) + { + case T_TXT: + { + struct ares_txt_reply* reply; + int r = ares_parse_txt_reply(buf, len, &reply); + if ( r == ARES_SUCCESS ) + { + // Use a hostent to send the data into AddResult(). We only care about + // setting the host field, but everything else should be zero just for + // safety. - ares_gethostbyaddr(channel, bytes, len, addr.GetFamily() == IPv4 ? AF_INET : AF_INET6, - hostbyaddr_callback, this); + // We don't currently handle more than the first response, and throw the + // rest away. There really isn't a good reason for this, we just haven't + // ever done so. It would likely require some changes to the output from + // Lookup(), since right now it only returns one value. + struct hostent he; + memset(&he, 0, sizeof(struct hostent)); + he.h_name = util::copy_string(reinterpret_cast(reply->txt)); + + // TODO: pass DNS_TIMEOUT for the TTL here just so things work for + // testing. This will absolutely need to get the data from the request + // somehow instead. See https://github.com/c-ares/c-ares/issues/387. + dns_mgr->AddResult(req, &he, DNS_TIMEOUT); + + ares_free_data(reply); + } + + break; + } + + default: + reporter->Error("Requests of type %d are unsupported", req->RequestType()); + break; + } } + + req->ProcessAsyncResult(timeouts > 0); } -DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) +/** + * Called when the c-ares socket changes state, whcih indicates that it's connected to + * some source of data (either a host file or a DNS server). This indicates that we're + * able to do lookups against c-ares now and should activate the IOSource. + */ +static void sock_cb(void* data, int s, int read, int write) { - did_init = false; - - mode = arg_mode; - - asyncs_pending = 0; - num_requests = 0; - successful = 0; - failed = 0; - ipv6_resolver = false; + auto mgr = reinterpret_cast(data); + mgr->RegisterSocket(s, read == 1); + } +DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) : mode(arg_mode) + { ares_library_init(ARES_LIB_INIT_ALL); } @@ -178,6 +358,20 @@ DNS_Mgr::~DNS_Mgr() ares_library_cleanup(); } +void DNS_Mgr::RegisterSocket(int fd, bool active) + { + if ( active && socket_fds.count(fd) == 0 ) + { + socket_fds.insert(fd); + iosource_mgr->RegisterFd(fd, this); + } + else if ( ! active && socket_fds.count(fd) != 0 ) + { + socket_fds.erase(fd); + iosource_mgr->UnregisterFd(fd, this); + } + } + void DNS_Mgr::InitSource() { if ( did_init ) @@ -186,13 +380,22 @@ void DNS_Mgr::InitSource() ares_options options; int optmask = 0; + // Don't close the socket for the server even if we have no active + // requests. options.flags = ARES_FLAG_STAYOPEN; optmask |= ARES_OPT_FLAGS; - options.timeout = DNS_TIMEOUT; - optmask |= ARES_OPT_TIMEOUT; + // This option is in milliseconds. + options.timeout = DNS_TIMEOUT * 1000; + optmask |= ARES_OPT_TIMEOUTMS; - options.sock_state_cb = ares_sock_cb; + // This causes c-ares to only attempt each server twice before + // giving up. + options.tries = 2; + optmask |= ARES_OPT_TRIES; + + // See the comment on sock_cb for how this gets used. + options.sock_state_cb = sock_cb; options.sock_state_cb_data = this; optmask |= ARES_OPT_SOCK_STATE_CB; @@ -210,19 +413,15 @@ void DNS_Mgr::InitSource() if ( dns_resolver ) { ares_addr_node servers; - servers.next = nullptr; + servers.next = NULL; auto dns_resolver_addr = IPAddr(dns_resolver); struct sockaddr_storage ss = {0}; if ( dns_resolver_addr.GetFamily() == IPv4 ) { - struct sockaddr_in* sa = (struct sockaddr_in*)&ss; - sa->sin_family = AF_INET; - dns_resolver_addr.CopyIPv4(&sa->sin_addr); - servers.family = AF_INET; - memcpy(&(servers.addr.addr4), &sa->sin_addr, sizeof(struct in_addr)); + dns_resolver_addr.CopyIPv4(&(servers.addr.addr4)); } else { @@ -256,102 +455,126 @@ void DNS_Mgr::InitPostScript() } // Load the DNS cache from disk, if it exists. - std::string cache_dir = dir.empty() ? dir : "."; + std::string cache_dir = dir.empty() ? "." : dir; cache_name = util::fmt("%s/%s", cache_dir.c_str(), ".zeek-dns-cache"); LoadCache(cache_name); } -static TableValPtr fake_name_lookup_result(const char* name) +static TableValPtr fake_name_lookup_result(const std::string& name) { hash128_t hash; - KeyedHash::StaticHash128(name, strlen(name), &hash); + KeyedHash::StaticHash128(name.c_str(), name.size(), &hash); auto hv = make_intrusive(TYPE_ADDR); hv->Append(make_intrusive(reinterpret_cast(&hash))); return hv->ToSetVal(); } -static const char* fake_text_lookup_result(const char* name) +static std::string fake_text_lookup_result(const std::string name) { - static char tmp[32 + 256]; - snprintf(tmp, sizeof(tmp), "fake_text_lookup_result_%s", name); - return tmp; + return util::fmt("fake_text_lookup_result_%s", name.c_str()); } -static const char* fake_addr_lookup_result(const IPAddr& addr) +static std::string fake_addr_lookup_result(const IPAddr& addr) { - static char tmp[128]; - snprintf(tmp, sizeof(tmp), "fake_addr_lookup_result_%s", addr.AsString().c_str()); - return tmp; + return util::fmt("fake_addr_lookup_result_%s", addr.AsString().c_str()); } -TableValPtr DNS_Mgr::LookupHost(const char* name) +static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, TableValPtr result) + { + callback->Resolved(std::move(result)); + delete callback; + } + +static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, const std::string& result) + { + callback->Resolved(result); + delete callback; + } + +ValPtr DNS_Mgr::Lookup(const std::string& name, int request_type) + { + if ( request_type == T_A || request_type == T_AAAA ) + return LookupHost(name); + + if ( mode == DNS_FAKE && request_type == T_TXT ) + return make_intrusive(fake_text_lookup_result(name)); + + InitSource(); + + if ( mode != DNS_PRIME && request_type == T_TXT ) + { + if ( auto val = LookupTextInCache(name, false) ) + return val; + } + + switch ( mode ) + { + case DNS_PRIME: + { + auto req = new DNS_Request(name, AF_UNSPEC, request_type); + req->MakeRequest(channel); + return empty_addr_set(); + } + + case DNS_FORCE: + reporter->FatalError("can't find DNS entry for %s (req type %d) in cache", name.c_str(), + request_type); + return nullptr; + + case DNS_DEFAULT: + { + auto req = new DNS_Request(name, AF_UNSPEC, request_type); + req->MakeRequest(channel); + Resolve(); + + // Call LookupHost() a second time to get the newly stored value out of the cache. + return Lookup(name, request_type); + } + + default: + reporter->InternalError("bad mode %d in DNS_Mgr::Lookup", mode); + return nullptr; + } + + return nullptr; + } + +TableValPtr DNS_Mgr::LookupHost(const std::string& name) { if ( mode == DNS_FAKE ) return fake_name_lookup_result(name); - // This should have been run already from InitPostScript(), but just run it again just - // in case it hadn't. InitSource(); // Check the cache before attempting to look up the name remotely. if ( mode != DNS_PRIME ) { - HostMap::iterator it = host_mappings.find(name); - - if ( it != host_mappings.end() ) - { - DNS_Mapping* d4 = it->second.first; - DNS_Mapping* d6 = it->second.second; - - if ( (d4 && d4->Failed()) || (d6 && d6->Failed()) ) - { - reporter->Warning("no such host: %s", name); - return empty_addr_set(); - } - else if ( d4 && d6 ) - { - auto tv4 = d4->AddrsSet(); - auto tv6 = d6->AddrsSet(); - tv4->AddTo(tv6.get(), false); - return tv6; - } - } + if ( auto val = LookupNameInCache(name, false, true) ) + return val; } - // Not found, or priming. We use ares_getaddrinfo here because we want the TTL value + // Not found, or priming. switch ( mode ) { case DNS_PRIME: { - // TODO: not sure we need to do these split like this if we can pass AF_UNSPEC - // in the hints structure. Do we really need the two different request objects? - auto v4 = new DNS_Mgr_Request(name, AF_INET, false); - ares_addrinfo_hints v4_hints = {ARES_AI_CANONNAME, AF_INET, 0, 0}; - ares_getaddrinfo(channel, name, NULL, &v4_hints, addrinfo_callback, v4); - - // TODO: check if ipv6 support is needed if we use AF_UNSPEC above - // auto v6 = new DNS_Mgr_Request(name, AF_INET6, false); - // ares_addrinfo_hints v6_hints = { 0, AF_INET6, 0, 0 }; - // ares_getaddrinfo(channel, name, NULL, &v6_hints, addrinfo_callback, v6); - + // We pass T_A here, but because we're passing AF_UNSPEC MakeRequest() will + // have c-ares attempt to lookup both ipv4 and ipv6 at the same time. + auto req = new DNS_Request(name, AF_UNSPEC, T_A); + req->MakeRequest(channel); return empty_addr_set(); } case DNS_FORCE: - reporter->FatalError("can't find DNS entry for %s in cache", name); + reporter->FatalError("can't find DNS entry for %s in cache", name.c_str()); return nullptr; case DNS_DEFAULT: { - auto v4 = new DNS_Mgr_Request(name, AF_INET, false); - ares_addrinfo_hints v4_hints = {ARES_AI_CANONNAME, AF_INET, 0, 0}; - ares_getaddrinfo(channel, name, NULL, &v4_hints, addrinfo_callback, v4); - - // TODO: check if ipv6 support is needed if we use AF_UNSPEC above - // auto v6 = new DNS_Mgr_Request(name, AF_INET6, false); - // ares_addrinfo_hints v6_hints = { 0, AF_INET6, 0, 0 }; - // ares_getaddrinfo(channel, name, NULL, &v6_hints, addrinfo_callback, v6); - + // We pass T_A here, but because we're passing AF_UNSPEC MakeRequest() will + // have c-ares attempt to lookup both ipv4 and ipv6 at the same time. + auto req = new DNS_Request(name, AF_UNSPEC, T_A); + req->MakeRequest(channel); Resolve(); // Call LookupHost() a second time to get the newly stored value out of the cache. @@ -369,40 +592,22 @@ StringValPtr DNS_Mgr::LookupAddr(const IPAddr& addr) if ( mode == DNS_FAKE ) return make_intrusive(fake_addr_lookup_result(addr)); - // This should have been run already from InitPostScript(), but just run it again just - // in case it hadn't. InitSource(); // Check the cache before attempting to look up the name remotely. if ( mode != DNS_PRIME ) { - AddrMap::iterator it = addr_mappings.find(addr); - - if ( it != addr_mappings.end() ) - { - DNS_Mapping* d = it->second; - if ( d->Valid() ) - return d->Host(); - else - { - std::string s(addr); - reporter->Warning("can't resolve IP address: %s", s.c_str()); - return make_intrusive(s.c_str()); - } - } + if ( auto val = LookupAddrInCache(addr, false, true) ) + return val; } - const uint32_t* bytes; - int len = addr.GetBytes(&bytes); - // Not found, or priming. switch ( mode ) { case DNS_PRIME: { - auto req = new DNS_Mgr_Request(addr); - ares_gethostbyaddr(channel, bytes, len, addr.GetFamily() == IPv4 ? AF_INET : AF_INET6, - hostbyaddr_callback, req); + auto req = new DNS_Request(addr); + req->MakeRequest(channel); return make_intrusive(""); } @@ -412,9 +617,8 @@ StringValPtr DNS_Mgr::LookupAddr(const IPAddr& addr) case DNS_DEFAULT: { - auto req = new DNS_Mgr_Request(addr); - ares_gethostbyaddr(channel, bytes, len, addr.GetFamily() == IPv4 ? AF_INET : AF_INET6, - hostbyaddr_callback, req); + auto req = new DNS_Request(addr); + req->MakeRequest(channel); Resolve(); // Call LookupAddr() a second time to get the newly stored value out of the cache. @@ -427,7 +631,129 @@ StringValPtr DNS_Mgr::LookupAddr(const IPAddr& addr) } } -constexpr int MAX_PENDING_REQUESTS = 20; +void DNS_Mgr::LookupHost(const std::string& name, LookupCallback* callback) + { + if ( mode == DNS_FAKE ) + { + resolve_lookup_cb(callback, fake_name_lookup_result(name)); + return; + } + + // Do we already know the answer? + if ( auto addrs = LookupNameInCache(name, true, false) ) + { + resolve_lookup_cb(callback, std::move(addrs)); + return; + } + + AsyncRequest* req = nullptr; + + // If we already have a request waiting for this host, we don't need to make + // another one. We can just add the callback to it and it'll get handled + // when the first request comes back. + AsyncRequestNameMap::iterator i = asyncs_names.find(name); + if ( i != asyncs_names.end() ) + req = i->second; + else + { + // A new one. + req = new AsyncRequest{}; + req->host = name; + asyncs_queued.push_back(req); + asyncs_names.emplace_hint(i, name, req); + } + + req->callbacks.push_back(callback); + + // There may be requests in the queue that haven't been processed yet + // so go ahead and reissue them, even if this method didn't change + // anything. + IssueAsyncRequests(); + } + +void DNS_Mgr::LookupAddr(const IPAddr& host, LookupCallback* callback) + { + if ( mode == DNS_FAKE ) + { + resolve_lookup_cb(callback, fake_addr_lookup_result(host)); + return; + } + + // Do we already know the answer? + if ( auto name = LookupAddrInCache(host, true, false) ) + { + resolve_lookup_cb(callback, name->CheckString()); + return; + } + + AsyncRequest* req = nullptr; + + // If we already have a request waiting for this host, we don't need to make + // another one. We can just add the callback to it and it'll get handled + // when the first request comes back. + AsyncRequestAddrMap::iterator i = asyncs_addrs.find(host); + if ( i != asyncs_addrs.end() ) + req = i->second; + else + { + // A new one. + req = new AsyncRequest{}; + req->addr = host; + asyncs_queued.push_back(req); + asyncs_addrs.emplace_hint(i, host, req); + } + + req->callbacks.push_back(callback); + + // There may be requests in the queue that haven't been processed yet + // so go ahead and reissue them, even if this method didn't change + // anything. + IssueAsyncRequests(); + } + +void DNS_Mgr::Lookup(const std::string& name, int request_type, LookupCallback* callback) + { + if ( request_type == T_A || request_type == T_AAAA ) + { + LookupHost(name, callback); + return; + } + + if ( mode == DNS_FAKE ) + { + resolve_lookup_cb(callback, fake_text_lookup_result(name)); + return; + } + + // Do we already know the answer? + if ( auto txt = LookupTextInCache(name, true) ) + { + resolve_lookup_cb(callback, txt->CheckString()); + return; + } + + AsyncRequest* req = nullptr; + + // If we already have a request waiting for this host, we don't need to make + // another one. We can just add the callback to it and it'll get handled + // when the first request comes back. + AsyncRequestTextMap::iterator i = asyncs_texts.find(name); + if ( i != asyncs_texts.end() ) + req = i->second; + else + { + // A new one. + req = new AsyncRequest{}; + req->host = name; + req->is_txt = true; + asyncs_queued.push_back(req); + asyncs_texts.emplace_hint(i, name, req); + } + + req->callbacks.push_back(callback); + + IssueAsyncRequests(); + } void DNS_Mgr::Resolve() { @@ -454,25 +780,20 @@ void DNS_Mgr::Resolve() void DNS_Mgr::Event(EventHandlerPtr e, DNS_Mapping* dm) { - if ( ! e ) - return; - event_mgr.Enqueue(e, BuildMappingVal(dm)); + if ( e ) + event_mgr.Enqueue(e, BuildMappingVal(dm)); } void DNS_Mgr::Event(EventHandlerPtr e, DNS_Mapping* dm, ListValPtr l1, ListValPtr l2) { - if ( ! e ) - return; - - event_mgr.Enqueue(e, BuildMappingVal(dm), l1->ToSetVal(), l2->ToSetVal()); + if ( e ) + event_mgr.Enqueue(e, BuildMappingVal(dm), l1->ToSetVal(), l2->ToSetVal()); } void DNS_Mgr::Event(EventHandlerPtr e, DNS_Mapping* old_dm, DNS_Mapping* new_dm) { - if ( ! e ) - return; - - event_mgr.Enqueue(e, BuildMappingVal(old_dm), BuildMappingVal(new_dm)); + if ( e ) + event_mgr.Enqueue(e, BuildMappingVal(old_dm), BuildMappingVal(new_dm)); } ValPtr DNS_Mgr::BuildMappingVal(DNS_Mapping* dm) @@ -494,66 +815,70 @@ ValPtr DNS_Mgr::BuildMappingVal(DNS_Mapping* dm) return r; } -void DNS_Mgr::AddResult(DNS_Mgr_Request* dr, struct hostent* h, uint32_t ttl) +void DNS_Mgr::AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl) { - DNS_Mapping* new_dm; - DNS_Mapping* prev_dm; + // TODO: the existing code doesn't handle hostname aliases at all. Should we? + + DNS_Mapping* new_mapping; + DNS_Mapping* prev_mapping; bool keep_prev = false; - if ( dr->ReqHost() ) + if ( ! dr->Host().empty() ) { - new_dm = new DNS_Mapping(dr->ReqHost(), h, ttl); - prev_dm = nullptr; + new_mapping = new DNS_Mapping(dr->Host(), h, ttl); + prev_mapping = nullptr; - if ( dr->ReqIsTxt() ) + if ( dr->IsTxt() ) { - TextMap::iterator it = text_mappings.find(dr->ReqHost()); + TextMap::iterator it = text_mappings.find(dr->Host()); if ( it == text_mappings.end() ) - text_mappings[dr->ReqHost()] = new_dm; + text_mappings[dr->Host()] = new_mapping; else { - prev_dm = it->second; - it->second = new_dm; + prev_mapping = it->second; + it->second = new_mapping; } - if ( new_dm->Failed() && prev_dm && prev_dm->Valid() ) + if ( new_mapping->Failed() && prev_mapping && prev_mapping->Valid() ) { - text_mappings[dr->ReqHost()] = prev_dm; + text_mappings[dr->Host()] = prev_mapping; keep_prev = true; } } else { - HostMap::iterator it = host_mappings.find(dr->ReqHost()); + HostMap::iterator it = host_mappings.find(dr->Host()); if ( it == host_mappings.end() ) { - host_mappings[dr->ReqHost()].first = new_dm->Type() == AF_INET ? new_dm : nullptr; + host_mappings[dr->Host()].first = new_mapping->Type() == AF_INET ? new_mapping + : nullptr; - host_mappings[dr->ReqHost()].second = new_dm->Type() == AF_INET ? nullptr : new_dm; + host_mappings[dr->Host()].second = new_mapping->Type() == AF_INET ? nullptr + : new_mapping; } else { - if ( new_dm->Type() == AF_INET ) + if ( new_mapping->Type() == AF_INET ) { - prev_dm = it->second.first; - it->second.first = new_dm; + prev_mapping = it->second.first; + it->second.first = new_mapping; } else { - prev_dm = it->second.second; - it->second.second = new_dm; + prev_mapping = it->second.second; + it->second.second = new_mapping; } } - if ( new_dm->Failed() && prev_dm && prev_dm->Valid() ) + if ( new_mapping->Failed() && prev_mapping && prev_mapping->Valid() ) { // Put previous, valid entry back - CompareMappings // will generate a corresponding warning. - if ( prev_dm->Type() == AF_INET ) - host_mappings[dr->ReqHost()].first = prev_dm; + if ( prev_mapping->Type() == AF_INET ) + host_mappings[dr->Host()].first = prev_mapping; else - host_mappings[dr->ReqHost()].second = prev_dm; + host_mappings[dr->Host()].second = prev_mapping; keep_prev = true; } @@ -561,60 +886,60 @@ void DNS_Mgr::AddResult(DNS_Mgr_Request* dr, struct hostent* h, uint32_t ttl) } else { - new_dm = new DNS_Mapping(dr->ReqAddr(), h, ttl); - AddrMap::iterator it = addr_mappings.find(dr->ReqAddr()); - prev_dm = (it == addr_mappings.end()) ? 0 : it->second; - addr_mappings[dr->ReqAddr()] = new_dm; + new_mapping = new DNS_Mapping(dr->Addr(), h, ttl); + AddrMap::iterator it = addr_mappings.find(dr->Addr()); + prev_mapping = (it == addr_mappings.end()) ? 0 : it->second; + addr_mappings[dr->Addr()] = new_mapping; - if ( new_dm->Failed() && prev_dm && prev_dm->Valid() ) + if ( new_mapping->Failed() && prev_mapping && prev_mapping->Valid() ) { - addr_mappings[dr->ReqAddr()] = prev_dm; + addr_mappings[dr->Addr()] = prev_mapping; keep_prev = true; } } - if ( prev_dm && ! dr->ReqIsTxt() ) - CompareMappings(prev_dm, new_dm); + if ( prev_mapping && ! dr->IsTxt() ) + CompareMappings(prev_mapping, new_mapping); if ( keep_prev ) - delete new_dm; + delete new_mapping; else - delete prev_dm; + delete prev_mapping; } -void DNS_Mgr::CompareMappings(DNS_Mapping* prev_dm, DNS_Mapping* new_dm) +void DNS_Mgr::CompareMappings(DNS_Mapping* prev_mapping, DNS_Mapping* new_mapping) { - if ( prev_dm->Failed() ) + if ( prev_mapping->Failed() ) { - if ( new_dm->Failed() ) + if ( new_mapping->Failed() ) // Nothing changed. return; - Event(dns_mapping_valid, new_dm); + Event(dns_mapping_valid, new_mapping); return; } - else if ( new_dm->Failed() ) + else if ( new_mapping->Failed() ) { - Event(dns_mapping_unverified, prev_dm); + Event(dns_mapping_unverified, prev_mapping); return; } - auto prev_s = prev_dm->Host(); - auto new_s = new_dm->Host(); + auto prev_s = prev_mapping->Host(); + auto new_s = new_mapping->Host(); if ( prev_s || new_s ) { if ( ! prev_s ) - Event(dns_mapping_new_name, new_dm); + Event(dns_mapping_new_name, new_mapping); else if ( ! new_s ) - Event(dns_mapping_lost_name, prev_dm); + Event(dns_mapping_lost_name, prev_mapping); else if ( ! Bstr_eq(new_s->AsString(), prev_s->AsString()) ) - Event(dns_mapping_name_changed, prev_dm, new_dm); + Event(dns_mapping_name_changed, prev_mapping, new_mapping); } - auto prev_a = prev_dm->Addrs(); - auto new_a = new_dm->Addrs(); + auto prev_a = prev_mapping->Addrs(); + auto new_a = new_mapping->Addrs(); if ( ! prev_a || ! new_a ) { @@ -626,7 +951,7 @@ void DNS_Mgr::CompareMappings(DNS_Mapping* prev_dm, DNS_Mapping* new_dm) auto new_delta = AddrListDelta(new_a.get(), prev_a.get()); if ( prev_delta->Length() > 0 || new_delta->Length() > 0 ) - Event(dns_mapping_altered, new_dm, std::move(prev_delta), std::move(new_delta)); + Event(dns_mapping_altered, new_mapping, std::move(prev_delta), std::move(new_delta)); } ListValPtr DNS_Mgr::AddrListDelta(ListVal* al1, ListVal* al2) @@ -653,15 +978,6 @@ ListValPtr DNS_Mgr::AddrListDelta(ListVal* al1, ListVal* al2) return delta; } -void DNS_Mgr::DumpAddrList(FILE* f, ListVal* al) - { - for ( int i = 0; i < al->Length(); ++i ) - { - const IPAddr& al_i = al->Idx(i)->AsAddr(); - fprintf(f, "%s%s", i > 0 ? "," : "", al_i.AsString().c_str()); - } - } - void DNS_Mgr::LoadCache(const std::string& path) { FILE* f = fopen(path.c_str(), "r"); @@ -738,43 +1054,20 @@ void DNS_Mgr::Save(FILE* f, const HostMap& m) } } -const char* DNS_Mgr::LookupAddrInCache(const IPAddr& addr) - { - AddrMap::iterator it = addr_mappings.find(addr); - - if ( it == addr_mappings.end() ) - return nullptr; - - DNS_Mapping* d = it->second; - - if ( d->Expired() ) - { - addr_mappings.erase(it); - delete d; - return nullptr; - } - - // The escapes in the following strings are to avoid having it - // interpreted as a trigraph sequence. - return d->names.empty() ? "<\?\?\?>" : d->names[0].c_str(); - } - -TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name) +TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_expired, + bool check_failed) { HostMap::iterator it = host_mappings.find(name); if ( it == host_mappings.end() ) - { - it = host_mappings.begin(); return nullptr; - } DNS_Mapping* d4 = it->second.first; DNS_Mapping* d6 = it->second.second; - if ( ! d4 || d4->names.empty() || ! d6 || d6->names.empty() ) + if ( (! d4 || d4->names.empty()) && (! d6 || d6->names.empty()) ) return nullptr; - if ( d4->Expired() || d6->Expired() ) + if ( cleanup_expired && ((d4 && d4->Expired()) || (d6 && d6->Expired())) ) { host_mappings.erase(it); delete d4; @@ -782,13 +1075,52 @@ TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name) return nullptr; } + if ( check_failed && ((d4 && d4->Failed()) || (d6 && d6->Failed())) ) + { + reporter->Warning("Can't resolve host: %s", name.c_str()); + return empty_addr_set(); + } + auto tv4 = d4->AddrsSet(); - auto tv6 = d6->AddrsSet(); - tv4->AddTo(tv6.get(), false); - return tv6; + + if ( d6 ) + { + auto tv6 = d6->AddrsSet(); + tv4->AddTo(tv6.get(), false); + return tv6; + } + + return tv4; } -const char* DNS_Mgr::LookupTextInCache(const std::string& name) +StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired, bool check_failed) + { + AddrMap::iterator it = addr_mappings.find(addr); + if ( it == addr_mappings.end() ) + return nullptr; + + DNS_Mapping* d = it->second; + + if ( cleanup_expired && d->Expired() ) + { + addr_mappings.erase(it); + delete d; + return nullptr; + } + else if ( check_failed && d->Failed() ) + { + std::string s(addr); + reporter->Warning("can't resolve IP address: %s", s.c_str()); + return make_intrusive(s); + } + + if ( d->Host() ) + return d->Host(); + + return make_intrusive("<\?\?\?>"); + } + +StringValPtr DNS_Mgr::LookupTextInCache(const std::string& name, bool cleanup_expired) { TextMap::iterator it = text_mappings.find(name); if ( it == text_mappings.end() ) @@ -796,164 +1128,24 @@ const char* DNS_Mgr::LookupTextInCache(const std::string& name) DNS_Mapping* d = it->second; - if ( d->Expired() ) + if ( cleanup_expired && d->Expired() ) { text_mappings.erase(it); delete d; return nullptr; } - // The escapes in the following strings are to avoid having it - // interpreted as a trigraph sequence. - return d->names.empty() ? "<\?\?\?>" : d->names[0].c_str(); - } + if ( d->Host() ) + return d->Host(); -static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, TableValPtr result) - { - callback->Resolved(result.get()); - - // Don't delete this if testing because we need it to look at the results of the - // request. It'll get deleted by the test when finished. - if ( ! doctest::is_running_in_test ) - delete callback; - } - -static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, const char* result) - { - callback->Resolved(result); - - // Don't delete this if testing because we need it to look at the results of the - // request. It'll get deleted by the test when finished. - if ( ! doctest::is_running_in_test ) - delete callback; - } - -void DNS_Mgr::AsyncLookupAddr(const IPAddr& host, LookupCallback* callback) - { - // This should have been run already from InitPostScript(), but just run it again just - // in case it hadn't. - InitSource(); - - if ( mode == DNS_FAKE ) - { - resolve_lookup_cb(callback, fake_addr_lookup_result(host)); - return; - } - - // Do we already know the answer? - const char* name = LookupAddrInCache(host); - if ( name ) - { - resolve_lookup_cb(callback, name); - return; - } - - AsyncRequest* req = nullptr; - - // Have we already a request waiting for this host? - AsyncRequestAddrMap::iterator i = asyncs_addrs.find(host); - if ( i != asyncs_addrs.end() ) - req = i->second; - else - { - // A new one. - req = new AsyncRequest; - req->host = host; - asyncs_queued.push_back(req); - asyncs_addrs.insert(AsyncRequestAddrMap::value_type(host, req)); - } - - req->callbacks.push_back(callback); - - IssueAsyncRequests(); - } - -void DNS_Mgr::AsyncLookupName(const std::string& name, LookupCallback* callback) - { - // This should have been run already from InitPostScript(), but just run it again just - // in case it hadn't. - InitSource(); - - if ( mode == DNS_FAKE ) - { - resolve_lookup_cb(callback, fake_name_lookup_result(name.c_str())); - return; - } - - // Do we already know the answer? - auto addrs = LookupNameInCache(name); - if ( addrs ) - { - resolve_lookup_cb(callback, std::move(addrs)); - return; - } - - AsyncRequest* req = nullptr; - - // Have we already a request waiting for this host? - AsyncRequestNameMap::iterator i = asyncs_names.find(name); - if ( i != asyncs_names.end() ) - req = i->second; - else - { - // A new one. - req = new AsyncRequest; - req->name = name; - asyncs_queued.push_back(req); - asyncs_names.insert(AsyncRequestNameMap::value_type(name, req)); - } - - req->callbacks.push_back(callback); - - IssueAsyncRequests(); - } - -void DNS_Mgr::AsyncLookupNameText(const std::string& name, LookupCallback* callback) - { - // This should have been run already from InitPostScript(), but just run it again just - // in case it hadn't. - InitSource(); - - if ( mode == DNS_FAKE ) - { - resolve_lookup_cb(callback, fake_text_lookup_result(name.c_str())); - return; - } - - // Do we already know the answer? - const char* txt = LookupTextInCache(name); - - if ( txt ) - { - resolve_lookup_cb(callback, txt); - return; - } - - AsyncRequest* req = nullptr; - - // Have we already a request waiting for this host? - AsyncRequestTextMap::iterator i = asyncs_texts.find(name); - if ( i != asyncs_texts.end() ) - req = i->second; - else - { - // A new one. - req = new AsyncRequest; - req->name = name; - req->is_txt = true; - asyncs_queued.push_back(req); - asyncs_texts.insert(AsyncRequestTextMap::value_type(name, req)); - } - - req->callbacks.push_back(callback); - - IssueAsyncRequests(); + return make_intrusive("<\?\?\?>"); } void DNS_Mgr::IssueAsyncRequests() { while ( ! asyncs_queued.empty() && asyncs_pending < MAX_PENDING_REQUESTS ) { + DNS_Request* dns_req = nullptr; AsyncRequest* req = asyncs_queued.front(); asyncs_queued.pop_front(); @@ -961,30 +1153,48 @@ void DNS_Mgr::IssueAsyncRequests() req->time = util::current_time(); if ( req->IsAddrReq() ) - { - auto* m_req = new DNS_Mgr_Request(req->host); - m_req->MakeRequest(channel); - } + dns_req = new DNS_Request(req->addr, true); else if ( req->is_txt ) - { - auto* m_req = new DNS_Mgr_Request(req->name.c_str(), AF_INET, req->is_txt); - m_req->MakeRequest(channel); - } + dns_req = new DNS_Request(req->host.c_str(), AF_UNSPEC, T_TXT, true); else - { - // If only one request type succeeds, don't consider it a failure. - auto* m_req4 = new DNS_Mgr_Request(req->name.c_str(), AF_INET, req->is_txt); - m_req4->MakeRequest(channel); - auto* m_req6 = new DNS_Mgr_Request(req->name.c_str(), AF_INET6, req->is_txt); - m_req6->MakeRequest(channel); - } + // We pass T_A here, but because we're passing AF_UNSPEC MakeRequest() will + // have c-ares attempt to lookup both ipv4 and ipv6 at the same time. + dns_req = new DNS_Request(req->host.c_str(), AF_UNSPEC, T_A, true); + + dns_req->MakeRequest(channel); asyncs_timeouts.push(req); - ++asyncs_pending; } } +void DNS_Mgr::CheckAsyncHostRequest(const std::string& host, bool timeout) + { + // Note that this code is a mirror of that for CheckAsyncAddrRequest. + + AsyncRequestNameMap::iterator i = asyncs_names.find(host); + + if ( i != asyncs_names.end() ) + { + if ( timeout ) + { + ++failed; + i->second->Timeout(); + } + else if ( auto addrs = LookupNameInCache(host, true, false) ) + { + ++successful; + i->second->Resolved(addrs); + } + else + return; + + delete i->second; + asyncs_names.erase(i); + --asyncs_pending; + } + } + void DNS_Mgr::CheckAsyncAddrRequest(const IPAddr& addr, bool timeout) { // Note that this code is a mirror of that for CheckAsyncHostRequest. @@ -995,101 +1205,57 @@ void DNS_Mgr::CheckAsyncAddrRequest(const IPAddr& addr, bool timeout) if ( i != asyncs_addrs.end() ) { - const char* name = LookupAddrInCache(addr); - if ( name ) - { - ++successful; - i->second->Resolved(name); - } - - else if ( timeout ) + if ( timeout ) { ++failed; i->second->Timeout(); } - + else if ( auto name = LookupAddrInCache(addr, true, false) ) + { + ++successful; + i->second->Resolved(name->CheckString()); + } else return; + delete i->second; asyncs_addrs.erase(i); --asyncs_pending; - - // Don't delete the request. That will be done once it - // eventually times out. } } -void DNS_Mgr::CheckAsyncTextRequest(const char* host, bool timeout) +void DNS_Mgr::CheckAsyncTextRequest(const std::string& host, bool timeout) { // Note that this code is a mirror of that for CheckAsyncAddrRequest. AsyncRequestTextMap::iterator i = asyncs_texts.find(host); if ( i != asyncs_texts.end() ) { - const char* name = LookupTextInCache(host); - if ( name ) - { - ++successful; - i->second->Resolved(name); - } - - else if ( timeout ) + if ( timeout ) { AsyncRequestTextMap::iterator it = asyncs_texts.begin(); ++failed; i->second->Timeout(); } - - else - return; - - asyncs_texts.erase(i); - --asyncs_pending; - - // Don't delete the request. That will be done once it - // eventually times out. - } - } - -void DNS_Mgr::CheckAsyncHostRequest(const char* host, bool timeout) - { - // Note that this code is a mirror of that for CheckAsyncAddrRequest. - - AsyncRequestNameMap::iterator i = asyncs_names.find(host); - - if ( i != asyncs_names.end() ) - { - auto addrs = LookupNameInCache(host); - - if ( addrs ) + else if ( auto name = LookupTextInCache(host, true) ) { ++successful; - i->second->Resolved(addrs.get()); + i->second->Resolved(name->CheckString()); } - - else if ( timeout ) - { - ++failed; - i->second->Timeout(); - } - else return; - asyncs_names.erase(i); + delete i->second; + asyncs_texts.erase(i); --asyncs_pending; - - // Don't delete the request. That will be done once it - // eventually times out. } } void DNS_Mgr::Flush() { - Process(); + Resolve(); - HostMap::iterator it; - for ( it = host_mappings.begin(); it != host_mappings.end(); ++it ) + for ( HostMap::iterator it = host_mappings.begin(); it != host_mappings.end(); ++it ) { delete it->second.first; delete it->second.second; @@ -1116,67 +1282,21 @@ double DNS_Mgr::GetNextTimeout() void DNS_Mgr::Process() { - while ( ! asyncs_timeouts.empty() ) + // If iosource_mgr says that we got a result on the socket fd, we don't have to ask c-ares + // to retrieve it for us. We have the file descriptor already, just call ares_process_fd() + // with it. Unfortunately, we may also have sockets close during this call, so we need to + // to make a copy of the list first. Having a list change while looping over it can + // cause segfaults. + decltype(socket_fds) temp_fds{socket_fds}; + + for ( int fd : temp_fds ) { - AsyncRequest* req = asyncs_timeouts.top(); - - if ( req->time + DNS_TIMEOUT > util::current_time() && ! run_state::terminating ) - break; - - if ( ! req->processed ) - { - if ( req->IsAddrReq() ) - CheckAsyncAddrRequest(req->host, true); - else if ( req->is_txt ) - CheckAsyncTextRequest(req->name.c_str(), true); - else - CheckAsyncHostRequest(req->name.c_str(), true); - } - - asyncs_timeouts.pop(); - delete req; + // double check this one wasn't removed already before trying to process it + if ( socket_fds.count(fd) != 0 ) + ares_process_fd(channel, fd, ARES_SOCKET_BAD); } - Resolve(); - - // TODO: what does the rest below do? - /* - char err[NB_DNS_ERRSIZE]; - struct nb_dns_result r; - - int status = nb_dns_activity(nb_dns, &r, err); - - if ( status < 0 ) - reporter->Warning("NB-DNS error in DNS_Mgr::Process (%s)", err); - - else if ( status > 0 ) - { - DNS_Mgr_Request* dr = (DNS_Mgr_Request*)r.cookie; - - bool do_host_timeout = true; - if ( dr->ReqHost() && host_mappings.find(dr->ReqHost()) == host_mappings.end() ) - // Don't timeout when this is the first result in an expected pair - // (one result each for A and AAAA queries). - do_host_timeout = false; - - if ( dr->RequestPending() ) - { - AddResult(dr, &r); - dr->RequestDone(); - } - - if ( ! dr->ReqHost() ) - CheckAsyncAddrRequest(dr->ReqAddr(), true); - else if ( dr->ReqIsTxt() ) - CheckAsyncTextRequest(dr->ReqHost(), do_host_timeout); - else - CheckAsyncHostRequest(dr->ReqHost(), do_host_timeout); - - IssueAsyncRequests(); - - delete dr; - } - */ + IssueAsyncRequests(); } void DNS_Mgr::GetStats(Stats* stats) @@ -1191,14 +1311,7 @@ void DNS_Mgr::GetStats(Stats* stats) stats->cached_texts = text_mappings.size(); } -void DNS_Mgr::TestProcess() - { - // Only allow usage of this method when running unit tests. - assert(doctest::is_running_in_test); - Process(); - } - -void DNS_Mgr::AsyncRequest::Resolved(const char* name) +void DNS_Mgr::AsyncRequest::Resolved(const std::string& name) { for ( const auto& cb : callbacks ) { @@ -1211,7 +1324,7 @@ void DNS_Mgr::AsyncRequest::Resolved(const char* name) processed = true; } -void DNS_Mgr::AsyncRequest::Resolved(TableVal* addrs) +void DNS_Mgr::AsyncRequest::Resolved(TableValPtr addrs) { for ( const auto& cb : callbacks ) { @@ -1252,7 +1365,7 @@ TableValPtr DNS_Mgr::empty_addr_set() ////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////// -static std::vector get_result_addresses(TableVal* addrs) +static std::vector get_result_addresses(TableValPtr addrs) { std::vector results; @@ -1275,12 +1388,12 @@ class TestCallback : public DNS_Mgr::LookupCallback { public: TestCallback() { } - void Resolved(const char* name) override + void Resolved(const std::string& name) override { host_result = name; done = true; } - void Resolved(TableVal* addrs) override + void Resolved(TableValPtr addrs) override { addr_results = get_result_addresses(addrs); done = true; @@ -1297,7 +1410,27 @@ public: bool timeout = false; }; -TEST_CASE("dns_mgr prime,save,load") +/** + * Derived testing version of DNS_Mgr so that the Process() method can be exposed + * publically. If new unit tests are added, this class should be used over using + * DNS_Mgr directly. + */ +class TestDNS_Mgr final : public DNS_Mgr + { +public: + explicit TestDNS_Mgr(DNS_MgrMode mode) : DNS_Mgr(mode) { } + void Process(); + }; + +void TestDNS_Mgr::Process() + { + // Only allow usage of this method when running unit tests. + assert(doctest::is_running_in_test); + Resolve(); + IssueAsyncRequests(); + } + +TEST_CASE("dns_mgr priming") { char prefix[] = "/tmp/zeek-unit-test-XXXXXX"; auto tmpdir = mkdtemp(prefix); @@ -1305,18 +1438,21 @@ TEST_CASE("dns_mgr prime,save,load") // Create a manager to prime the cache, make a few requests, and the save // the result. This tests that the priming code will create the requests but // wait for Resolve() to actually make the requests. - DNS_Mgr mgr(DNS_PRIME); + TestDNS_Mgr mgr(DNS_PRIME); + dns_mgr = &mgr; mgr.SetDir(tmpdir); mgr.InitPostScript(); auto host_result = mgr.LookupHost("one.one.one.one"); REQUIRE(host_result != nullptr); - CHECK(host_result->EqualTo(DNS_Mgr::empty_addr_set())); + CHECK(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); IPAddr ones("1.1.1.1"); auto addr_result = mgr.LookupAddr(ones); CHECK(strcmp(addr_result->CheckString(), "") == 0); + // This should wait until we have all of the results back from the above + // requests. mgr.Resolve(); // Save off the resulting values from Resolve() into a file on disk @@ -1324,7 +1460,8 @@ TEST_CASE("dns_mgr prime,save,load") REQUIRE(mgr.Save()); // Make a second DNS manager and reload the cache that we just saved. - DNS_Mgr mgr2(DNS_FORCE); + TestDNS_Mgr mgr2(DNS_FORCE); + dns_mgr = &mgr2; mgr2.SetDir(tmpdir); mgr2.InitPostScript(); @@ -1332,11 +1469,15 @@ TEST_CASE("dns_mgr prime,save,load") // data out of the cache. host_result = mgr2.LookupHost("one.one.one.one"); REQUIRE(host_result != nullptr); - CHECK_FALSE(host_result->EqualTo(DNS_Mgr::empty_addr_set())); + CHECK_FALSE(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); addr_result = mgr2.LookupAddr(ones); REQUIRE(addr_result != nullptr); CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); + + // Clean up cache file and the temp directory + unlink(mgr2.CacheFile().c_str()); + rmdir(tmpdir); } TEST_CASE("dns_mgr alternate server") @@ -1344,7 +1485,9 @@ TEST_CASE("dns_mgr alternate server") char* old_server = getenv("ZEEK_DNS_RESOLVER"); setenv("ZEEK_DNS_RESOLVER", "1.1.1.1", 1); - DNS_Mgr mgr(DNS_DEFAULT); + TestDNS_Mgr mgr(DNS_DEFAULT); + dns_mgr = &mgr; + mgr.InitPostScript(); auto result = mgr.LookupAddr("1.1.1.1"); @@ -1353,7 +1496,7 @@ TEST_CASE("dns_mgr alternate server") // FIXME: This won't run on systems without IPv6 connectivity. // setenv("ZEEK_DNS_RESOLVER", "2606:4700:4700::1111", 1); - // DNS_Mgr mgr2(DNS_DEFAULT, true); + // TestDNS_Mgr mgr2(DNS_DEFAULT, true); // mgr2.InitPostScript(); // result = mgr2.LookupAddr("1.1.1.1"); // mgr2.Resolve(); @@ -1369,15 +1512,16 @@ TEST_CASE("dns_mgr alternate server") TEST_CASE("dns_mgr default mode") { - DNS_Mgr mgr(DNS_DEFAULT); + TestDNS_Mgr mgr(DNS_DEFAULT); + dns_mgr = &mgr; mgr.InitPostScript(); IPAddr ones("1.1.1.1"); auto host_result = mgr.LookupHost("one.one.one.one"); REQUIRE(host_result != nullptr); - CHECK_FALSE(host_result->EqualTo(DNS_Mgr::empty_addr_set())); + CHECK_FALSE(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); - auto addrs_from_request = get_result_addresses(host_result.get()); + auto addrs_from_request = get_result_addresses(host_result); auto it = std::find(addrs_from_request.begin(), addrs_from_request.end(), ones); CHECK(it != addrs_from_request.end()); @@ -1393,18 +1537,19 @@ TEST_CASE("dns_mgr default mode") TEST_CASE("dns_mgr async host") { - DNS_Mgr mgr(DNS_DEFAULT); + TestDNS_Mgr mgr(DNS_DEFAULT); + dns_mgr = &mgr; mgr.InitPostScript(); TestCallback cb{}; - mgr.AsyncLookupName("one.one.one.one", &cb); + mgr.LookupHost("one.one.one.one", &cb); // This shouldn't take any longer than DNS_TIMEOUT+1 seconds, so bound it // just in case of some failure we're not aware of yet. int count = 0; while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) { - mgr.TestProcess(); + mgr.Process(); sleep(1); if ( ! cb.timeout ) count++; @@ -1424,18 +1569,19 @@ TEST_CASE("dns_mgr async host") TEST_CASE("dns_mgr async addr") { - DNS_Mgr mgr(DNS_DEFAULT); + TestDNS_Mgr mgr(DNS_DEFAULT); + dns_mgr = &mgr; mgr.InitPostScript(); TestCallback cb{}; - mgr.AsyncLookupAddr(IPAddr{"1.1.1.1"}, &cb); + mgr.LookupAddr(IPAddr{"1.1.1.1"}, &cb); // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it // just in case of some failure we're not aware of yet. int count = 0; while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) { - mgr.TestProcess(); + mgr.Process(); sleep(1); if ( ! cb.timeout ) count++; @@ -1450,18 +1596,19 @@ TEST_CASE("dns_mgr async addr") TEST_CASE("dns_mgr async text") { - DNS_Mgr mgr(DNS_DEFAULT); + TestDNS_Mgr mgr(DNS_DEFAULT); + dns_mgr = &mgr; mgr.InitPostScript(); TestCallback cb{}; - mgr.AsyncLookupNameText("unittest.zeek.org", &cb); + mgr.Lookup("unittest.zeek.org", T_TXT, &cb); // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it // just in case of some failure we're not aware of yet. int count = 0; while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) { - mgr.TestProcess(); + mgr.Process(); sleep(1); if ( ! cb.timeout ) count++; @@ -1482,7 +1629,7 @@ TEST_CASE("dns_mgr timeouts") // server that lets you connect but never returns any responses, always // resulting in a timeout. setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1); - DNS_Mgr mgr(DNS_DEFAULT); + TestDNS_Mgr mgr(DNS_DEFAULT); dns_mgr = &mgr; mgr.InitPostScript(); @@ -1492,7 +1639,7 @@ TEST_CASE("dns_mgr timeouts") auto host_result = mgr.LookupHost("one.one.one.one"); REQUIRE(host_result != nullptr); - auto addresses = get_result_addresses(host_result.get()); + auto addresses = get_result_addresses(host_result); CHECK(addresses.size() == 0); if ( old_server ) @@ -1509,19 +1656,19 @@ TEST_CASE("dns_mgr async timeouts") // server that lets you connect but never returns any responses, always // resulting in a timeout. setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1); - DNS_Mgr mgr(DNS_DEFAULT); + TestDNS_Mgr mgr(DNS_DEFAULT); dns_mgr = &mgr; mgr.InitPostScript(); TestCallback cb{}; - mgr.AsyncLookupNameText("unittest.zeek.org", &cb); + mgr.Lookup("unittest.zeek.org", T_TXT, &cb); - // This shouldn't take any longer than DNS_TIMEOUT +2 seconds, so bound it + // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it // just in case of some failure we're not aware of yet. int count = 0; while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) { - mgr.TestProcess(); + mgr.Process(); sleep(1); if ( ! cb.timeout ) count++; diff --git a/src/DNS_Mgr.h b/src/DNS_Mgr.h index e5bfba9689..5f6204f45d 100644 --- a/src/DNS_Mgr.h +++ b/src/DNS_Mgr.h @@ -2,7 +2,7 @@ #pragma once -#include +#include #include #include #include @@ -14,6 +14,18 @@ #include "zeek/iosource/IOSource.h" #include "zeek/util.h" +// These are defined in ares headers but we don't want to have to include +// those headers here and create install dependencies on them. +struct ares_channeldata; +typedef struct ares_channeldata* ares_channel; +#ifndef T_PTR +#define T_PTR 12 +#endif + +#ifndef T_TXT +#define T_TXT 16 +#endif + namespace zeek { class Val; @@ -31,8 +43,8 @@ using StringValPtr = IntrusivePtr; namespace zeek::detail { -class DNS_Mgr_Request; class DNS_Mapping; +class DNS_Request; enum DNS_MgrMode { @@ -42,9 +54,44 @@ enum DNS_MgrMode DNS_FAKE, // don't look up names, just return dummy results }; -class DNS_Mgr final : public iosource::IOSource +class DNS_Mgr : public iosource::IOSource { public: + /** + * Base class for callback handling for asynchronous lookups. + */ + class LookupCallback + { + public: + virtual ~LookupCallback() = default; + + /** + * Called when an address lookup finishes. + * + * @param name The resulting name from the lookup. + */ + virtual void Resolved(const std::string& name){}; + + /** + * Called when a name lookup finishes. + * + * @param addrs A table of the resulting addresses from the lookup. + */ + virtual void Resolved(TableValPtr addrs){}; + + /** + * Generic callback method for all request types. + * + * @param val A Val containing the data from the query. + */ + virtual void Resolved(ValPtr data, int request_type) { } + + /** + * Called when a timeout request occurs. + */ + virtual void Timeout() = 0; + }; + explicit DNS_Mgr(DNS_MgrMode mode); ~DNS_Mgr() override; @@ -61,27 +108,79 @@ public: void Flush(); /** - * Looks up the address(es) of a given host and returns a set of addr. - * This is a synchronous method and will block until results are ready. + * Looks up the address(es) of a given host and returns a set of addresses. + * This is a shorthand method for doing A/AAAA requests. This is a + * synchronous request and will block until the request completes or times + * out. * - * @param host The host name to look up an address for. - * @return A set of addresses. + * @param host The hostname to lookup an address for. + * @return A set of addresses for the host. */ - TableValPtr LookupHost(const char* host); + TableValPtr LookupHost(const std::string& host); /** - * Looks up the hostname of a given address. This is a synchronous method - * and will block until results are ready. + * Looks up the hostname of a given address. This is a shorthand method for + * doing PTR requests. This is a synchronous request and will block until + * the request completes or times out. * * @param host The addr to lookup a hostname for. - * @return The hostname. + * @return The hostname for the address. */ StringValPtr LookupAddr(const IPAddr& addr); + /** + * Performs a generic request to the DNS server. This is a synchronous + * request and will block until the request completes or times out. + * + * @param name The name or address to make a request for. If this is an + * address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa). + * Note that calling LookupAddr for PTR requests does this conversion + * automatically. + * @param request_type The type of request to make. This should be one of + * the type values defined in arpa/nameser.h or ares_nameser.h. + * @return The requested data. + */ + ValPtr Lookup(const std::string& name, int request_type); + + /** + * Looks up the address(es) of a given host. This is a shorthand method + * for doing A/AAAA requests. This is an asynchronous request. The + * response will be handled via the provided callback object. + * + * @param host The hostname to lookup an address for. + * @param callback A callback object for handling the response. + */ + void LookupHost(const std::string& host, LookupCallback* callback); + + /** + * Looks up the hostname of a given address. This is a shorthand method for + * doing PTR requests. This is an asynchronous request. The response will + * be handled via the provided callback object. + * + * @param host The addr to lookup a hostname for. + * @param callback A callback object for handling the response. + */ + void LookupAddr(const IPAddr& addr, LookupCallback* callback); + + /** + * Performs a generic request to the DNS server. This is an asynchronous + * request. The response will be handled via the provided callback + * object. + * + * @param name The name or address to make a request for. If this is an + * address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa). + * Note that calling LookupAddr for PTR requests does this conversion + * automatically. + * @param request_type The type of request to make. This should be one of + * the type values defined in arpa/nameser.h or ares_nameser.h. + * @param callback A callback object for handling the response. + */ + void Lookup(const std::string& name, int request_type, LookupCallback* callback); + /** * Sets the directory where to store DNS data when Save() is called. */ - void SetDir(const char* arg_dir) { dir = arg_dir; } + void SetDir(const std::string& arg_dir) { dir = arg_dir; } /** * Waits for responses to become available or a timeout to occur, @@ -94,61 +193,6 @@ public: */ bool Save(); - /** - * Base class for callback handling for asynchronous lookups. - */ - class LookupCallback - { - public: - virtual ~LookupCallback() = default; - - /** - * Called when an address lookup finishes. - * - * @param name The resulting name from the lookup. - */ - virtual void Resolved(const char* name){}; - - /** - * Called when a name lookup finishes. - * - * @param addrs A table of the resulting addresses from the lookup. - */ - virtual void Resolved(TableVal* addrs){}; - - /** - * Called when a timeout request occurs. - */ - virtual void Timeout() = 0; - }; - - /** - * Schedules an asynchronous request to lookup a hostname for an IP address. - * This is the equivalent of an "A" or "AAAA" request, depending on if the - * address is ipv4 or ipv6. - * - * @param host The address to lookup names for. - * @param callback A callback object to call when the request completes. - */ - void AsyncLookupAddr(const IPAddr& host, LookupCallback* callback); - - /** - * Schedules an asynchronous request to lookup an address for a hostname. - * This is the equivalent of a "PTR" request. - * - * @param host The hostname to look up addresses for. - * @param callback A callback object to call when the request completes. - */ - void AsyncLookupName(const std::string& name, LookupCallback* callback); - - /** - * Schedules an asynchronous TXT request for a hostname. - * - * @param host The address to lookup names for. - * @param callback A callback object to call when the request completes. - */ - void AsyncLookupNameText(const std::string& name, LookupCallback* callback); - struct Stats { unsigned long requests; // These count only async requests. @@ -175,7 +219,7 @@ public: * @param h A hostent structure containing the actual result data. * @param ttl A ttl value contained in the response from the server. */ - void AddResult(DNS_Mgr_Request* dr, struct hostent* h, uint32_t ttl); + void AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl); /** * Returns an empty set of addresses, used in various error cases and during @@ -184,18 +228,30 @@ public: static TableValPtr empty_addr_set(); /** - * This method is used to call the private Process() method during unit testing - * and shouldn't be used otherwise. + * Returns the full path to the file used to store the DNS cache. */ - void TestProcess(); + std::string CacheFile() const { return cache_name; } + + /** + * Used by the c-ares socket call back to register/unregister a socket file descriptor. + */ + void RegisterSocket(int fd, bool active); protected: friend class LookupCallback; - friend class DNS_Mgr_Request; + friend class DNS_Request; - const char* LookupAddrInCache(const IPAddr& addr); - TableValPtr LookupNameInCache(const std::string& name); - const char* LookupTextInCache(const std::string& name); + StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, + bool check_failed = false); + TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, + bool check_failed = false); + StringValPtr LookupTextInCache(const std::string& name, bool cleanup_expired = false); + + // Finish the request if we have a result. If not, time it out if + // requested. + void CheckAsyncAddrRequest(const IPAddr& addr, bool timeout); + void CheckAsyncHostRequest(const std::string& host, bool timeout); + void CheckAsyncTextRequest(const std::string& host, bool timeout); void Event(EventHandlerPtr e, DNS_Mapping* dm); void Event(EventHandlerPtr e, DNS_Mapping* dm, ListValPtr l1, ListValPtr l2); @@ -205,7 +261,6 @@ protected: void CompareMappings(DNS_Mapping* prev_dm, DNS_Mapping* new_dm); ListValPtr AddrListDelta(ListVal* al1, ListVal* al2); - void DumpAddrList(FILE* f, ListVal* al); using HostMap = std::map>; using AddrMap = std::map; @@ -217,12 +272,6 @@ protected: // Issue as many queued async requests as slots are available. void IssueAsyncRequests(); - // Finish the request if we have a result. If not, time it out if - // requested. - void CheckAsyncAddrRequest(const IPAddr& addr, bool timeout); - void CheckAsyncHostRequest(const char* host, bool timeout); - void CheckAsyncTextRequest(const char* host, bool timeout); - // IOSource interface. void Process() override; void InitSource() override; @@ -235,9 +284,6 @@ protected: AddrMap addr_mappings; TextMap text_mappings; - using DNS_mgr_request_list = PList; - DNS_mgr_request_list requests; - std::string cache_name; std::string dir; // directory in which cache_name resides @@ -247,26 +293,30 @@ protected: RecordTypePtr dm_rec; ares_channel channel; - bool ipv6_resolver = false; using CallbackList = std::list; struct AsyncRequest { double time = 0.0; - IPAddr host; - std::string name; + IPAddr addr; + std::string host; CallbackList callbacks; bool is_txt = false; bool processed = false; - bool IsAddrReq() const { return name.empty(); } + bool IsAddrReq() const { return host.empty(); } - void Resolved(const char* name); - void Resolved(TableVal* addrs); + void Resolved(const std::string& name); + void Resolved(TableValPtr addrs); void Timeout(); }; + struct AsyncRequestCompare + { + bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; } + }; + using AsyncRequestAddrMap = std::map; AsyncRequestAddrMap asyncs_addrs; @@ -279,18 +329,15 @@ protected: using QueuedList = std::list; QueuedList asyncs_queued; - struct AsyncRequestCompare - { - bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; } - }; - using TimeoutQueue = std::priority_queue, AsyncRequestCompare>; TimeoutQueue asyncs_timeouts; - unsigned long num_requests; - unsigned long successful; - unsigned long failed; + unsigned long num_requests = 0; + unsigned long successful = 0; + unsigned long failed = 0; + + std::set socket_fds; }; extern DNS_Mgr* dns_mgr; diff --git a/src/zeek-setup.cc b/src/zeek-setup.cc index 551dd003fc..952db62b4f 100644 --- a/src/zeek-setup.cc +++ b/src/zeek-setup.cc @@ -342,8 +342,7 @@ static void terminate_zeek() delete packet_mgr; delete analyzer_mgr; delete file_mgr; - delete dns_mgr; - // broker_mgr, timer_mgr, and supervisor are deleted via iosource_mgr + // broker_mgr, timer_mgr, supervisor, and dns_mgr are deleted via iosource_mgr delete iosource_mgr; delete event_registry; delete log_mgr; @@ -757,7 +756,6 @@ SetupResult setup(int argc, char** argv, Options* zopts) file_mgr->InitPostScript(); dns_mgr->InitPostScript(); - dns_mgr->LookupHost("www.apple.com"); // dns_mgr->LookupAddr("17.253.144.10"); #ifdef USE_PERFTOOLS_DEBUG diff --git a/src/zeek.bif b/src/zeek.bif index 3fd41e0901..4665831d65 100644 --- a/src/zeek.bif +++ b/src/zeek.bif @@ -3642,8 +3642,8 @@ function dump_packet%(pkt: pcap_packet, file_name: string%) : bool class LookupHostCallback : public zeek::detail::DNS_Mgr::LookupCallback { public: - LookupHostCallback(zeek::detail::trigger::Trigger* arg_trigger, const zeek::detail::CallExpr* arg_call, - bool arg_lookup_name) + LookupHostCallback(zeek::detail::trigger::Trigger* arg_trigger, + const zeek::detail::CallExpr* arg_call, bool arg_lookup_name) { Ref(arg_trigger); trigger = arg_trigger; @@ -3657,7 +3657,7 @@ public: } // Overridden from zeek::detail::DNS_Mgr:Lookup:Callback. - void Resolved(const char* name) override + void Resolved(const std::string& name) override { zeek::Val* result = new zeek::StringVal(name); trigger->Cache(call, result); @@ -3665,10 +3665,10 @@ public: trigger->Release(); } - void Resolved(zeek::TableVal* addrs) override + void Resolved(zeek::TableValPtr addrs) override { // No Ref() for addrs. - trigger->Cache(call, addrs); + trigger->Cache(call, addrs.get()); trigger->Release(); } @@ -3724,7 +3724,7 @@ function lookup_addr%(host: addr%) : string frame->SetDelayed(); trigger->Hold(); - zeek::detail::dns_mgr->AsyncLookupAddr(host->AsAddr(), + zeek::detail::dns_mgr->LookupAddr(host->AsAddr(), new LookupHostCallback(trigger, frame->GetCall(), true)); return nullptr; %} @@ -3753,7 +3753,7 @@ function lookup_hostname_txt%(host: string%) : string frame->SetDelayed(); trigger->Hold(); - zeek::detail::dns_mgr->AsyncLookupNameText(host->CheckString(), + zeek::detail::dns_mgr->Lookup(host->CheckString(), T_TXT, new LookupHostCallback(trigger, frame->GetCall(), true)); return nullptr; %} @@ -3782,7 +3782,7 @@ function lookup_hostname%(host: string%) : addr_set frame->SetDelayed(); trigger->Hold(); - zeek::detail::dns_mgr->AsyncLookupName(host->CheckString(), + zeek::detail::dns_mgr->LookupHost(host->CheckString(), new LookupHostCallback(trigger, frame->GetCall(), false)); return nullptr; %}