Add unit testing for DNS_Mgr and related classes

This commit is contained in:
Tim Wojtulewicz 2022-01-14 10:21:26 -07:00
parent 824bc372c5
commit e6e9144da6
8 changed files with 870 additions and 311 deletions

2
cmake

@ -1 +1 @@
Subproject commit b1c9e55d1e837d46fc36312b567245013cf9646a Subproject commit 649c319f88e2966931892d55adb2ee50f278662b

@ -1 +1 @@
Subproject commit 0af190f90572abc90366471f36e6feb1b817d2ab Subproject commit 6cbb3d65877f80326c047364583f506ce58758ba

View file

@ -286,6 +286,7 @@ set(MAIN_SRCS
Desc.cc Desc.cc
Dict.cc Dict.cc
Discard.cc Discard.cc
DNS_Mapping.cc
DNS_Mgr.cc DNS_Mgr.cc
EquivClass.cc EquivClass.cc
Event.cc Event.cc

387
src/DNS_Mapping.cc Normal file
View file

@ -0,0 +1,387 @@
#include "zeek/DNS_Mapping.h"
#include "zeek/3rdparty/doctest.h"
#include "zeek/DNS_Mgr.h"
namespace zeek::detail
{
DNS_Mapping::DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl)
{
Init(h);
req_host = host;
req_ttl = ttl;
if ( names.empty() )
names.push_back(host);
}
DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl)
{
Init(h);
req_addr = addr;
req_ttl = ttl;
}
DNS_Mapping::DNS_Mapping(FILE* f)
{
Clear();
init_failed = true;
req_ttl = 0;
creation_time = 0;
char buf[512];
if ( ! fgets(buf, sizeof(buf), f) )
{
no_mapping = true;
return;
}
char req_buf[512 + 1], name_buf[512 + 1];
int is_req_host;
int failed_local;
int num_addrs;
if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf,
&failed_local, name_buf, &map_type, &num_addrs, &req_ttl) != 8 )
return;
failed = static_cast<bool>(failed_local);
if ( is_req_host )
req_host = req_buf;
else
req_addr = IPAddr(req_buf);
names.push_back(name_buf);
for ( int i = 0; i < num_addrs; ++i )
{
if ( ! fgets(buf, sizeof(buf), f) )
return;
char* newline = strchr(buf, '\n');
if ( newline )
*newline = '\0';
addrs.emplace_back(IPAddr(buf));
}
init_failed = false;
}
ListValPtr DNS_Mapping::Addrs()
{
if ( failed )
return nullptr;
if ( ! addrs_val )
{
addrs_val = make_intrusive<ListVal>(TYPE_ADDR);
for ( const auto& addr : addrs )
addrs_val->Append(make_intrusive<AddrVal>(addr));
}
return addrs_val;
}
TableValPtr DNS_Mapping::AddrsSet()
{
auto l = Addrs();
if ( ! l || l->Length() == 0 )
return DNS_Mgr::empty_addr_set();
return l->ToSetVal();
}
StringValPtr DNS_Mapping::Host()
{
if ( failed || names.empty() )
return nullptr;
if ( ! host_val )
host_val = make_intrusive<StringVal>(names[0]);
return host_val;
}
void DNS_Mapping::Init(struct hostent* h)
{
no_mapping = false;
init_failed = false;
creation_time = util::current_time();
host_val = nullptr;
addrs_val = nullptr;
if ( ! h )
{
Clear();
return;
}
map_type = h->h_addrtype;
if ( h->h_name )
// for now, just use the official name
// TODO: this could easily be expanded to include all of the aliases as well
names.push_back(h->h_name);
for ( int i = 0; h->h_addr_list[i] != NULL; ++i )
{
if ( h->h_addrtype == AF_INET )
addrs.push_back(IPAddr(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network));
else if ( h->h_addrtype == AF_INET6 )
addrs.push_back(IPAddr(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network));
}
failed = false;
}
void DNS_Mapping::Clear()
{
names.clear();
host_val = nullptr;
addrs.clear();
addrs_val = nullptr;
no_mapping = false;
map_type = 0;
failed = true;
}
void DNS_Mapping::Save(FILE* f) const
{
fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(),
req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed,
names.empty() ? "*" : names[0].c_str(), map_type, addrs.size(), req_ttl);
for ( const auto& addr : addrs )
fprintf(f, "%s\n", addr.AsString().c_str());
}
//////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////
TEST_CASE("dns_mapping init null hostent")
{
DNS_Mapping mapping("www.apple.com", nullptr, 123);
CHECK(! mapping.Valid());
CHECK(mapping.Addrs() == nullptr);
CHECK(mapping.AddrsSet()->EqualTo(DNS_Mgr::empty_addr_set()));
CHECK(mapping.Host() == nullptr);
}
TEST_CASE("dns_mapping init host")
{
IPAddr addr("1.2.3.4");
in4_addr in4;
addr.CopyIPv4(&in4);
struct hostent he;
he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL;
he.h_addrtype = AF_INET;
he.h_length = sizeof(in_addr);
std::vector<in_addr*> addrs = {&in4, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping("testing.home", &he, 123);
CHECK(mapping.Valid());
CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified);
CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0);
CHECK(mapping.ReqStr() == "testing.home");
auto lva = mapping.Addrs();
REQUIRE(lva != nullptr);
CHECK(lva->Length() == 1);
auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "1.2.3.4");
auto tvas = mapping.AddrsSet();
REQUIRE(tvas != nullptr);
CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set()));
auto svh = mapping.Host();
REQUIRE(svh != nullptr);
CHECK(svh->ToStdString() == "testing.home");
delete[] he.h_name;
}
TEST_CASE("dns_mapping init addr")
{
IPAddr addr("1.2.3.4");
in4_addr in4;
addr.CopyIPv4(&in4);
struct hostent he;
he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL;
he.h_addrtype = AF_INET;
he.h_length = sizeof(in_addr);
std::vector<in_addr*> addrs = {&in4, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping(addr, &he, 123);
CHECK(mapping.Valid());
CHECK(mapping.ReqAddr() == addr);
CHECK(mapping.ReqHost() == nullptr);
CHECK(mapping.ReqStr() == "1.2.3.4");
auto lva = mapping.Addrs();
REQUIRE(lva != nullptr);
CHECK(lva->Length() == 1);
auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "1.2.3.4");
auto tvas = mapping.AddrsSet();
REQUIRE(tvas != nullptr);
CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set()));
auto svh = mapping.Host();
REQUIRE(svh != nullptr);
CHECK(svh->ToStdString() == "testing.home");
delete[] he.h_name;
}
TEST_CASE("dns_mapping save reload")
{
IPAddr addr("1.2.3.4");
in4_addr in4;
addr.CopyIPv4(&in4);
struct hostent he;
he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL;
he.h_addrtype = AF_INET;
he.h_length = sizeof(in_addr);
std::vector<in_addr*> addrs = {&in4, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data());
// Create a temporary file in memory and fseek to the end of it so we're at
// EOF for the next bit.
char buffer[4096];
memset(buffer, 0, 4096);
FILE* tmpfile = fmemopen(buffer, 4096, "r+");
fseek(tmpfile, 0, SEEK_END);
// Try loading from the file at EOF. This should cause a mapping failure.
DNS_Mapping mapping(tmpfile);
CHECK(mapping.NoMapping());
rewind(tmpfile);
// Try reading from the empty file. This should cause an init failure.
DNS_Mapping mapping2(tmpfile);
CHECK(mapping2.InitFailed());
rewind(tmpfile);
// Save a valid mapping into the file and rewind to the start.
DNS_Mapping mapping3(addr, &he, 123);
mapping3.Save(tmpfile);
rewind(tmpfile);
// Test loading the mapping back out of the file
DNS_Mapping mapping4(tmpfile);
fclose(tmpfile);
CHECK(mapping4.Valid());
CHECK(mapping4.ReqAddr() == addr);
CHECK(mapping4.ReqHost() == nullptr);
CHECK(mapping4.ReqStr() == "1.2.3.4");
auto lva = mapping4.Addrs();
REQUIRE(lva != nullptr);
CHECK(lva->Length() == 1);
auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "1.2.3.4");
auto tvas = mapping4.AddrsSet();
REQUIRE(tvas != nullptr);
CHECK(tvas != DNS_Mgr::empty_addr_set());
auto svh = mapping4.Host();
REQUIRE(svh != nullptr);
CHECK(svh->ToStdString() == "testing.home");
delete[] he.h_name;
}
TEST_CASE("dns_mapping multiple addresses")
{
IPAddr addr("1.2.3.4");
in4_addr in4_1;
addr.CopyIPv4(&in4_1);
IPAddr addr2("5.6.7.8");
in4_addr in4_2;
addr2.CopyIPv4(&in4_2);
struct hostent he;
he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL;
he.h_addrtype = AF_INET;
he.h_length = sizeof(in_addr);
std::vector<in_addr*> addrs = {&in4_1, &in4_2, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping("testing.home", &he, 123);
CHECK(mapping.Valid());
auto lva = mapping.Addrs();
REQUIRE(lva != nullptr);
CHECK(lva->Length() == 2);
auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "1.2.3.4");
lvae = lva->Idx(1)->AsAddrVal();
REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "5.6.7.8");
delete[] he.h_name;
}
TEST_CASE("dns_mapping ipv6")
{
IPAddr addr("64:ff9b:1::");
in6_addr in6;
addr.CopyIPv6(&in6);
struct hostent he;
he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL;
he.h_addrtype = AF_INET6;
he.h_length = sizeof(in6_addr);
std::vector<in6_addr*> addrs = {&in6, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping(addr, &he, 123);
CHECK(mapping.Valid());
CHECK(mapping.ReqAddr() == addr);
CHECK(mapping.ReqHost() == nullptr);
CHECK(mapping.ReqStr() == "64:ff9b:1::");
auto lva = mapping.Addrs();
REQUIRE(lva != nullptr);
CHECK(lva->Length() == 1);
auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "64:ff9b:1::");
delete[] he.h_name;
}
} // namespace zeek::detail

