diff --git a/cmake b/cmake index b1c9e55d1e..649c319f88 160000 --- a/cmake +++ b/cmake @@ -1 +1 @@ -Subproject commit b1c9e55d1e837d46fc36312b567245013cf9646a +Subproject commit 649c319f88e2966931892d55adb2ee50f278662b diff --git a/src/3rdparty b/src/3rdparty index 0af190f905..6cbb3d6587 160000 --- a/src/3rdparty +++ b/src/3rdparty @@ -1 +1 @@ -Subproject commit 0af190f90572abc90366471f36e6feb1b817d2ab +Subproject commit 6cbb3d65877f80326c047364583f506ce58758ba diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 63478cc610..f5158bcb50 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -286,6 +286,7 @@ set(MAIN_SRCS Desc.cc Dict.cc Discard.cc + DNS_Mapping.cc DNS_Mgr.cc EquivClass.cc Event.cc diff --git a/src/DNS_Mapping.cc b/src/DNS_Mapping.cc new file mode 100644 index 0000000000..1e98788904 --- /dev/null +++ b/src/DNS_Mapping.cc @@ -0,0 +1,387 @@ +#include "zeek/DNS_Mapping.h" + +#include "zeek/3rdparty/doctest.h" +#include "zeek/DNS_Mgr.h" + +namespace zeek::detail + { + +DNS_Mapping::DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl) + { + Init(h); + req_host = host; + req_ttl = ttl; + + if ( names.empty() ) + names.push_back(host); + } + +DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) + { + Init(h); + req_addr = addr; + req_ttl = ttl; + } + +DNS_Mapping::DNS_Mapping(FILE* f) + { + Clear(); + init_failed = true; + + req_ttl = 0; + creation_time = 0; + + char buf[512]; + + if ( ! fgets(buf, sizeof(buf), f) ) + { + no_mapping = true; + return; + } + + char req_buf[512 + 1], name_buf[512 + 1]; + int is_req_host; + int failed_local; + int num_addrs; + + if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, + &failed_local, name_buf, &map_type, &num_addrs, &req_ttl) != 8 ) + return; + + failed = static_cast(failed_local); + + if ( is_req_host ) + req_host = req_buf; + else + req_addr = IPAddr(req_buf); + + names.push_back(name_buf); + + for ( int i = 0; i < num_addrs; ++i ) + { + if ( ! fgets(buf, sizeof(buf), f) ) + return; + + char* newline = strchr(buf, '\n'); + if ( newline ) + *newline = '\0'; + + addrs.emplace_back(IPAddr(buf)); + } + + init_failed = false; + } + +ListValPtr DNS_Mapping::Addrs() + { + if ( failed ) + return nullptr; + + if ( ! addrs_val ) + { + addrs_val = make_intrusive(TYPE_ADDR); + + for ( const auto& addr : addrs ) + addrs_val->Append(make_intrusive(addr)); + } + + return addrs_val; + } + +TableValPtr DNS_Mapping::AddrsSet() + { + auto l = Addrs(); + + if ( ! l || l->Length() == 0 ) + return DNS_Mgr::empty_addr_set(); + + return l->ToSetVal(); + } + +StringValPtr DNS_Mapping::Host() + { + if ( failed || names.empty() ) + return nullptr; + + if ( ! host_val ) + host_val = make_intrusive(names[0]); + + return host_val; + } + +void DNS_Mapping::Init(struct hostent* h) + { + no_mapping = false; + init_failed = false; + creation_time = util::current_time(); + host_val = nullptr; + addrs_val = nullptr; + + if ( ! h ) + { + Clear(); + return; + } + + map_type = h->h_addrtype; + if ( h->h_name ) + // for now, just use the official name + // TODO: this could easily be expanded to include all of the aliases as well + names.push_back(h->h_name); + + for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) + { + if ( h->h_addrtype == AF_INET ) + addrs.push_back(IPAddr(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network)); + else if ( h->h_addrtype == AF_INET6 ) + addrs.push_back(IPAddr(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network)); + } + + failed = false; + } + +void DNS_Mapping::Clear() + { + names.clear(); + host_val = nullptr; + addrs.clear(); + addrs_val = nullptr; + no_mapping = false; + map_type = 0; + failed = true; + } + +void DNS_Mapping::Save(FILE* f) const + { + fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(), + req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed, + names.empty() ? "*" : names[0].c_str(), map_type, addrs.size(), req_ttl); + + for ( const auto& addr : addrs ) + fprintf(f, "%s\n", addr.AsString().c_str()); + } + +////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////// + +TEST_CASE("dns_mapping init null hostent") + { + DNS_Mapping mapping("www.apple.com", nullptr, 123); + + CHECK(! mapping.Valid()); + CHECK(mapping.Addrs() == nullptr); + CHECK(mapping.AddrsSet()->EqualTo(DNS_Mgr::empty_addr_set())); + CHECK(mapping.Host() == nullptr); + } + +TEST_CASE("dns_mapping init host") + { + IPAddr addr("1.2.3.4"); + in4_addr in4; + addr.CopyIPv4(&in4); + + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); + + std::vector addrs = {&in4, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); + + DNS_Mapping mapping("testing.home", &he, 123); + CHECK(mapping.Valid()); + CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified); + CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0); + CHECK(mapping.ReqStr() == "testing.home"); + + auto lva = mapping.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 1); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "1.2.3.4"); + + auto tvas = mapping.AddrsSet(); + REQUIRE(tvas != nullptr); + CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set())); + + auto svh = mapping.Host(); + REQUIRE(svh != nullptr); + CHECK(svh->ToStdString() == "testing.home"); + + delete[] he.h_name; + } + +TEST_CASE("dns_mapping init addr") + { + IPAddr addr("1.2.3.4"); + in4_addr in4; + addr.CopyIPv4(&in4); + + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); + + std::vector addrs = {&in4, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); + + DNS_Mapping mapping(addr, &he, 123); + CHECK(mapping.Valid()); + CHECK(mapping.ReqAddr() == addr); + CHECK(mapping.ReqHost() == nullptr); + CHECK(mapping.ReqStr() == "1.2.3.4"); + + auto lva = mapping.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 1); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "1.2.3.4"); + + auto tvas = mapping.AddrsSet(); + REQUIRE(tvas != nullptr); + CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set())); + + auto svh = mapping.Host(); + REQUIRE(svh != nullptr); + CHECK(svh->ToStdString() == "testing.home"); + + delete[] he.h_name; + } + +TEST_CASE("dns_mapping save reload") + { + IPAddr addr("1.2.3.4"); + in4_addr in4; + addr.CopyIPv4(&in4); + + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); + + std::vector addrs = {&in4, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); + + // Create a temporary file in memory and fseek to the end of it so we're at + // EOF for the next bit. + char buffer[4096]; + memset(buffer, 0, 4096); + FILE* tmpfile = fmemopen(buffer, 4096, "r+"); + fseek(tmpfile, 0, SEEK_END); + + // Try loading from the file at EOF. This should cause a mapping failure. + DNS_Mapping mapping(tmpfile); + CHECK(mapping.NoMapping()); + rewind(tmpfile); + + // Try reading from the empty file. This should cause an init failure. + DNS_Mapping mapping2(tmpfile); + CHECK(mapping2.InitFailed()); + rewind(tmpfile); + + // Save a valid mapping into the file and rewind to the start. + DNS_Mapping mapping3(addr, &he, 123); + mapping3.Save(tmpfile); + rewind(tmpfile); + + // Test loading the mapping back out of the file + DNS_Mapping mapping4(tmpfile); + fclose(tmpfile); + CHECK(mapping4.Valid()); + CHECK(mapping4.ReqAddr() == addr); + CHECK(mapping4.ReqHost() == nullptr); + CHECK(mapping4.ReqStr() == "1.2.3.4"); + + auto lva = mapping4.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 1); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "1.2.3.4"); + + auto tvas = mapping4.AddrsSet(); + REQUIRE(tvas != nullptr); + CHECK(tvas != DNS_Mgr::empty_addr_set()); + + auto svh = mapping4.Host(); + REQUIRE(svh != nullptr); + CHECK(svh->ToStdString() == "testing.home"); + + delete[] he.h_name; + } + +TEST_CASE("dns_mapping multiple addresses") + { + IPAddr addr("1.2.3.4"); + in4_addr in4_1; + addr.CopyIPv4(&in4_1); + + IPAddr addr2("5.6.7.8"); + in4_addr in4_2; + addr2.CopyIPv4(&in4_2); + + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); + + std::vector addrs = {&in4_1, &in4_2, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); + + DNS_Mapping mapping("testing.home", &he, 123); + CHECK(mapping.Valid()); + + auto lva = mapping.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 2); + + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "1.2.3.4"); + + lvae = lva->Idx(1)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "5.6.7.8"); + + delete[] he.h_name; + } + +TEST_CASE("dns_mapping ipv6") + { + IPAddr addr("64:ff9b:1::"); + in6_addr in6; + addr.CopyIPv6(&in6); + + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET6; + he.h_length = sizeof(in6_addr); + + std::vector addrs = {&in6, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); + + DNS_Mapping mapping(addr, &he, 123); + CHECK(mapping.Valid()); + CHECK(mapping.ReqAddr() == addr); + CHECK(mapping.ReqHost() == nullptr); + CHECK(mapping.ReqStr() == "64:ff9b:1::"); + + auto lva = mapping.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 1); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "64:ff9b:1::"); + + delete[] he.h_name; + } + + } // namespace zeek::detail diff --git a/src/DNS_Mapping.h b/src/DNS_Mapping.h new file mode 100644 index 0000000000..ef56ade2b3 --- /dev/null +++ b/src/DNS_Mapping.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include + +#include "zeek/IPAddr.h" +#include "zeek/Val.h" + +namespace zeek::detail + { + +class DNS_Mapping + { +public: + DNS_Mapping() = delete; + DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl); + DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl); + DNS_Mapping(FILE* f); + + bool NoMapping() const { return no_mapping; } + bool InitFailed() const { return init_failed; } + + ~DNS_Mapping() = default; + + // Returns nil if this was an address request. + // TODO: fix this an uses of this to just return the empty string + const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); } + const IPAddr& ReqAddr() const { return req_addr; } + std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; } + + ListValPtr Addrs(); + TableValPtr AddrsSet(); // addresses returned as a set + StringValPtr Host(); + + double CreationTime() const { return creation_time; } + + void Save(FILE* f) const; + + bool Failed() const { return failed; } + bool Valid() const { return ! failed; } + + bool Expired() const + { + if ( ! req_host.empty() && addrs.empty() ) + return false; // nothing to expire + + return util::current_time() > (creation_time + req_ttl); + } + + int Type() const { return map_type; } + +protected: + friend class DNS_Mgr; + + void Init(struct hostent* h); + void Clear(); + + std::string req_host; + IPAddr req_addr; + uint32_t req_ttl = 0; + + // This class supports multiple names per address, but we only store one of them. + std::vector names; + StringValPtr host_val; + + std::vector addrs; + ListValPtr addrs_val; + + double creation_time = 0.0; + int map_type = 0; + bool no_mapping = false; // when initializing from a file, immediately hit EOF + bool init_failed = false; + bool failed = false; + }; + + } // namespace zeek::detail diff --git a/src/DNS_Mgr.cc b/src/DNS_Mgr.cc index 8d776b7351..4b3b4b9201 100644 --- a/src/DNS_Mgr.cc +++ b/src/DNS_Mgr.cc @@ -27,7 +27,10 @@ #endif #include #include +#include +#include "zeek/3rdparty/doctest.h" +#include "zeek/DNS_Mapping.h" #include "zeek/Event.h" #include "zeek/Expr.h" #include "zeek/Hash.h" @@ -102,272 +105,6 @@ int DNS_Mgr_Request::MakeRequest(nb_dns_info* nb_dns) } } -class DNS_Mapping - { -public: - DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl); - DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl); - DNS_Mapping(FILE* f); - - bool NoMapping() const { return no_mapping; } - bool InitFailed() const { return init_failed; } - - ~DNS_Mapping(); - - // Returns nil if this was an address request. - const char* ReqHost() const { return req_host; } - IPAddr ReqAddr() const { return req_addr; } - string ReqStr() const { return req_host ? req_host : req_addr.AsString(); } - - ListValPtr Addrs(); - TableValPtr AddrsSet(); // addresses returned as a set - StringValPtr Host(); - - double CreationTime() const { return creation_time; } - - void Save(FILE* f) const; - - bool Failed() const { return failed; } - bool Valid() const { return ! failed; } - - bool Expired() const - { - if ( req_host && num_addrs == 0 ) - return false; // nothing to expire - - return util::current_time() > (creation_time + req_ttl); - } - - int Type() const { return map_type; } - -protected: - friend class DNS_Mgr; - - void Init(struct hostent* h); - void Clear(); - - char* req_host; - IPAddr req_addr; - uint32_t req_ttl; - - int num_names; - char** names; - StringValPtr host_val; - - int num_addrs; - IPAddr* addrs; - ListValPtr addrs_val; - - double creation_time; - int map_type; - bool no_mapping; // when initializing from a file, immediately hit EOF - bool init_failed; - bool failed; - }; - -void DNS_Mgr_mapping_delete_func(void* v) - { - delete (DNS_Mapping*)v; - } - -static TableValPtr empty_addr_set() - { - auto addr_t = base_type(TYPE_ADDR); - auto set_index = make_intrusive(addr_t); - set_index->Append(std::move(addr_t)); - auto s = make_intrusive(std::move(set_index), nullptr); - return make_intrusive(std::move(s)); - } - -DNS_Mapping::DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl) - { - Init(h); - req_host = util::copy_string(host); - req_ttl = ttl; - - if ( names && ! names[0] ) - names[0] = util::copy_string(host); - } - -DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) - { - Init(h); - req_addr = addr; - req_host = nullptr; - req_ttl = ttl; - } - -DNS_Mapping::DNS_Mapping(FILE* f) - { - Clear(); - init_failed = true; - - req_host = nullptr; - req_ttl = 0; - creation_time = 0; - - char buf[512]; - - if ( ! fgets(buf, sizeof(buf), f) ) - { - no_mapping = true; - return; - } - - char req_buf[512 + 1], name_buf[512 + 1]; - int is_req_host; - int failed_local; - - if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, - &failed_local, name_buf, &map_type, &num_addrs, &req_ttl) != 8 ) - return; - - failed = static_cast(failed_local); - - if ( is_req_host ) - req_host = util::copy_string(req_buf); - else - req_addr = IPAddr(req_buf); - - num_names = 1; - names = new char*[num_names]; - names[0] = util::copy_string(name_buf); - - if ( num_addrs > 0 ) - { - addrs = new IPAddr[num_addrs]; - - for ( int i = 0; i < num_addrs; ++i ) - { - if ( ! fgets(buf, sizeof(buf), f) ) - { - num_addrs = i; - return; - } - - char* newline = strchr(buf, '\n'); - if ( newline ) - *newline = '\0'; - - addrs[i] = IPAddr(buf); - } - } - else - addrs = nullptr; - - init_failed = false; - } - -DNS_Mapping::~DNS_Mapping() - { - delete[] req_host; - - if ( names ) - { - for ( int i = 0; i < num_names; ++i ) - delete[] names[i]; - delete[] names; - } - - delete[] addrs; - } - -ListValPtr DNS_Mapping::Addrs() - { - if ( failed ) - return nullptr; - - if ( ! addrs_val ) - { - addrs_val = make_intrusive(TYPE_ADDR); - - for ( int i = 0; i < num_addrs; ++i ) - addrs_val->Append(make_intrusive(addrs[i])); - } - - return addrs_val; - } - -TableValPtr DNS_Mapping::AddrsSet() - { - auto l = Addrs(); - - if ( ! l ) - return empty_addr_set(); - - return l->ToSetVal(); - } - -StringValPtr DNS_Mapping::Host() - { - if ( failed || num_names == 0 || ! names[0] ) - return nullptr; - - if ( ! host_val ) - host_val = make_intrusive(names[0]); - - return host_val; - } - -void DNS_Mapping::Init(struct hostent* h) - { - no_mapping = false; - init_failed = false; - creation_time = util::current_time(); - host_val = nullptr; - addrs_val = nullptr; - - if ( ! h ) - { - Clear(); - return; - } - - map_type = h->h_addrtype; - num_names = 1; // for now, just use official name - names = new char*[num_names]; - names[0] = h->h_name ? util::copy_string(h->h_name) : nullptr; - - for ( num_addrs = 0; h->h_addr_list[num_addrs]; ++num_addrs ) - ; - - if ( num_addrs > 0 ) - { - addrs = new IPAddr[num_addrs]; - for ( int i = 0; i < num_addrs; ++i ) - if ( h->h_addrtype == AF_INET ) - addrs[i] = IPAddr(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network); - else if ( h->h_addrtype == AF_INET6 ) - addrs[i] = IPAddr(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network); - } - else - addrs = nullptr; - - failed = false; - } - -void DNS_Mapping::Clear() - { - num_names = num_addrs = 0; - names = nullptr; - addrs = nullptr; - host_val = nullptr; - addrs_val = nullptr; - no_mapping = false; - map_type = 0; - failed = true; - } - -void DNS_Mapping::Save(FILE* f) const - { - fprintf(f, "%.0f %d %s %d %s %d %d %" PRIu32 "\n", creation_time, req_host != nullptr, - req_host ? req_host : req_addr.AsString().c_str(), failed, - (names && names[0]) ? names[0] : "*", map_type, num_addrs, req_ttl); - - for ( int i = 0; i < num_addrs; ++i ) - fprintf(f, "%s\n", addrs[i].AsString().c_str()); - } - DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) { did_init = false; @@ -410,6 +147,8 @@ void DNS_Mgr::InitSource() nb_dns = nb_dns_init(err); else { + // nb_dns expects a sockaddr, so copy the address out of the IPAddr + // object into one so it can be passed. struct sockaddr_storage ss = {0}; if ( dns_resolver_addr.GetFamily() == IPv4 ) @@ -430,7 +169,7 @@ void DNS_Mgr::InitSource() if ( nb_dns ) { - if ( ! iosource_mgr->RegisterFd(nb_dns_fd(nb_dns), this) ) + if ( ! doctest::is_running_in_test && ! iosource_mgr->RegisterFd(nb_dns_fd(nb_dns), this) ) reporter->FatalError("Failed to register nb_dns file descriptor with iosource_mgr"); } else @@ -443,10 +182,18 @@ void DNS_Mgr::InitSource() void DNS_Mgr::InitPostScript() { - dm_rec = id::find_type("dns_mapping"); + if ( ! doctest::is_running_in_test ) + { + dm_rec = id::find_type("dns_mapping"); - // Registering will call Init() - iosource_mgr->Register(this, true); + // Registering will call Init() + iosource_mgr->Register(this, true); + } + else + { + // This would normally be called when registering the iosource above. + InitSource(); + } const char* cache_dir = dir ? dir : "."; cache_name = new char[strlen(cache_dir) + 64]; @@ -535,10 +282,16 @@ TableValPtr DNS_Mgr::LookupHost(const char* name) } } -ValPtr DNS_Mgr::LookupAddr(const IPAddr& addr) +StringValPtr DNS_Mgr::LookupAddr(const IPAddr& addr) { + if ( mode == DNS_FAKE ) + return make_intrusive(fake_addr_lookup_result(addr)); + InitSource(); + if ( ! nb_dns ) + return make_intrusive(""); + if ( mode != DNS_PRIME ) { AddrMap::iterator it = addr_mappings.find(addr); @@ -703,6 +456,9 @@ void DNS_Mgr::Event(EventHandlerPtr e, DNS_Mapping* old_dm, DNS_Mapping* new_dm) ValPtr DNS_Mgr::BuildMappingVal(DNS_Mapping* dm) { + if ( ! dm_rec ) + return nullptr; + auto r = make_intrusive(dm_rec); r->AssignTime(0, dm->CreationTime()); @@ -962,7 +718,7 @@ const char* DNS_Mgr::LookupAddrInCache(const IPAddr& addr) // The escapes in the following strings are to avoid having it // interpreted as a trigraph sequence. - return d->names ? d->names[0] : "<\?\?\?>"; + return d->names.empty() ? "<\?\?\?>" : d->names[0].c_str(); } TableValPtr DNS_Mgr::LookupNameInCache(const string& name) @@ -977,7 +733,7 @@ TableValPtr DNS_Mgr::LookupNameInCache(const string& name) DNS_Mapping* d4 = it->second.first; DNS_Mapping* d6 = it->second.second; - if ( ! d4 || ! d4->names || ! d6 || ! d6->names ) + if ( ! d4 || d4->names.empty() || ! d6 || d6->names.empty() ) return nullptr; if ( d4->Expired() || d6->Expired() ) @@ -1011,19 +767,27 @@ const char* DNS_Mgr::LookupTextInCache(const string& name) // The escapes in the following strings are to avoid having it // interpreted as a trigraph sequence. - return d->names ? d->names[0] : "<\?\?\?>"; + return d->names.empty() ? "<\?\?\?>" : d->names[0].c_str(); } static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, TableValPtr result) { callback->Resolved(result.get()); - delete callback; + + // Don't delete this if testing because we need it to look at the results of the + // request. It'll get deleted by the test when finished. + if ( ! doctest::is_running_in_test ) + delete callback; } static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, const char* result) { callback->Resolved(result); - delete callback; + + // Don't delete this if testing because we need it to look at the results of the + // request. It'll get deleted by the test when finished. + if ( ! doctest::is_running_in_test ) + delete callback; } void DNS_Mgr::AsyncLookupAddr(const IPAddr& host, LookupCallback* callback) @@ -1445,4 +1209,351 @@ void DNS_Mgr::Terminate() iosource_mgr->UnregisterFd(nb_dns_fd(nb_dns), this); } +void DNS_Mgr::TestProcess() + { + // Only allow usage of this method when running unit tests. + assert(doctest::is_running_in_test); + Process(); + } + +void DNS_Mgr::AsyncRequest::Resolved(const char* name) + { + for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i ) + { + (*i)->Resolved(name); + if ( ! doctest::is_running_in_test ) + delete *i; + } + + callbacks.clear(); + processed = true; + } + +void DNS_Mgr::AsyncRequest::Resolved(TableVal* addrs) + { + for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i ) + { + (*i)->Resolved(addrs); + if ( ! doctest::is_running_in_test ) + delete *i; + } + + callbacks.clear(); + processed = true; + } + +void DNS_Mgr::AsyncRequest::Timeout() + { + for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i ) + { + (*i)->Timeout(); + if ( ! doctest::is_running_in_test ) + delete *i; + } + + callbacks.clear(); + processed = true; + } + +TableValPtr DNS_Mgr::empty_addr_set() + { + auto addr_t = base_type(TYPE_ADDR); + auto set_index = make_intrusive(addr_t); + set_index->Append(std::move(addr_t)); + auto s = make_intrusive(std::move(set_index), nullptr); + return make_intrusive(std::move(s)); + } + +////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////// + +static std::vector get_result_addresses(TableVal* addrs) + { + std::vector results; + + auto m = addrs->ToMap(); + for ( const auto& [k, v] : m ) + { + auto lv = cast_intrusive(k); + auto lvv = lv->Vals(); + for ( const auto& addr : lvv ) + { + auto addr_ptr = cast_intrusive(addr); + results.push_back(addr_ptr->Get()); + } + } + + return results; + } + +class TestCallback : public DNS_Mgr::LookupCallback + { +public: + TestCallback() { } + void Resolved(const char* name) override + { + host_result = name; + done = true; + } + void Resolved(TableVal* addrs) override + { + addr_results = get_result_addresses(addrs); + done = true; + } + void Timeout() override + { + timeout = true; + done = true; + } + + std::string host_result; + std::vector addr_results; + bool done = false; + bool timeout = false; + }; + +TEST_CASE("dns_mgr prime,save,load") + { + char prefix[] = "/tmp/zeek-unit-test-XXXXXX"; + auto tmpdir = mkdtemp(prefix); + + // Create a manager to prime the cache, make a few requests, and the save + // the result. This tests that the priming code will create the requests but + // wait for Resolve() to actually make the requests. + DNS_Mgr mgr(DNS_PRIME); + mgr.SetDir(tmpdir); + mgr.InitPostScript(); + + auto host_result = mgr.LookupHost("one.one.one.one"); + REQUIRE(host_result != nullptr); + CHECK(host_result->EqualTo(DNS_Mgr::empty_addr_set())); + + IPAddr ones("1.1.1.1"); + auto addr_result = mgr.LookupAddr(ones); + CHECK(strcmp(addr_result->CheckString(), "") == 0); + + mgr.Verify(); + mgr.Resolve(); + + // Save off the resulting values from Resolve() into a file on disk + // in the tmpdir created by mkdtemp. + REQUIRE(mgr.Save()); + + // Make a second DNS manager and reload the cache that we just saved. + DNS_Mgr mgr2(DNS_FORCE); + mgr2.SetDir(tmpdir); + mgr2.InitPostScript(); + + // Make the same two requests, but verify that we're correctly getting + // data out of the cache. + host_result = mgr2.LookupHost("one.one.one.one"); + REQUIRE(host_result != nullptr); + CHECK_FALSE(host_result->EqualTo(DNS_Mgr::empty_addr_set())); + + addr_result = mgr2.LookupAddr(ones); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); + } + +TEST_CASE("dns_mgr alternate server") + { + char* old_server = getenv("ZEEK_DNS_RESOLVER"); + + setenv("ZEEK_DNS_RESOLVER", "1.1.1.1", 1); + DNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); + + auto result = mgr.LookupAddr("1.1.1.1"); + REQUIRE(result != nullptr); + CHECK(strcmp(result->CheckString(), "one.one.one.one") == 0); + + // FIXME: This won't run on systems without IPv6 connectivity. + // setenv("ZEEK_DNS_RESOLVER", "2606:4700:4700::1111", 1); + // DNS_Mgr mgr2(DNS_DEFAULT, true); + // mgr2.InitPostScript(); + // result = mgr2.LookupAddr("1.1.1.1"); + // mgr2.Verify(); + // mgr2.Resolve(); + + // result = mgr2.LookupAddr("1.1.1.1"); + // CHECK(strcmp(result->CheckString(), "one.one.one.one") == 0); + + if ( old_server ) + setenv("ZEEK_DNS_RESOLVER", old_server, 1); + else + unsetenv("ZEEK_DNS_RESOLVER"); + } + +TEST_CASE("dns_mgr default mode") + { + DNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); + + IPAddr ones("1.1.1.1"); + auto host_result = mgr.LookupHost("one.one.one.one"); + REQUIRE(host_result != nullptr); + CHECK_FALSE(host_result->EqualTo(DNS_Mgr::empty_addr_set())); + + auto addrs_from_request = get_result_addresses(host_result.get()); + auto it = std::find(addrs_from_request.begin(), addrs_from_request.end(), ones); + CHECK(it != addrs_from_request.end()); + + auto addr_result = mgr.LookupAddr(ones); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); + + IPAddr bad("240.0.0.0"); + addr_result = mgr.LookupAddr(bad); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "240.0.0.0") == 0); + } + +TEST_CASE("dns_mgr async host") + { + DNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); + + TestCallback cb{}; + mgr.AsyncLookupName("one.one.one.one", &cb); + + // This shouldn't take any longer than DNS_TIMEOUT+1 seconds, so bound it + // just in case of some failure we're not aware of yet. + int count = 0; + while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) + { + mgr.TestProcess(); + sleep(1); + if ( ! cb.timeout ) + count++; + } + + REQUIRE(count < (DNS_TIMEOUT + 1)); + if ( ! cb.timeout ) + { + REQUIRE_FALSE(cb.addr_results.empty()); + IPAddr ones("1.1.1.1"); + auto it = std::find(cb.addr_results.begin(), cb.addr_results.end(), ones); + CHECK(it != cb.addr_results.end()); + } + + mgr.Flush(); + } + +TEST_CASE("dns_mgr async addr") + { + DNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); + + TestCallback cb{}; + mgr.AsyncLookupAddr(IPAddr{"1.1.1.1"}, &cb); + + // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it + // just in case of some failure we're not aware of yet. + int count = 0; + while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) + { + mgr.TestProcess(); + sleep(1); + if ( ! cb.timeout ) + count++; + } + + REQUIRE(count < (DNS_TIMEOUT + 1)); + if ( ! cb.timeout ) + REQUIRE(cb.host_result == "one.one.one.one"); + + mgr.Flush(); + } + +TEST_CASE("dns_mgr async text") + { + DNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); + + TestCallback cb{}; + mgr.AsyncLookupNameText("unittest.zeek.org", &cb); + + // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it + // just in case of some failure we're not aware of yet. + int count = 0; + while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) + { + mgr.TestProcess(); + sleep(1); + if ( ! cb.timeout ) + count++; + } + + REQUIRE(count < (DNS_TIMEOUT + 1)); + if ( ! cb.timeout ) + REQUIRE(cb.host_result == "testing dns_mgr"); + + mgr.Flush(); + } + +TEST_CASE("dns_mgr timeouts") + { + char* old_server = getenv("ZEEK_DNS_RESOLVER"); + + // This is the address for blackhole.webpagetest.org, which provides a DNS + // server that lets you connect but never returns any responses, always + // resulting in a timeout. + setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1); + DNS_Mgr mgr(DNS_DEFAULT); + dns_mgr = &mgr; + + mgr.InitPostScript(); + auto addr_result = mgr.LookupAddr("1.1.1.1"); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "1.1.1.1") == 0); + + auto host_result = mgr.LookupHost("one.one.one.one"); + REQUIRE(host_result != nullptr); + auto addresses = get_result_addresses(host_result.get()); + CHECK(addresses.size() == 0); + + if ( old_server ) + setenv("ZEEK_DNS_RESOLVER", old_server, 1); + else + unsetenv("ZEEK_DNS_RESOLVER"); + } + +TEST_CASE("dns_mgr async timeouts") + { + char* old_server = getenv("ZEEK_DNS_RESOLVER"); + + // This is the address for blackhole.webpagetest.org, which provides a DNS + // server that lets you connect but never returns any responses, always + // resulting in a timeout. + setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1); + DNS_Mgr mgr(DNS_DEFAULT); + dns_mgr = &mgr; + mgr.InitPostScript(); + + TestCallback cb{}; + mgr.AsyncLookupNameText("unittest.zeek.org", &cb); + + // This shouldn't take any longer than DNS_TIMEOUT +2 seconds, so bound it + // just in case of some failure we're not aware of yet. + int count = 0; + while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) + { + mgr.TestProcess(); + sleep(1); + if ( ! cb.timeout ) + count++; + } + + REQUIRE(count < (DNS_TIMEOUT + 1)); + CHECK(cb.timeout); + + mgr.Flush(); + + if ( old_server ) + setenv("ZEEK_DNS_RESOLVER", old_server, 1); + else + unsetenv("ZEEK_DNS_RESOLVER"); + } + } // namespace zeek::detail diff --git a/src/DNS_Mgr.h b/src/DNS_Mgr.h index ef9b61db74..a8c4f6f048 100644 --- a/src/DNS_Mgr.h +++ b/src/DNS_Mgr.h @@ -21,11 +21,13 @@ class RecordType; class Val; class ListVal; class TableVal; +class StringVal; template class IntrusivePtr; using ValPtr = IntrusivePtr; using ListValPtr = IntrusivePtr; using TableValPtr = IntrusivePtr; +using StringValPtr = IntrusivePtr; } // namespace zeek @@ -65,7 +67,7 @@ public: // a set of addr. TableValPtr LookupHost(const char* host); - ValPtr LookupAddr(const IPAddr& addr); + StringValPtr LookupAddr(const IPAddr& addr); // Define the directory where to store the data. void SetDir(const char* arg_dir) { dir = util::copy_string(arg_dir); } @@ -109,6 +111,14 @@ public: void Terminate(); + static TableValPtr empty_addr_set(); + + /** + * This method is used to call the private Process() method during unit testing + * and shouldn't be used otherwise. + */ + void TestProcess(); + protected: friend class LookupCallback; friend class DNS_Mgr_Request; @@ -183,38 +193,9 @@ protected: bool IsAddrReq() const { return name.empty(); } - void Resolved(const char* name) - { - for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i ) - { - (*i)->Resolved(name); - delete *i; - } - callbacks.clear(); - processed = true; - } - - void Resolved(TableVal* addrs) - { - for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i ) - { - (*i)->Resolved(addrs); - delete *i; - } - callbacks.clear(); - processed = true; - } - - void Timeout() - { - for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i ) - { - (*i)->Timeout(); - delete *i; - } - callbacks.clear(); - processed = true; - } + void Resolved(const char* name); + void Resolved(TableVal* addrs); + void Timeout(); }; using AsyncRequestAddrMap = std::map; diff --git a/src/Val.h b/src/Val.h index afa77c8818..b7cae59690 100644 --- a/src/Val.h +++ b/src/Val.h @@ -824,6 +824,7 @@ public: // so errors can arise for compound sets such as sets-of-sets. // See https://github.com/zeek/zeek/issues/151. bool EqualTo(const TableVal& v) const; + bool EqualTo(const TableValPtr& v) const { return EqualTo(*(v.get())); } // Returns true if this set is a subset (not necessarily proper) // of the given set.