diff --git a/src/DNS_Mgr.cc b/src/DNS_Mgr.cc index ed7f0ffee1..3fc2404dfa 100644 --- a/src/DNS_Mgr.cc +++ b/src/DNS_Mgr.cc @@ -537,10 +537,10 @@ static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, in static void sock_cb(void* data, int s, int read, int write) { auto mgr = reinterpret_cast(data); - mgr->RegisterSocket(s, read == 1); + mgr->RegisterSocket(s, read == 1, write == 1); } -DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) : mode(arg_mode) +DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) : IOSource(true), mode(arg_mode) { ares_library_init(ARES_LIB_INIT_ALL); } @@ -554,17 +554,28 @@ DNS_Mgr::~DNS_Mgr() ares_library_cleanup(); } -void DNS_Mgr::RegisterSocket(int fd, bool active) +void DNS_Mgr::RegisterSocket(int fd, bool read, bool write) { - if ( active && socket_fds.count(fd) == 0 ) + if ( read && socket_fds.count(fd) == 0 ) { socket_fds.insert(fd); - iosource_mgr->RegisterFd(fd, this); + iosource_mgr->RegisterFd(fd, this, IOSource::READ); } - else if ( ! active && socket_fds.count(fd) != 0 ) + else if ( ! read && socket_fds.count(fd) != 0 ) { socket_fds.erase(fd); - iosource_mgr->UnregisterFd(fd, this); + iosource_mgr->UnregisterFd(fd, this, IOSource::READ); + } + + if ( write && write_socket_fds.count(fd) == 0 ) + { + write_socket_fds.insert(fd); + iosource_mgr->RegisterFd(fd, this, IOSource::WRITE); + } + else if ( ! write && write_socket_fds.count(fd) != 0 ) + { + write_socket_fds.erase(fd); + iosource_mgr->UnregisterFd(fd, this, IOSource::WRITE); } } @@ -1385,20 +1396,13 @@ double DNS_Mgr::GetNextTimeout() return run_state::network_time + DNS_TIMEOUT; } -void DNS_Mgr::Process() +void DNS_Mgr::ProcessFd(int fd, int flags) { - // If iosource_mgr says that we got a result on the socket fd, we don't have to ask c-ares - // to retrieve it for us. We have the file descriptor already, just call ares_process_fd() - // with it. Unfortunately, we may also have sockets close during this call, so we need to - // to make a copy of the list first. Having a list change while looping over it can - // cause segfaults. - decltype(socket_fds) temp_fds{socket_fds}; - - for ( int fd : temp_fds ) + if ( socket_fds.count(fd) != 0 ) { - // double check this one wasn't removed already before trying to process it - if ( socket_fds.count(fd) != 0 ) - ares_process_fd(channel, fd, ARES_SOCKET_BAD); + int read_fd = (flags & IOSource::ProcessFlags::READ) != 0 ? fd : ARES_SOCKET_BAD; + int write_fd = (flags & IOSource::ProcessFlags::WRITE) != 0 ? fd : ARES_SOCKET_BAD; + ares_process_fd(channel, read_fd, write_fd); } IssueAsyncRequests(); diff --git a/src/DNS_Mgr.h b/src/DNS_Mgr.h index a392ee9f93..9834f68a09 100644 --- a/src/DNS_Mgr.h +++ b/src/DNS_Mgr.h @@ -240,7 +240,9 @@ public: /** * Used by the c-ares socket call back to register/unregister a socket file descriptor. */ - void RegisterSocket(int fd, bool active); + void RegisterSocket(int fd, bool read, bool write); + + ares_channel& GetChannel() { return channel; } protected: friend class LookupCallback; @@ -277,7 +279,8 @@ protected: void IssueAsyncRequests(); // IOSource interface. - void Process() override; + void Process() override { } + void ProcessFd(int fd, int flags) override; void InitSource() override; const char* Tag() override { return "DNS_Mgr"; } double GetNextTimeout() override; @@ -337,6 +340,7 @@ protected: unsigned long failed = 0; std::set socket_fds; + std::set write_socket_fds; }; extern DNS_Mgr* dns_mgr; diff --git a/src/RunState.cc b/src/RunState.cc index 30abb987e0..76ac99e79d 100644 --- a/src/RunState.cc +++ b/src/RunState.cc @@ -283,7 +283,7 @@ void run_loop() { util::detail::set_processing_status("RUNNING", "run_loop"); - std::vector ready; + iosource::Manager::ReadySources ready; ready.reserve(iosource_mgr->TotalSize()); while ( iosource_mgr->Size() || (BifConst::exit_only_after_terminate && ! terminating) ) @@ -310,11 +310,16 @@ void run_loop() if ( ! ready.empty() ) { - for ( auto src : ready ) + for ( const auto& src : ready ) { - DBG_LOG(DBG_MAINLOOP, "processing source %s", src->Tag()); - current_iosrc = src; - src->Process(); + auto* iosrc = src.src; + + DBG_LOG(DBG_MAINLOOP, "processing source %s", iosrc->Tag()); + current_iosrc = iosrc; + if ( iosrc->ImplementsProcessFd() && src.fd != -1 ) + iosrc->ProcessFd(src.fd, src.flags); + else + iosrc->Process(); } } else if ( (have_pending_timers || communication_enabled || diff --git a/src/iosource/IOSource.h b/src/iosource/IOSource.h index 8e863f78e1..16fc76a6c5 100644 --- a/src/iosource/IOSource.h +++ b/src/iosource/IOSource.h @@ -12,10 +12,20 @@ namespace zeek::iosource class IOSource { public: + enum ProcessFlags + { + READ = 0x01, + WRITE = 0x02 + }; + /** * Constructor. + * + * @param process_fd A flag for indicating whether the child class implements + * the ProcessFd() method. This is used by the run loop for dispatching to the + * appropriate process method. */ - IOSource() { closed = false; } + IOSource(bool process_fd = false) : implements_process_fd(process_fd) { } /** * Destructor. @@ -66,6 +76,19 @@ public: */ virtual void Process() = 0; + /** + * Optional process method that allows an IOSource to only process + * the file descriptor that is found ready and not every possible + * descriptor. If this method is implemented, true must be passed + * to the IOSource constructor via the child class. + * + * @param fd The file descriptor to process. + * @param flags Flags indicating what type of event is being + * processed. + */ + virtual void ProcessFd(int fd, int flags) { } + bool ImplementsProcessFd() const { return implements_process_fd; } + /** * Returns a descriptive tag representing the source for debugging. * @@ -84,7 +107,8 @@ protected: void SetClosed(bool is_closed) { closed = is_closed; } private: - bool closed; + bool closed = false; + bool implements_process_fd = false; }; } // namespace zeek::iosource diff --git a/src/iosource/Manager.cc b/src/iosource/Manager.cc index 8f669512fa..6c9876c398 100644 --- a/src/iosource/Manager.cc +++ b/src/iosource/Manager.cc @@ -103,7 +103,7 @@ void Manager::Wakeup(const std::string& where) wakeup->Ping(where); } -void Manager::FindReadySources(std::vector* ready) +void Manager::FindReadySources(ReadySources* ready) { ready->clear(); @@ -155,7 +155,7 @@ void Manager::FindReadySources(std::vector* ready) if ( timeout == 0 && ! time_to_poll ) { added = true; - ready->push_back(timeout_src); + ready->push_back({timeout_src, -1, 0}); } } @@ -167,13 +167,13 @@ void Manager::FindReadySources(std::vector* ready) // Avoid calling Poll() if we can help it since on very // high-traffic networks, we spend too much time in // Poll() and end up dropping packets. - ready->push_back(pkt_src); + ready->push_back({pkt_src, -1, 0}); } else { if ( ! run_state::pseudo_realtime && ! time_to_poll ) // A pcap file is always ready to process unless it's suspended - ready->push_back(pkt_src); + ready->push_back({pkt_src, -1, 0}); } } } @@ -189,7 +189,7 @@ void Manager::FindReadySources(std::vector* ready) Poll(ready, timeout, timeout_src); } -void Manager::Poll(std::vector* ready, double timeout, IOSource* timeout_src) +void Manager::Poll(ReadySources* ready, double timeout, IOSource* timeout_src) { struct timespec kqueue_timeout; ConvertTimeout(timeout, kqueue_timeout); @@ -205,7 +205,7 @@ void Manager::Poll(std::vector* ready, double timeout, IOSource* time else if ( ret == 0 ) { if ( timeout_src ) - ready->push_back(timeout_src); + ready->push_back({timeout_src, -1, 0}); } else { @@ -217,7 +217,15 @@ void Manager::Poll(std::vector* ready, double timeout, IOSource* time { std::map::const_iterator it = fd_map.find(events[i].ident); if ( it != fd_map.end() ) - ready->push_back(it->second); + ready->push_back({it->second, static_cast(events[i].ident), + IOSource::ProcessFlags::READ}); + } + else if ( events[i].filter == EVFILT_WRITE ) + { + std::map::const_iterator it = write_fd_map.find(events[i].ident); + if ( it != write_fd_map.end() ) + ready->push_back({it->second, static_cast(events[i].ident), + IOSource::ProcessFlags::WRITE}); } } } @@ -240,41 +248,97 @@ void Manager::ConvertTimeout(double timeout, struct timespec& spec) } } -bool Manager::RegisterFd(int fd, IOSource* src) +bool Manager::RegisterFd(int fd, IOSource* src, int flags) { - struct kevent event; - EV_SET(&event, fd, EVFILT_READ, EV_ADD, 0, 0, NULL); - int ret = kevent(event_queue, &event, 1, NULL, 0, NULL); - if ( ret != -1 ) - { - events.push_back({}); - DBG_LOG(DBG_MAINLOOP, "Registered fd %d from %s", fd, src->Tag()); - fd_map[fd] = src; + std::vector new_events; - Wakeup("RegisterFd"); - return true; - } - else + if ( (flags & IOSource::READ) != 0 ) { - reporter->Error("Failed to register fd %d from %s: %s", fd, src->Tag(), strerror(errno)); - return false; + if ( fd_map.count(fd) == 0 ) + { + new_events.push_back({}); + EV_SET(&(new_events.back()), fd, EVFILT_READ, EV_ADD, 0, 0, NULL); + } } + if ( (flags & IOSource::WRITE) != 0 ) + { + if ( write_fd_map.count(fd) == 0 ) + { + new_events.push_back({}); + EV_SET(&(new_events.back()), fd, EVFILT_WRITE, EV_ADD, 0, 0, NULL); + } + } + + if ( ! new_events.empty() ) + { + int ret = kevent(event_queue, new_events.data(), new_events.size(), NULL, 0, NULL); + if ( ret != -1 ) + { + DBG_LOG(DBG_MAINLOOP, "Registered fd %d from %s", fd, src->Tag()); + for ( const auto& a : new_events ) + events.push_back({}); + + if ( (flags & IOSource::READ) != 0 ) + fd_map[fd] = src; + if ( (flags & IOSource::WRITE) != 0 ) + write_fd_map[fd] = src; + + Wakeup("RegisterFd"); + return true; + } + else + { + reporter->Error("Failed to register fd %d from %s: %s (flags %d)", fd, src->Tag(), + strerror(errno), flags); + return false; + } + } + + return true; } -bool Manager::UnregisterFd(int fd, IOSource* src) +bool Manager::UnregisterFd(int fd, IOSource* src, int flags) { - if ( fd_map.find(fd) != fd_map.end() ) + std::vector new_events; + + if ( (flags & IOSource::READ) != 0 ) { - struct kevent event; - EV_SET(&event, fd, EVFILT_READ, EV_DELETE, 0, 0, NULL); - int ret = kevent(event_queue, &event, 1, NULL, 0, NULL); + if ( fd_map.count(fd) != 0 ) + { + new_events.push_back({}); + EV_SET(&(new_events.back()), fd, EVFILT_READ, EV_DELETE, 0, 0, NULL); + } + } + if ( (flags & IOSource::WRITE) != 0 ) + { + if ( write_fd_map.count(fd) != 0 ) + { + new_events.push_back({}); + EV_SET(&(new_events.back()), fd, EVFILT_WRITE, EV_DELETE, 0, 0, NULL); + } + } + + if ( ! new_events.empty() ) + { + int ret = kevent(event_queue, new_events.data(), new_events.size(), NULL, 0, NULL); if ( ret != -1 ) + { DBG_LOG(DBG_MAINLOOP, "Unregistered fd %d from %s", fd, src->Tag()); + for ( const auto& a : new_events ) + events.pop_back(); - fd_map.erase(fd); + if ( (flags & IOSource::READ) != 0 ) + fd_map.erase(fd); + if ( (flags & IOSource::WRITE) != 0 ) + write_fd_map.erase(fd); - Wakeup("UnregisterFd"); - return true; + Wakeup("UnregisterFd"); + return true; + } + + // We don't care about failure here. If it failed to unregister, it's likely because + // the file descriptor was already closed, and kqueue already automatically removed + // it. } else { @@ -282,6 +346,8 @@ bool Manager::UnregisterFd(int fd, IOSource* src) src->Tag()); return false; } + + return true; } void Manager::Register(IOSource* src, bool dont_count, bool manage_lifetime) diff --git a/src/iosource/Manager.h b/src/iosource/Manager.h index 5053bbd132..1e2bbbface 100644 --- a/src/iosource/Manager.h +++ b/src/iosource/Manager.h @@ -29,6 +29,15 @@ class PktDumper; class Manager { public: + struct ReadySource + { + IOSource* src = nullptr; + int fd = -1; + int flags = 0; + }; + + using ReadySources = std::vector; + /** * Constructor. */ @@ -110,22 +119,25 @@ public: * * @param ready A vector used to return the set of sources that are ready. */ - void FindReadySources(std::vector* ready); + void FindReadySources(ReadySources* ready); /** * Registers a file descriptor and associated IOSource with the manager - * to be checked during FindReadySources. + * to be checked during FindReadySources. This will register the file + * descriptor to check for read events. * * @param fd A file descriptor pointing at some resource that should be * checked for readiness. * @param src The IOSource that owns the file descriptor. + * @param flags A combination of values from IOSource::ProcessFlags for + * which modes we should register for this file descriptor. */ - bool RegisterFd(int fd, IOSource* src); + bool RegisterFd(int fd, IOSource* src, int flags = IOSource::READ); /** * Unregisters a file descriptor from the FindReadySources checks. */ - bool UnregisterFd(int fd, IOSource* src); + bool UnregisterFd(int fd, IOSource* src, int flags = IOSource::READ); /** * Forces the poll in FindReadySources to wake up immediately. This method @@ -147,7 +159,7 @@ private: * @param timeout_src The source associated with the current timeout value. * This is typically a timer manager object. */ - void Poll(std::vector* ready, double timeout, IOSource* timeout_src); + void Poll(ReadySources* ready, double timeout, IOSource* timeout_src); /** * Converts a double timeout value into a timespec struct used for calls @@ -208,6 +220,7 @@ private: int event_queue = -1; std::map fd_map; + std::map write_fd_map; // This is only used for the output of the call to kqueue in FindReadySources(). // The actual events are stored as part of the queue.