78
src/DNS_Mapping.h Normal file
View file

@ -0,0 +1,78 @@
#pragma once
#include <netdb.h>
#include <sys/socket.h>
#include <cstdint>
#include <string>
#include "zeek/IPAddr.h"
#include "zeek/Val.h"
namespace zeek::detail
{
class DNS_Mapping
{
public:
DNS_Mapping() = delete;
DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl);
DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl);
DNS_Mapping(FILE* f);
bool NoMapping() const { return no_mapping; }
bool InitFailed() const { return init_failed; }
~DNS_Mapping() = default;
// Returns nil if this was an address request.
// TODO: fix this an uses of this to just return the empty string
const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); }
const IPAddr& ReqAddr() const { return req_addr; }
std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; }
ListValPtr Addrs();
TableValPtr AddrsSet(); // addresses returned as a set
StringValPtr Host();
double CreationTime() const { return creation_time; }
void Save(FILE* f) const;
bool Failed() const { return failed; }
bool Valid() const { return ! failed; }
bool Expired() const
{
if ( ! req_host.empty() && addrs.empty() )
return false; // nothing to expire
return util::current_time() > (creation_time + req_ttl);
}
int Type() const { return map_type; }
protected:
friend class DNS_Mgr;
void Init(struct hostent* h);
void Clear();
std::string req_host;
IPAddr req_addr;
uint32_t req_ttl = 0;
// This class supports multiple names per address, but we only store one of them.
std::vector<std::string> names;
StringValPtr host_val;
std::vector<IPAddr> addrs;
ListValPtr addrs_val;
double creation_time = 0.0;
int map_type = 0;
bool no_mapping = false; // when initializing from a file, immediately hit EOF
bool init_failed = false;
bool failed = false;
};
} // namespace zeek::detail

