Rework DNS_Mgr API to be more consistent and to support more request types

This commit is contained in:
Tim Wojtulewicz 2022-01-11 17:07:19 -07:00
parent 336c6ae5c2
commit 9f197aa458
6 changed files with 920 additions and 721 deletions

View file

@ -2,7 +2,7 @@
#pragma once
#include <ares.h>
#include <netdb.h>
#include <list>
#include <map>
#include <queue>
@ -14,6 +14,18 @@
#include "zeek/iosource/IOSource.h"
#include "zeek/util.h"
// These are defined in ares headers but we don't want to have to include
// those headers here and create install dependencies on them.
struct ares_channeldata;
typedef struct ares_channeldata* ares_channel;
#ifndef T_PTR
#define T_PTR 12
#endif
#ifndef T_TXT
#define T_TXT 16
#endif
namespace zeek
{
class Val;
@ -31,8 +43,8 @@ using StringValPtr = IntrusivePtr<StringVal>;
namespace zeek::detail
{
class DNS_Mgr_Request;
class DNS_Mapping;
class DNS_Request;
enum DNS_MgrMode
{
@ -42,9 +54,44 @@ enum DNS_MgrMode
DNS_FAKE, // don't look up names, just return dummy results
};
class DNS_Mgr final : public iosource::IOSource
class DNS_Mgr : public iosource::IOSource
{
public:
/**
* Base class for callback handling for asynchronous lookups.
*/
class LookupCallback
{
public:
virtual ~LookupCallback() = default;
/**
* Called when an address lookup finishes.
*
* @param name The resulting name from the lookup.
*/
virtual void Resolved(const std::string& name){};
/**
* Called when a name lookup finishes.
*
* @param addrs A table of the resulting addresses from the lookup.
*/
virtual void Resolved(TableValPtr addrs){};
/**
* Generic callback method for all request types.
*
* @param val A Val containing the data from the query.
*/
virtual void Resolved(ValPtr data, int request_type) { }
/**
* Called when a timeout request occurs.
*/
virtual void Timeout() = 0;
};
explicit DNS_Mgr(DNS_MgrMode mode);
~DNS_Mgr() override;
@ -61,27 +108,79 @@ public:
void Flush();
/**
* Looks up the address(es) of a given host and returns a set of addr.
* This is a synchronous method and will block until results are ready.
* Looks up the address(es) of a given host and returns a set of addresses.
* This is a shorthand method for doing A/AAAA requests. This is a
* synchronous request and will block until the request completes or times
* out.
*
* @param host The host name to look up an address for.
* @return A set of addresses.
* @param host The hostname to lookup an address for.
* @return A set of addresses for the host.
*/
TableValPtr LookupHost(const char* host);
TableValPtr LookupHost(const std::string& host);
/**
* Looks up the hostname of a given address. This is a synchronous method
* and will block until results are ready.
* Looks up the hostname of a given address. This is a shorthand method for
* doing PTR requests. This is a synchronous request and will block until
* the request completes or times out.
*
* @param host The addr to lookup a hostname for.
* @return The hostname.
* @return The hostname for the address.
*/
StringValPtr LookupAddr(const IPAddr& addr);
/**
* Performs a generic request to the DNS server. This is a synchronous
* request and will block until the request completes or times out.
*
* @param name The name or address to make a request for. If this is an
* address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa).
* Note that calling LookupAddr for PTR requests does this conversion
* automatically.
* @param request_type The type of request to make. This should be one of
* the type values defined in arpa/nameser.h or ares_nameser.h.
* @return The requested data.
*/
ValPtr Lookup(const std::string& name, int request_type);
/**
* Looks up the address(es) of a given host. This is a shorthand method
* for doing A/AAAA requests. This is an asynchronous request. The
* response will be handled via the provided callback object.
*
* @param host The hostname to lookup an address for.
* @param callback A callback object for handling the response.
*/
void LookupHost(const std::string& host, LookupCallback* callback);
/**
* Looks up the hostname of a given address. This is a shorthand method for
* doing PTR requests. This is an asynchronous request. The response will
* be handled via the provided callback object.
*
* @param host The addr to lookup a hostname for.
* @param callback A callback object for handling the response.
*/
void LookupAddr(const IPAddr& addr, LookupCallback* callback);
/**
* Performs a generic request to the DNS server. This is an asynchronous
* request. The response will be handled via the provided callback
* object.
*
* @param name The name or address to make a request for. If this is an
* address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa).
* Note that calling LookupAddr for PTR requests does this conversion
* automatically.
* @param request_type The type of request to make. This should be one of
* the type values defined in arpa/nameser.h or ares_nameser.h.
* @param callback A callback object for handling the response.
*/
void Lookup(const std::string& name, int request_type, LookupCallback* callback);
/**
* Sets the directory where to store DNS data when Save() is called.
*/
void SetDir(const char* arg_dir) { dir = arg_dir; }
void SetDir(const std::string& arg_dir) { dir = arg_dir; }
/**
* Waits for responses to become available or a timeout to occur,
@ -94,61 +193,6 @@ public:
*/
bool Save();
/**
* Base class for callback handling for asynchronous lookups.
*/
class LookupCallback
{
public:
virtual ~LookupCallback() = default;
/**
* Called when an address lookup finishes.
*
* @param name The resulting name from the lookup.
*/
virtual void Resolved(const char* name){};
/**
* Called when a name lookup finishes.
*
* @param addrs A table of the resulting addresses from the lookup.
*/
virtual void Resolved(TableVal* addrs){};
/**
* Called when a timeout request occurs.
*/
virtual void Timeout() = 0;
};
/**
* Schedules an asynchronous request to lookup a hostname for an IP address.
* This is the equivalent of an "A" or "AAAA" request, depending on if the
* address is ipv4 or ipv6.
*
* @param host The address to lookup names for.
* @param callback A callback object to call when the request completes.
*/
void AsyncLookupAddr(const IPAddr& host, LookupCallback* callback);
/**
* Schedules an asynchronous request to lookup an address for a hostname.
* This is the equivalent of a "PTR" request.
*
* @param host The hostname to look up addresses for.
* @param callback A callback object to call when the request completes.
*/
void AsyncLookupName(const std::string& name, LookupCallback* callback);
/**
* Schedules an asynchronous TXT request for a hostname.
*
* @param host The address to lookup names for.
* @param callback A callback object to call when the request completes.
*/
void AsyncLookupNameText(const std::string& name, LookupCallback* callback);
struct Stats
{
unsigned long requests; // These count only async requests.
@ -175,7 +219,7 @@ public:
* @param h A hostent structure containing the actual result data.
* @param ttl A ttl value contained in the response from the server.
*/
void AddResult(DNS_Mgr_Request* dr, struct hostent* h, uint32_t ttl);
void AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl);
/**
* Returns an empty set of addresses, used in various error cases and during
@ -184,18 +228,30 @@ public:
static TableValPtr empty_addr_set();
/**
* This method is used to call the private Process() method during unit testing
* and shouldn't be used otherwise.
* Returns the full path to the file used to store the DNS cache.
*/
void TestProcess();
std::string CacheFile() const { return cache_name; }
/**
* Used by the c-ares socket call back to register/unregister a socket file descriptor.
*/
void RegisterSocket(int fd, bool active);
protected:
friend class LookupCallback;
friend class DNS_Mgr_Request;
friend class DNS_Request;
const char* LookupAddrInCache(const IPAddr& addr);
TableValPtr LookupNameInCache(const std::string& name);
const char* LookupTextInCache(const std::string& name);
StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false,
bool check_failed = false);
TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false,
bool check_failed = false);
StringValPtr LookupTextInCache(const std::string& name, bool cleanup_expired = false);
// Finish the request if we have a result. If not, time it out if
// requested.
void CheckAsyncAddrRequest(const IPAddr& addr, bool timeout);
void CheckAsyncHostRequest(const std::string& host, bool timeout);
void CheckAsyncTextRequest(const std::string& host, bool timeout);
void Event(EventHandlerPtr e, DNS_Mapping* dm);
void Event(EventHandlerPtr e, DNS_Mapping* dm, ListValPtr l1, ListValPtr l2);
@ -205,7 +261,6 @@ protected:
void CompareMappings(DNS_Mapping* prev_dm, DNS_Mapping* new_dm);
ListValPtr AddrListDelta(ListVal* al1, ListVal* al2);
void DumpAddrList(FILE* f, ListVal* al);
using HostMap = std::map<std::string, std::pair<DNS_Mapping*, DNS_Mapping*>>;
using AddrMap = std::map<IPAddr, DNS_Mapping*>;
@ -217,12 +272,6 @@ protected:
// Issue as many queued async requests as slots are available.
void IssueAsyncRequests();
// Finish the request if we have a result. If not, time it out if
// requested.
void CheckAsyncAddrRequest(const IPAddr& addr, bool timeout);
void CheckAsyncHostRequest(const char* host, bool timeout);
void CheckAsyncTextRequest(const char* host, bool timeout);
// IOSource interface.
void Process() override;
void InitSource() override;
@ -235,9 +284,6 @@ protected:
AddrMap addr_mappings;
TextMap text_mappings;
using DNS_mgr_request_list = PList<DNS_Mgr_Request>;
DNS_mgr_request_list requests;
std::string cache_name;
std::string dir; // directory in which cache_name resides
@ -247,26 +293,30 @@ protected:
RecordTypePtr dm_rec;
ares_channel channel;
bool ipv6_resolver = false;
using CallbackList = std::list<LookupCallback*>;
struct AsyncRequest
{
double time = 0.0;
IPAddr host;
std::string name;
IPAddr addr;
std::string host;
CallbackList callbacks;
bool is_txt = false;
bool processed = false;
bool IsAddrReq() const { return name.empty(); }
bool IsAddrReq() const { return host.empty(); }
void Resolved(const char* name);
void Resolved(TableVal* addrs);
void Resolved(const std::string& name);
void Resolved(TableValPtr addrs);
void Timeout();
};
struct AsyncRequestCompare
{
bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; }
};
using AsyncRequestAddrMap = std::map<IPAddr, AsyncRequest*>;
AsyncRequestAddrMap asyncs_addrs;
@ -279,18 +329,15 @@ protected:
using QueuedList = std::list<AsyncRequest*>;
QueuedList asyncs_queued;
struct AsyncRequestCompare
{
bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; }
};
using TimeoutQueue =
std::priority_queue<AsyncRequest*, std::vector<AsyncRequest*>, AsyncRequestCompare>;
TimeoutQueue asyncs_timeouts;
unsigned long num_requests;
unsigned long successful;
unsigned long failed;
unsigned long num_requests = 0;
unsigned long successful = 0;
unsigned long failed = 0;
std::set<int> socket_fds;
};
extern DNS_Mgr* dns_mgr;