Use shared_ptrs for DNS_Mapping objects

This commit is contained in:
Tim Wojtulewicz 2022-03-28 14:59:48 -07:00
parent b531ec97ef
commit c4cac72fd7
4 changed files with 32 additions and 36 deletions

View file

@ -171,7 +171,7 @@ void DNS_Mapping::Save(FILE* f) const
fprintf(f, "%s\n", addr.AsString().c_str());
}
void DNS_Mapping::Merge(DNS_Mapping* other)
void DNS_Mapping::Merge(const DNS_MappingPtr& other)
{
std::copy(other->names.begin(), other->names.end(), std::back_inserter(names));
std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs));

View file

@ -11,6 +11,9 @@
namespace zeek::detail
{
class DNS_Mapping;
using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>;
class DNS_Mapping
{
public:
@ -51,7 +54,7 @@ public:
return util::current_time() > (creation_time + req_ttl);
}
void Merge(DNS_Mapping* other);
void Merge(const DNS_MappingPtr& other);
static void InitializeCache(FILE* f);
static bool ValidateCacheVersion(FILE* f);

View file

@ -961,25 +961,25 @@ void DNS_Mgr::Resolve()
}
}
void DNS_Mgr::Event(EventHandlerPtr e, DNS_Mapping* dm)
void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& dm)
{
if ( e )
event_mgr.Enqueue(e, BuildMappingVal(dm));
}
void DNS_Mgr::Event(EventHandlerPtr e, DNS_Mapping* dm, ListValPtr l1, ListValPtr l2)
void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& dm, ListValPtr l1, ListValPtr l2)
{
if ( e )
event_mgr.Enqueue(e, BuildMappingVal(dm), l1->ToSetVal(), l2->ToSetVal());
}
void DNS_Mgr::Event(EventHandlerPtr e, DNS_Mapping* old_dm, DNS_Mapping* new_dm)
void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& old_dm, DNS_MappingPtr new_dm)
{
if ( e )
event_mgr.Enqueue(e, BuildMappingVal(old_dm), BuildMappingVal(new_dm));
}
ValPtr DNS_Mgr::BuildMappingVal(DNS_Mapping* dm)
ValPtr DNS_Mgr::BuildMappingVal(const DNS_MappingPtr& dm)
{
if ( ! dm_rec )
return nullptr;
@ -1002,14 +1002,14 @@ void DNS_Mgr::AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool m
{
// TODO: the existing code doesn't handle hostname aliases at all. Should we?
DNS_Mapping* new_mapping = nullptr;
DNS_Mapping* prev_mapping = nullptr;
DNS_MappingPtr new_mapping = nullptr;
DNS_MappingPtr prev_mapping = nullptr;
bool keep_prev = true;
MappingMap::iterator it;
if ( dr->RequestType() == T_PTR )
{
new_mapping = new DNS_Mapping(dr->Addr(), h, ttl);
new_mapping = std::make_shared<DNS_Mapping>(dr->Addr(), h, ttl);
it = all_mappings.find(dr->Addr());
if ( it == all_mappings.end() )
{
@ -1021,7 +1021,7 @@ void DNS_Mgr::AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool m
}
else
{
new_mapping = new DNS_Mapping(dr->Host(), h, ttl, dr->RequestType());
new_mapping = std::make_shared<DNS_Mapping>(dr->Host(), h, ttl, dr->RequestType());
auto key = std::make_pair(dr->RequestType(), dr->Host());
it = all_mappings.find(key);
@ -1055,12 +1055,12 @@ void DNS_Mgr::AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool m
CompareMappings(prev_mapping, new_mapping);
if ( keep_prev )
delete new_mapping;
new_mapping.reset();
else
delete prev_mapping;
prev_mapping.reset();
}
void DNS_Mgr::CompareMappings(DNS_Mapping* prev_mapping, DNS_Mapping* new_mapping)
void DNS_Mgr::CompareMappings(const DNS_MappingPtr& prev_mapping, const DNS_MappingPtr& new_mapping)
{
if ( prev_mapping->Failed() )
{
@ -1100,14 +1100,14 @@ void DNS_Mgr::CompareMappings(DNS_Mapping* prev_mapping, DNS_Mapping* new_mappin
return;
}
auto prev_delta = AddrListDelta(prev_a.get(), new_a.get());
auto new_delta = AddrListDelta(new_a.get(), prev_a.get());
auto prev_delta = AddrListDelta(prev_a, new_a);
auto new_delta = AddrListDelta(new_a, prev_a);
if ( prev_delta->Length() > 0 || new_delta->Length() > 0 )
Event(dns_mapping_altered, new_mapping, std::move(prev_delta), std::move(new_delta));
}
ListValPtr DNS_Mgr::AddrListDelta(ListVal* al1, ListVal* al2)
ListValPtr DNS_Mgr::AddrListDelta(ListValPtr al1, ListValPtr al2)
{
auto delta = make_intrusive<ListVal>(TYPE_ADDR);
@ -1142,8 +1142,8 @@ void DNS_Mgr::LoadCache(const std::string& path)
return;
// Loop until we find a mapping that doesn't initialize correctly.
DNS_Mapping* m = new DNS_Mapping(f);
for ( ; ! m->NoMapping() && ! m->InitFailed(); m = new DNS_Mapping(f) )
auto m = std::make_shared<DNS_Mapping>(f);
for ( ; ! m->NoMapping() && ! m->InitFailed(); m = std::make_shared<DNS_Mapping>(f) )
{
if ( m->ReqHost() )
all_mappings.insert_or_assign(std::make_pair(m->ReqType(), m->ReqHost()), m);
@ -1154,7 +1154,6 @@ void DNS_Mgr::LoadCache(const std::string& path)
if ( ! m->NoMapping() )
reporter->FatalError("DNS cache corrupted");
delete m;
fclose(f);
}
@ -1192,7 +1191,7 @@ TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_exp
if ( it == all_mappings.end() )
return nullptr;
DNS_Mapping* d = it->second;
auto d = it->second;
if ( ! d || d->names.empty() )
return nullptr;
@ -1200,7 +1199,6 @@ TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_exp
if ( cleanup_expired && (d && d->Expired()) )
{
all_mappings.erase(it);
delete d;
return nullptr;
}
@ -1219,12 +1217,11 @@ StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired
if ( it == all_mappings.end() )
return nullptr;
DNS_Mapping* d = it->second;
auto d = it->second;
if ( cleanup_expired && d->Expired() )
{
all_mappings.erase(it);
delete d;
return nullptr;
}
else if ( check_failed && d->Failed() )
@ -1247,12 +1244,11 @@ StringValPtr DNS_Mgr::LookupOtherInCache(const std::string& name, int request_ty
if ( it == all_mappings.end() )
return nullptr;
DNS_Mapping* d = it->second;
auto d = it->second;
if ( cleanup_expired && d->Expired() )
{
all_mappings.erase(it);
delete d;
return nullptr;
}
@ -1373,10 +1369,6 @@ void DNS_Mgr::CheckAsyncOtherRequest(const std::string& host, bool timeout, int
void DNS_Mgr::Flush()
{
Resolve();
for ( auto& [key, mapping] : all_mappings )
delete mapping;
all_mappings.clear();
}

View file

@ -45,6 +45,7 @@ using StringValPtr = IntrusivePtr<StringVal>;
namespace zeek::detail
{
class DNS_Mapping;
using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>;
class DNS_Request;
enum DNS_MgrMode
@ -258,17 +259,17 @@ protected:
void CheckAsyncHostRequest(const std::string& host, bool timeout);
void CheckAsyncOtherRequest(const std::string& host, bool timeout, int request_type);
void Event(EventHandlerPtr e, DNS_Mapping* dm);
void Event(EventHandlerPtr e, DNS_Mapping* dm, ListValPtr l1, ListValPtr l2);
void Event(EventHandlerPtr e, DNS_Mapping* old_dm, DNS_Mapping* new_dm);
void Event(EventHandlerPtr e, const DNS_MappingPtr& dm);
void Event(EventHandlerPtr e, const DNS_MappingPtr& dm, ListValPtr l1, ListValPtr l2);
void Event(EventHandlerPtr e, const DNS_MappingPtr& old_dm, DNS_MappingPtr new_dm);
ValPtr BuildMappingVal(DNS_Mapping* dm);
ValPtr BuildMappingVal(const DNS_MappingPtr& dm);
void CompareMappings(DNS_Mapping* prev_dm, DNS_Mapping* new_dm);
ListValPtr AddrListDelta(ListVal* al1, ListVal* al2);
void CompareMappings(const DNS_MappingPtr& prev_dm, const DNS_MappingPtr& new_dm);
ListValPtr AddrListDelta(ListValPtr al1, ListValPtr al2);
using MappingKey = std::variant<IPAddr, std::pair<int, std::string>>;
using MappingMap = std::map<MappingKey, DNS_Mapping*>;
using MappingMap = std::map<MappingKey, DNS_MappingPtr>;
void LoadCache(const std::string& path);
void Save(FILE* f, const MappingMap& m);