View file

@ -27,7 +27,10 @@
#endif #endif
#include <stdlib.h> #include <stdlib.h>
#include <algorithm> #include <algorithm>
#include <vector>
#include "zeek/3rdparty/doctest.h"
#include "zeek/DNS_Mapping.h"
#include "zeek/Event.h" #include "zeek/Event.h"
#include "zeek/Expr.h" #include "zeek/Expr.h"
#include "zeek/Hash.h" #include "zeek/Hash.h"
@ -102,272 +105,6 @@ int DNS_Mgr_Request::MakeRequest(nb_dns_info* nb_dns)
} }
} }
class DNS_Mapping
{
public:
DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl);
DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl);
DNS_Mapping(FILE* f);
bool NoMapping() const { return no_mapping; }
bool InitFailed() const { return init_failed; }
~DNS_Mapping();
// Returns nil if this was an address request.
const char* ReqHost() const { return req_host; }
IPAddr ReqAddr() const { return req_addr; }
string ReqStr() const { return req_host ? req_host : req_addr.AsString(); }
ListValPtr Addrs();
TableValPtr AddrsSet(); // addresses returned as a set
StringValPtr Host();
double CreationTime() const { return creation_time; }
void Save(FILE* f) const;
bool Failed() const { return failed; }
bool Valid() const { return ! failed; }
bool Expired() const
{
if ( req_host && num_addrs == 0 )
return false; // nothing to expire
return util::current_time() > (creation_time + req_ttl);
}
int Type() const { return map_type; }
protected:
friend class DNS_Mgr;
void Init(struct hostent* h);
void Clear();
char* req_host;
IPAddr req_addr;
uint32_t req_ttl;
int num_names;
char** names;
StringValPtr host_val;
int num_addrs;
IPAddr* addrs;
ListValPtr addrs_val;
double creation_time;
int map_type;
bool no_mapping; // when initializing from a file, immediately hit EOF
bool init_failed;
bool failed;
};
void DNS_Mgr_mapping_delete_func(void* v)
{
delete (DNS_Mapping*)v;
}
static TableValPtr empty_addr_set()
{
auto addr_t = base_type(TYPE_ADDR);
auto set_index = make_intrusive<TypeList>(addr_t);
set_index->Append(std::move(addr_t));
auto s = make_intrusive<SetType>(std::move(set_index), nullptr);
return make_intrusive<TableVal>(std::move(s));
}
DNS_Mapping::DNS_Mapping(const char* host, struct hostent* h, uint32_t ttl)
{
Init(h);
req_host = util::copy_string(host);
req_ttl = ttl;
if ( names && ! names[0] )
names[0] = util::copy_string(host);
}
DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl)
{
Init(h);
req_addr = addr;
req_host = nullptr;
req_ttl = ttl;
}
DNS_Mapping::DNS_Mapping(FILE* f)
{
Clear();
init_failed = true;
req_host = nullptr;
req_ttl = 0;
creation_time = 0;
char buf[512];
if ( ! fgets(buf, sizeof(buf), f) )
{
no_mapping = true;
return;
}
char req_buf[512 + 1], name_buf[512 + 1];
int is_req_host;
int failed_local;
if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf,
&failed_local, name_buf, &map_type, &num_addrs, &req_ttl) != 8 )
return;
failed = static_cast<bool>(failed_local);
if ( is_req_host )
req_host = util::copy_string(req_buf);
else
req_addr = IPAddr(req_buf);
num_names = 1;
names = new char*[num_names];
names[0] = util::copy_string(name_buf);
if ( num_addrs > 0 )
{
addrs = new IPAddr[num_addrs];
for ( int i = 0; i < num_addrs; ++i )
{
if ( ! fgets(buf, sizeof(buf), f) )
{
num_addrs = i;
return;
}
char* newline = strchr(buf, '\n');
if ( newline )
*newline = '\0';
addrs[i] = IPAddr(buf);
}
}
else
addrs = nullptr;
init_failed = false;
}
DNS_Mapping::~DNS_Mapping()
{
delete[] req_host;
if ( names )
{
for ( int i = 0; i < num_names; ++i )
delete[] names[i];
delete[] names;
}
delete[] addrs;
}
ListValPtr DNS_Mapping::Addrs()
{
if ( failed )
return nullptr;
if ( ! addrs_val )
{
addrs_val = make_intrusive<ListVal>(TYPE_ADDR);
for ( int i = 0; i < num_addrs; ++i )
addrs_val->Append(make_intrusive<AddrVal>(addrs[i]));
}
return addrs_val;
}
TableValPtr DNS_Mapping::AddrsSet()
{
auto l = Addrs();
if ( ! l )
return empty_addr_set();
return l->ToSetVal();
}
StringValPtr DNS_Mapping::Host()
{
if ( failed || num_names == 0 || ! names[0] )
return nullptr;
if ( ! host_val )
host_val = make_intrusive<StringVal>(names[0]);
return host_val;
}
void DNS_Mapping::Init(struct hostent* h)
{
no_mapping = false;
init_failed = false;
creation_time = util::current_time();
host_val = nullptr;
addrs_val = nullptr;
if ( ! h )
{
Clear();
return;
}
map_type = h->h_addrtype;
num_names = 1; // for now, just use official name
names = new char*[num_names];
names[0] = h->h_name ? util::copy_string(h->h_name) : nullptr;
for ( num_addrs = 0; h->h_addr_list[num_addrs]; ++num_addrs )
;
if ( num_addrs > 0 )
{
addrs = new IPAddr[num_addrs];
for ( int i = 0; i < num_addrs; ++i )
if ( h->h_addrtype == AF_INET )
addrs[i] = IPAddr(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network);
else if ( h->h_addrtype == AF_INET6 )
addrs[i] = IPAddr(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network);
}
else
addrs = nullptr;
failed = false;
}
void DNS_Mapping::Clear()
{
num_names = num_addrs = 0;
names = nullptr;
addrs = nullptr;
host_val = nullptr;
addrs_val = nullptr;
no_mapping = false;
map_type = 0;
failed = true;
}
void DNS_Mapping::Save(FILE* f) const
{
fprintf(f, "%.0f %d %s %d %s %d %d %" PRIu32 "\n", creation_time, req_host != nullptr,
req_host ? req_host : req_addr.AsString().c_str(), failed,
(names && names[0]) ? names[0] : "*", map_type, num_addrs, req_ttl);
for ( int i = 0; i < num_addrs; ++i )
fprintf(f, "%s\n", addrs[i].AsString().c_str());
}
DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode)
{ {
did_init = false; did_init = false;
@ -410,6 +147,8 @@ void DNS_Mgr::InitSource()
nb_dns = nb_dns_init(err); nb_dns = nb_dns_init(err);
else else
{ {
// nb_dns expects a sockaddr, so copy the address out of the IPAddr
// object into one so it can be passed.
struct sockaddr_storage ss = {0}; struct sockaddr_storage ss = {0};
if ( dns_resolver_addr.GetFamily() == IPv4 ) if ( dns_resolver_addr.GetFamily() == IPv4 )
@ -430,7 +169,7 @@ void DNS_Mgr::InitSource()
if ( nb_dns ) if ( nb_dns )
{ {
if ( ! iosource_mgr->RegisterFd(nb_dns_fd(nb_dns), this) ) if ( ! doctest::is_running_in_test && ! iosource_mgr->RegisterFd(nb_dns_fd(nb_dns), this) )
reporter->FatalError("Failed to register nb_dns file descriptor with iosource_mgr"); reporter->FatalError("Failed to register nb_dns file descriptor with iosource_mgr");
} }
else else
@ -442,11 +181,19 @@ void DNS_Mgr::InitSource()
} }
void DNS_Mgr::InitPostScript() void DNS_Mgr::InitPostScript()
{
if ( ! doctest::is_running_in_test )
{ {
dm_rec = id::find_type<RecordType>("dns_mapping"); dm_rec = id::find_type<RecordType>("dns_mapping");
// Registering will call Init() // Registering will call Init()
iosource_mgr->Register(this, true); iosource_mgr->Register(this, true);
}
else
{
// This would normally be called when registering the iosource above.
InitSource();
}
const char* cache_dir = dir ? dir : "."; const char* cache_dir = dir ? dir : ".";
cache_name = new char[strlen(cache_dir) + 64]; cache_name = new char[strlen(cache_dir) + 64];
@ -535,10 +282,16 @@ TableValPtr DNS_Mgr::LookupHost(const char* name)
} }
} }
ValPtr DNS_Mgr::LookupAddr(const IPAddr& addr) StringValPtr DNS_Mgr::LookupAddr(const IPAddr& addr)
{ {
if ( mode == DNS_FAKE )
return make_intrusive<StringVal>(fake_addr_lookup_result(addr));
InitSource(); InitSource();
if ( ! nb_dns )
return make_intrusive<StringVal>("<none>");
if ( mode != DNS_PRIME ) if ( mode != DNS_PRIME )
{ {
AddrMap::iterator it = addr_mappings.find(addr); AddrMap::iterator it = addr_mappings.find(addr);
@ -703,6 +456,9 @@ void DNS_Mgr::Event(EventHandlerPtr e, DNS_Mapping* old_dm, DNS_Mapping* new_dm)
ValPtr DNS_Mgr::BuildMappingVal(DNS_Mapping* dm) ValPtr DNS_Mgr::BuildMappingVal(DNS_Mapping* dm)
{ {
if ( ! dm_rec )
return nullptr;
auto r = make_intrusive<RecordVal>(dm_rec); auto r = make_intrusive<RecordVal>(dm_rec);
r->AssignTime(0, dm->CreationTime()); r->AssignTime(0, dm->CreationTime());
@ -962,7 +718,7 @@ const char* DNS_Mgr::LookupAddrInCache(const IPAddr& addr)
// The escapes in the following strings are to avoid having it // The escapes in the following strings are to avoid having it
// interpreted as a trigraph sequence. // interpreted as a trigraph sequence.
return d->names ? d->names[0] : "<\?\?\?>"; return d->names.empty() ? "<\?\?\?>" : d->names[0].c_str();
} }
TableValPtr DNS_Mgr::LookupNameInCache(const string& name) TableValPtr DNS_Mgr::LookupNameInCache(const string& name)
@ -977,7 +733,7 @@ TableValPtr DNS_Mgr::LookupNameInCache(const string& name)
DNS_Mapping* d4 = it->second.first; DNS_Mapping* d4 = it->second.first;
DNS_Mapping* d6 = it->second.second; DNS_Mapping* d6 = it->second.second;
if ( ! d4 || ! d4->names || ! d6 || ! d6->names ) if ( ! d4 || d4->names.empty() || ! d6 || d6->names.empty() )
return nullptr; return nullptr;
if ( d4->Expired() || d6->Expired() ) if ( d4->Expired() || d6->Expired() )
@ -1011,18 +767,26 @@ const char* DNS_Mgr::LookupTextInCache(const string& name)
// The escapes in the following strings are to avoid having it // The escapes in the following strings are to avoid having it
// interpreted as a trigraph sequence. // interpreted as a trigraph sequence.
return d->names ? d->names[0] : "<\?\?\?>"; return d->names.empty() ? "<\?\?\?>" : d->names[0].c_str();
} }
static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, TableValPtr result) static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, TableValPtr result)
{ {
callback->Resolved(result.get()); callback->Resolved(result.get());
// Don't delete this if testing because we need it to look at the results of the
// request. It'll get deleted by the test when finished.
if ( ! doctest::is_running_in_test )
delete callback; delete callback;
} }
static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, const char* result) static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, const char* result)
{ {
callback->Resolved(result); callback->Resolved(result);
// Don't delete this if testing because we need it to look at the results of the
// request. It'll get deleted by the test when finished.
if ( ! doctest::is_running_in_test )
delete callback; delete callback;
} }
@ -1445,4 +1209,351 @@ void DNS_Mgr::Terminate()
iosource_mgr->UnregisterFd(nb_dns_fd(nb_dns), this); iosource_mgr->UnregisterFd(nb_dns_fd(nb_dns), this);
} }
void DNS_Mgr::TestProcess()
{
// Only allow usage of this method when running unit tests.
assert(doctest::is_running_in_test);
Process();
}
void DNS_Mgr::AsyncRequest::Resolved(const char* name)
{
for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i )
{
(*i)->Resolved(name);
if ( ! doctest::is_running_in_test )
delete *i;
}
callbacks.clear();
processed = true;
}
void DNS_Mgr::AsyncRequest::Resolved(TableVal* addrs)
{
for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i )
{
(*i)->Resolved(addrs);
if ( ! doctest::is_running_in_test )
delete *i;
}
callbacks.clear();
processed = true;
}
void DNS_Mgr::AsyncRequest::Timeout()
{
for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i )
{
(*i)->Timeout();
if ( ! doctest::is_running_in_test )
delete *i;
}
callbacks.clear();
processed = true;
}
TableValPtr DNS_Mgr::empty_addr_set()
{
auto addr_t = base_type(TYPE_ADDR);
auto set_index = make_intrusive<TypeList>(addr_t);
set_index->Append(std::move(addr_t));
auto s = make_intrusive<SetType>(std::move(set_index), nullptr);
return make_intrusive<TableVal>(std::move(s));
}
//////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////
static std::vector<IPAddr> get_result_addresses(TableVal* addrs)
{
std::vector<IPAddr> results;
auto m = addrs->ToMap();
for ( const auto& [k, v] : m )
{
auto lv = cast_intrusive<ListVal>(k);
auto lvv = lv->Vals();
for ( const auto& addr : lvv )
{
auto addr_ptr = cast_intrusive<AddrVal>(addr);
results.push_back(addr_ptr->Get());
}
}
return results;
}
class TestCallback : public DNS_Mgr::LookupCallback
{
public:
TestCallback() { }
void Resolved(const char* name) override
{
host_result = name;
done = true;
}
void Resolved(TableVal* addrs) override
{
addr_results = get_result_addresses(addrs);
done = true;
}
void Timeout() override
{
timeout = true;
done = true;
}
std::string host_result;
std::vector<IPAddr> addr_results;
bool done = false;
bool timeout = false;
};
TEST_CASE("dns_mgr prime,save,load")
{
char prefix[] = "/tmp/zeek-unit-test-XXXXXX";
auto tmpdir = mkdtemp(prefix);
// Create a manager to prime the cache, make a few requests, and the save
// the result. This tests that the priming code will create the requests but
// wait for Resolve() to actually make the requests.
DNS_Mgr mgr(DNS_PRIME);
mgr.SetDir(tmpdir);
mgr.InitPostScript();
auto host_result = mgr.LookupHost("one.one.one.one");
REQUIRE(host_result != nullptr);
CHECK(host_result->EqualTo(DNS_Mgr::empty_addr_set()));
IPAddr ones("1.1.1.1");
auto addr_result = mgr.LookupAddr(ones);
CHECK(strcmp(addr_result->CheckString(), "<none>") == 0);
mgr.Verify();
mgr.Resolve();
// Save off the resulting values from Resolve() into a file on disk
// in the tmpdir created by mkdtemp.
REQUIRE(mgr.Save());
// Make a second DNS manager and reload the cache that we just saved.
DNS_Mgr mgr2(DNS_FORCE);
mgr2.SetDir(tmpdir);
mgr2.InitPostScript();
// Make the same two requests, but verify that we're correctly getting
// data out of the cache.
host_result = mgr2.LookupHost("one.one.one.one");
REQUIRE(host_result != nullptr);
CHECK_FALSE(host_result->EqualTo(DNS_Mgr::empty_addr_set()));
addr_result = mgr2.LookupAddr(ones);
REQUIRE(addr_result != nullptr);
CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0);
}
TEST_CASE("dns_mgr alternate server")
{
char* old_server = getenv("ZEEK_DNS_RESOLVER");
setenv("ZEEK_DNS_RESOLVER", "1.1.1.1", 1);
DNS_Mgr mgr(DNS_DEFAULT);
mgr.InitPostScript();
auto result = mgr.LookupAddr("1.1.1.1");
REQUIRE(result != nullptr);
CHECK(strcmp(result->CheckString(), "one.one.one.one") == 0);
// FIXME: This won't run on systems without IPv6 connectivity.
// setenv("ZEEK_DNS_RESOLVER", "2606:4700:4700::1111", 1);
// DNS_Mgr mgr2(DNS_DEFAULT, true);
// mgr2.InitPostScript();
// result = mgr2.LookupAddr("1.1.1.1");
// mgr2.Verify();
// mgr2.Resolve();
// result = mgr2.LookupAddr("1.1.1.1");
// CHECK(strcmp(result->CheckString(), "one.one.one.one") == 0);
if ( old_server )
setenv("ZEEK_DNS_RESOLVER", old_server, 1);
else
unsetenv("ZEEK_DNS_RESOLVER");
}
TEST_CASE("dns_mgr default mode")
{
DNS_Mgr mgr(DNS_DEFAULT);
mgr.InitPostScript();
IPAddr ones("1.1.1.1");
auto host_result = mgr.LookupHost("one.one.one.one");
REQUIRE(host_result != nullptr);
CHECK_FALSE(host_result->EqualTo(DNS_Mgr::empty_addr_set()));
auto addrs_from_request = get_result_addresses(host_result.get());
auto it = std::find(addrs_from_request.begin(), addrs_from_request.end(), ones);
CHECK(it != addrs_from_request.end());
auto addr_result = mgr.LookupAddr(ones);
REQUIRE(addr_result != nullptr);
CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0);
IPAddr bad("240.0.0.0");
addr_result = mgr.LookupAddr(bad);
REQUIRE(addr_result != nullptr);
CHECK(strcmp(addr_result->CheckString(), "240.0.0.0") == 0);
}
TEST_CASE("dns_mgr async host")
{
DNS_Mgr mgr(DNS_DEFAULT);
mgr.InitPostScript();
TestCallback cb{};
mgr.AsyncLookupName("one.one.one.one", &cb);
// This shouldn't take any longer than DNS_TIMEOUT+1 seconds, so bound it
// just in case of some failure we're not aware of yet.
int count = 0;
while ( ! cb.done && (count < DNS_TIMEOUT + 1) )
{
mgr.TestProcess();
sleep(1);
if ( ! cb.timeout )
count++;
}
REQUIRE(count < (DNS_TIMEOUT + 1));
if ( ! cb.timeout )
{
REQUIRE_FALSE(cb.addr_results.empty());
IPAddr ones("1.1.1.1");
auto it = std::find(cb.addr_results.begin(), cb.addr_results.end(), ones);
CHECK(it != cb.addr_results.end());
}
mgr.Flush();
}
TEST_CASE("dns_mgr async addr")
{
DNS_Mgr mgr(DNS_DEFAULT);
mgr.InitPostScript();
TestCallback cb{};
mgr.AsyncLookupAddr(IPAddr{"1.1.1.1"}, &cb);
// This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it
// just in case of some failure we're not aware of yet.
int count = 0;
while ( ! cb.done && (count < DNS_TIMEOUT + 1) )
{
mgr.TestProcess();
sleep(1);
if ( ! cb.timeout )
count++;
}
REQUIRE(count < (DNS_TIMEOUT + 1));
if ( ! cb.timeout )
REQUIRE(cb.host_result == "one.one.one.one");
mgr.Flush();
}
TEST_CASE("dns_mgr async text")
{
DNS_Mgr mgr(DNS_DEFAULT);
mgr.InitPostScript();
TestCallback cb{};
mgr.AsyncLookupNameText("unittest.zeek.org", &cb);
// This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it
// just in case of some failure we're not aware of yet.
int count = 0;
while ( ! cb.done && (count < DNS_TIMEOUT + 1) )
{
mgr.TestProcess();
sleep(1);
if ( ! cb.timeout )
count++;
}
REQUIRE(count < (DNS_TIMEOUT + 1));
if ( ! cb.timeout )
REQUIRE(cb.host_result == "testing dns_mgr");
mgr.Flush();
}
TEST_CASE("dns_mgr timeouts")
{
char* old_server = getenv("ZEEK_DNS_RESOLVER");
// This is the address for blackhole.webpagetest.org, which provides a DNS
// server that lets you connect but never returns any responses, always
// resulting in a timeout.
setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1);
DNS_Mgr mgr(DNS_DEFAULT);
dns_mgr = &mgr;
mgr.InitPostScript();
auto addr_result = mgr.LookupAddr("1.1.1.1");
REQUIRE(addr_result != nullptr);
CHECK(strcmp(addr_result->CheckString(), "1.1.1.1") == 0);
auto host_result = mgr.LookupHost("one.one.one.one");
REQUIRE(host_result != nullptr);
auto addresses = get_result_addresses(host_result.get());
CHECK(addresses.size() == 0);
if ( old_server )
setenv("ZEEK_DNS_RESOLVER", old_server, 1);
else
unsetenv("ZEEK_DNS_RESOLVER");
}
TEST_CASE("dns_mgr async timeouts")
{
char* old_server = getenv("ZEEK_DNS_RESOLVER");
// This is the address for blackhole.webpagetest.org, which provides a DNS
// server that lets you connect but never returns any responses, always
// resulting in a timeout.
setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1);
DNS_Mgr mgr(DNS_DEFAULT);
dns_mgr = &mgr;
mgr.InitPostScript();
TestCallback cb{};
mgr.AsyncLookupNameText("unittest.zeek.org", &cb);
// This shouldn't take any longer than DNS_TIMEOUT +2 seconds, so bound it
// just in case of some failure we're not aware of yet.
int count = 0;
while ( ! cb.done && (count < DNS_TIMEOUT + 1) )
{
mgr.TestProcess();
sleep(1);
if ( ! cb.timeout )
count++;
}
REQUIRE(count < (DNS_TIMEOUT + 1));
CHECK(cb.timeout);
mgr.Flush();
if ( old_server )
setenv("ZEEK_DNS_RESOLVER", old_server, 1);
else
unsetenv("ZEEK_DNS_RESOLVER");
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -21,11 +21,13 @@ class RecordType;
class Val; class Val;
class ListVal; class ListVal;
class TableVal; class TableVal;
class StringVal;
template <class T> class IntrusivePtr; template <class T> class IntrusivePtr;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using ListValPtr = IntrusivePtr<ListVal>; using ListValPtr = IntrusivePtr<ListVal>;
using TableValPtr = IntrusivePtr<TableVal>; using TableValPtr = IntrusivePtr<TableVal>;
using StringValPtr = IntrusivePtr<StringVal>;
} // namespace zeek } // namespace zeek
@ -65,7 +67,7 @@ public:
// a set of addr. // a set of addr.
TableValPtr LookupHost(const char* host); TableValPtr LookupHost(const char* host);
ValPtr LookupAddr(const IPAddr& addr); StringValPtr LookupAddr(const IPAddr& addr);
// Define the directory where to store the data. // Define the directory where to store the data.
void SetDir(const char* arg_dir) { dir = util::copy_string(arg_dir); } void SetDir(const char* arg_dir) { dir = util::copy_string(arg_dir); }
@ -109,6 +111,14 @@ public:
void Terminate(); void Terminate();
static TableValPtr empty_addr_set();
/**
* This method is used to call the private Process() method during unit testing
* and shouldn't be used otherwise.
*/
void TestProcess();
protected: protected:
friend class LookupCallback; friend class LookupCallback;
friend class DNS_Mgr_Request; friend class DNS_Mgr_Request;
@ -183,38 +193,9 @@ protected:
bool IsAddrReq() const { return name.empty(); } bool IsAddrReq() const { return name.empty(); }
void Resolved(const char* name) void Resolved(const char* name);
{ void Resolved(TableVal* addrs);
for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i ) void Timeout();
{
(*i)->Resolved(name);
delete *i;
}
callbacks.clear();
processed = true;
}
void Resolved(TableVal* addrs)
{
for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i )
{
(*i)->Resolved(addrs);
delete *i;
}
callbacks.clear();
processed = true;
}
void Timeout()
{
for ( CallbackList::iterator i = callbacks.begin(); i != callbacks.end(); ++i )
{
(*i)->Timeout();
delete *i;
}
callbacks.clear();
processed = true;
}
}; };
using AsyncRequestAddrMap = std::map<IPAddr, AsyncRequest*>; using AsyncRequestAddrMap = std::map<IPAddr, AsyncRequest*>;

View file

@ -824,6 +824,7 @@ public:
// so errors can arise for compound sets such as sets-of-sets. // so errors can arise for compound sets such as sets-of-sets.
// See https://github.com/zeek/zeek/issues/151. // See https://github.com/zeek/zeek/issues/151.
bool EqualTo(const TableVal& v) const; bool EqualTo(const TableVal& v) const;
bool EqualTo(const TableValPtr& v) const { return EqualTo(*(v.get())); }
// Returns true if this set is a subset (not necessarily proper) // Returns true if this set is a subset (not necessarily proper)
// of the given set. // of the given set.