diff --git a/.gitmodules b/.gitmodules index 0c5cccce1a..cb20683187 100644 --- a/.gitmodules +++ b/.gitmodules @@ -55,3 +55,6 @@ [submodule "auxil/c-ares"] path = auxil/c-ares url = https://github.com/c-ares/c-ares +[submodule "auxil/out_ptr"] + path = auxil/out_ptr + url = https://github.com/soasis/out_ptr.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 402a0311a9..5f7df8ffdb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -481,6 +481,7 @@ include(GetArchitecture) include(RequireCXX17) include(FindKqueue) include(FindCAres) +include_directories(BEFORE "auxil/out_ptr/include") if ( (OPENSSL_VERSION VERSION_EQUAL "1.1.0") OR (OPENSSL_VERSION VERSION_GREATER "1.1.0") ) set(ZEEK_HAVE_OPENSSL_1_1 true CACHE INTERNAL "" FORCE) diff --git a/auxil/out_ptr b/auxil/out_ptr new file mode 160000 index 0000000000..ea379b2f35 --- /dev/null +++ b/auxil/out_ptr @@ -0,0 +1 @@ +Subproject commit ea379b2f35e28d6ee894e05ad4c26ed60a613d30 diff --git a/src/DNS_Mgr.cc b/src/DNS_Mgr.cc index ab96a8421f..ed7f0ffee1 100644 --- a/src/DNS_Mgr.cc +++ b/src/DNS_Mgr.cc @@ -25,6 +25,9 @@ #include #endif +#include +using ztd::out_ptr::out_ptr; + #include #include #include @@ -173,6 +176,15 @@ static const char* request_type_string(int request_type) } } +struct ares_deleter + { + void operator()(char* s) const { ares_free_string(s); } + void operator()(unsigned char* s) const { ares_free_string(s); } + void operator()(ares_addrinfo* s) const { ares_freeaddrinfo(s); } + void operator()(struct hostent* h) const { ares_free_hostent(h); } + void operator()(struct ares_txt_reply* h) const { ares_free_data(h); } + }; + namespace zeek::detail { static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result); @@ -205,7 +217,7 @@ private: IPAddr addr; int request_type = 0; // Query type bool async = false; - unsigned char* query = nullptr; + std::unique_ptr query; static uint16_t request_id; }; @@ -225,11 +237,7 @@ DNS_Request::DNS_Request(const IPAddr& addr, bool async) : addr(addr), async(asy request_type = T_PTR; } -DNS_Request::~DNS_Request() - { - if ( query ) - ares_free_string(query); - } +DNS_Request::~DNS_Request() { } void DNS_Request::MakeRequest(ares_channel channel, DNS_Mgr* mgr) { @@ -258,17 +266,18 @@ void DNS_Request::MakeRequest(ares_channel channel, DNS_Mgr* mgr) else query_host = host; - unsigned char* query = NULL; + std::unique_ptr query_str; int len = 0; int status = ares_create_query(query_host.c_str(), C_IN, request_type, - DNS_Request::request_id, 1, &query, &len, 0); + DNS_Request::request_id, 1, + out_ptr(query_str), &len, 0); if ( status != ARES_SUCCESS ) return; // Store this so it can be destroyed when the request is destroyed. - this->query = query; - ares_send(channel, query, len, query_cb, req_data); + this->query = std::move(query_str); + ares_send(channel, this->query.get(), len, query_cb, req_data); } } @@ -305,27 +314,22 @@ static int get_ttl(unsigned char* abuf, int alen, int* ttl) { int status; long len; - char* hostname = NULL; + std::unique_ptr hostname; *ttl = DNS_TIMEOUT; unsigned char* aptr = abuf + HFIXEDSZ; - status = ares_expand_name(aptr, abuf, alen, &hostname, &len); + status = ares_expand_name(aptr, abuf, alen, out_ptr(hostname), &len); if ( status != ARES_SUCCESS ) - { - ares_free_string(hostname); return status; - } + if ( aptr + len + QFIXEDSZ > abuf + alen ) - { - ares_free_string(hostname); return ARES_EBADRESP; - } aptr += len + QFIXEDSZ; - ares_free_string(hostname); + hostname.reset(); - status = ares_expand_name(aptr, abuf, alen, &hostname, &len); + status = ares_expand_name(aptr, abuf, alen, out_ptr(hostname), &len); if ( status != ARES_SUCCESS ) return status; @@ -333,8 +337,6 @@ static int get_ttl(unsigned char* abuf, int alen, int* ttl) return ARES_EBADRESP; aptr += len; - ares_free_string(hostname); - *ttl = DNS_RR_TTL(aptr); return status; @@ -348,6 +350,7 @@ static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinf { auto arg_data = reinterpret_cast(arg); const auto [req, mgr] = *arg_data; + std::unique_ptr res_ptr(result); if ( status != ARES_SUCCESS ) { @@ -387,8 +390,9 @@ static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinf // Push a null on the end so the addr list has a final point during later parsing. addrs.push_back(NULL); - struct hostent he; - memset(&he, 0, sizeof(struct hostent)); + struct hostent he + { + }; he.h_name = util::copy_string(result->name); he.h_addrtype = AF_INET; he.h_length = sizeof(in_addr); @@ -404,8 +408,9 @@ static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinf // Push a null on the end so the addr list has a final point during later parsing. addrs6.push_back(NULL); - struct hostent he; - memset(&he, 0, sizeof(struct hostent)); + struct hostent he + { + }; he.h_name = util::copy_string(result->name); he.h_addrtype = AF_INET6; he.h_length = sizeof(in6_addr); @@ -418,9 +423,10 @@ static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinf } req->ProcessAsyncResult(timeouts > 0, mgr); - ares_freeaddrinfo(result); - delete arg_data; + + // TODO: might need to turn these into unique_ptr as well? delete req; + delete arg_data; } static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, int len) @@ -455,25 +461,24 @@ static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, in { case T_PTR: { - struct hostent* he; + std::unique_ptr he; if ( req->Addr().GetFamily() == IPv4 ) { struct in_addr addr; req->Addr().CopyIPv4(&addr); - status = ares_parse_ptr_reply(buf, len, &addr, sizeof(addr), AF_INET, &he); + status = ares_parse_ptr_reply(buf, len, &addr, sizeof(addr), AF_INET, + out_ptr(he)); } else { struct in6_addr addr; req->Addr().CopyIPv6(&addr); - status = ares_parse_ptr_reply(buf, len, &addr, sizeof(addr), AF_INET6, &he); + status = ares_parse_ptr_reply(buf, len, &addr, sizeof(addr), AF_INET6, + out_ptr(he)); } if ( status == ARES_SUCCESS ) - { - mgr->AddResult(req, he, ttl); - ares_free_hostent(he); - } + mgr->AddResult(req, he.get(), ttl); else { // See above for why DNS_TIMEOUT here. @@ -483,8 +488,8 @@ static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, in } case T_TXT: { - struct ares_txt_reply* reply; - int r = ares_parse_txt_reply(buf, len, &reply); + std::unique_ptr reply; + int r = ares_parse_txt_reply(buf, len, out_ptr(reply)); if ( r == ARES_SUCCESS ) { // Use a hostent to send the data into AddResult(). We only care about @@ -495,13 +500,13 @@ static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, in // rest away. There really isn't a good reason for this, we just haven't // ever done so. It would likely require some changes to the output from // Lookup(), since right now it only returns one value. - struct hostent he; - memset(&he, 0, sizeof(struct hostent)); + struct hostent he + { + }; he.h_name = util::copy_string(reinterpret_cast(reply->txt)); mgr->AddResult(req, &he, ttl); delete[] he.h_name; - ares_free_data(reply); } else {