Add new features to IOSource::Manager, used by DNS_Mgr

- iosource_mgr can now track write events to file descriptors as well
  as read events. This adds an argument to both RegisterFd() and
  UnregisterFd() for setting the mode, defaulting to read.
- IOSources can now implement a ProcessFd() method that allows them to
  handle events to single file descriptors instead of of having to
  loop through/track sets of them at processing time.
This commit is contained in:
Tim Wojtulewicz 2022-04-05 21:14:04 -07:00
parent c2bf602d94
commit f9f37b11c6
6 changed files with 179 additions and 63 deletions

View file

@ -537,10 +537,10 @@ static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, in
static void sock_cb(void* data, int s, int read, int write)
{
auto mgr = reinterpret_cast<DNS_Mgr*>(data);
mgr->RegisterSocket(s, read == 1);
mgr->RegisterSocket(s, read == 1, write == 1);
}
DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) : mode(arg_mode)
DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) : IOSource(true), mode(arg_mode)
{
ares_library_init(ARES_LIB_INIT_ALL);
}
@ -554,17 +554,28 @@ DNS_Mgr::~DNS_Mgr()
ares_library_cleanup();
}
void DNS_Mgr::RegisterSocket(int fd, bool active)
void DNS_Mgr::RegisterSocket(int fd, bool read, bool write)
{
if ( active && socket_fds.count(fd) == 0 )
if ( read && socket_fds.count(fd) == 0 )
{
socket_fds.insert(fd);
iosource_mgr->RegisterFd(fd, this);
iosource_mgr->RegisterFd(fd, this, IOSource::READ);
}
else if ( ! active && socket_fds.count(fd) != 0 )
else if ( ! read && socket_fds.count(fd) != 0 )
{
socket_fds.erase(fd);
iosource_mgr->UnregisterFd(fd, this);
iosource_mgr->UnregisterFd(fd, this, IOSource::READ);
}
if ( write && write_socket_fds.count(fd) == 0 )
{
write_socket_fds.insert(fd);
iosource_mgr->RegisterFd(fd, this, IOSource::WRITE);
}
else if ( ! write && write_socket_fds.count(fd) != 0 )
{
write_socket_fds.erase(fd);
iosource_mgr->UnregisterFd(fd, this, IOSource::WRITE);
}
}
@ -1385,20 +1396,13 @@ double DNS_Mgr::GetNextTimeout()
return run_state::network_time + DNS_TIMEOUT;
}
void DNS_Mgr::Process()
void DNS_Mgr::ProcessFd(int fd, int flags)
{
// If iosource_mgr says that we got a result on the socket fd, we don't have to ask c-ares
// to retrieve it for us. We have the file descriptor already, just call ares_process_fd()
// with it. Unfortunately, we may also have sockets close during this call, so we need to
// to make a copy of the list first. Having a list change while looping over it can
// cause segfaults.
decltype(socket_fds) temp_fds{socket_fds};
for ( int fd : temp_fds )
{
// double check this one wasn't removed already before trying to process it
if ( socket_fds.count(fd) != 0 )
ares_process_fd(channel, fd, ARES_SOCKET_BAD);
{
int read_fd = (flags & IOSource::ProcessFlags::READ) != 0 ? fd : ARES_SOCKET_BAD;
int write_fd = (flags & IOSource::ProcessFlags::WRITE) != 0 ? fd : ARES_SOCKET_BAD;
ares_process_fd(channel, read_fd, write_fd);
}
IssueAsyncRequests();

View file

@ -240,7 +240,9 @@ public:
/**
* Used by the c-ares socket call back to register/unregister a socket file descriptor.
*/
void RegisterSocket(int fd, bool active);
void RegisterSocket(int fd, bool read, bool write);
ares_channel& GetChannel() { return channel; }
protected:
friend class LookupCallback;
@ -277,7 +279,8 @@ protected:
void IssueAsyncRequests();
// IOSource interface.
void Process() override;
void Process() override { }
void ProcessFd(int fd, int flags) override;
void InitSource() override;
const char* Tag() override { return "DNS_Mgr"; }
double GetNextTimeout() override;
@ -337,6 +340,7 @@ protected:
unsigned long failed = 0;
std::set<int> socket_fds;
std::set<int> write_socket_fds;
};
extern DNS_Mgr* dns_mgr;

View file

@ -283,7 +283,7 @@ void run_loop()
{
util::detail::set_processing_status("RUNNING", "run_loop");
std::vector<iosource::IOSource*> ready;
iosource::Manager::ReadySources ready;
ready.reserve(iosource_mgr->TotalSize());
while ( iosource_mgr->Size() || (BifConst::exit_only_after_terminate && ! terminating) )
@ -310,11 +310,16 @@ void run_loop()
if ( ! ready.empty() )
{
for ( auto src : ready )
for ( const auto& src : ready )
{
DBG_LOG(DBG_MAINLOOP, "processing source %s", src->Tag());
current_iosrc = src;
src->Process();
auto* iosrc = src.src;
DBG_LOG(DBG_MAINLOOP, "processing source %s", iosrc->Tag());
current_iosrc = iosrc;
if ( iosrc->ImplementsProcessFd() && src.fd != -1 )
iosrc->ProcessFd(src.fd, src.flags);
else
iosrc->Process();
}
}
else if ( (have_pending_timers || communication_enabled ||

View file

@ -12,10 +12,20 @@ namespace zeek::iosource
class IOSource
{
public:
enum ProcessFlags
{
READ = 0x01,
WRITE = 0x02
};
/**
* Constructor.
*
* @param process_fd A flag for indicating whether the child class implements
* the ProcessFd() method. This is used by the run loop for dispatching to the
* appropriate process method.
*/
IOSource() { closed = false; }
IOSource(bool process_fd = false) : implements_process_fd(process_fd) { }
/**
* Destructor.
@ -66,6 +76,19 @@ public:
*/
virtual void Process() = 0;
/**
* Optional process method that allows an IOSource to only process
* the file descriptor that is found ready and not every possible
* descriptor. If this method is implemented, true must be passed
* to the IOSource constructor via the child class.
*
* @param fd The file descriptor to process.
* @param flags Flags indicating what type of event is being
* processed.
*/
virtual void ProcessFd(int fd, int flags) { }
bool ImplementsProcessFd() const { return implements_process_fd; }
/**
* Returns a descriptive tag representing the source for debugging.
*
@ -84,7 +107,8 @@ protected:
void SetClosed(bool is_closed) { closed = is_closed; }
private:
bool closed;
bool closed = false;
bool implements_process_fd = false;
};
} // namespace zeek::iosource

View file

@ -103,7 +103,7 @@ void Manager::Wakeup(const std::string& where)
wakeup->Ping(where);
}
void Manager::FindReadySources(std::vector<IOSource*>* ready)
void Manager::FindReadySources(ReadySources* ready)
{
ready->clear();
@ -155,7 +155,7 @@ void Manager::FindReadySources(std::vector<IOSource*>* ready)
if ( timeout == 0 && ! time_to_poll )
{
added = true;
ready->push_back(timeout_src);
ready->push_back({timeout_src, -1, 0});
}
}
@ -167,13 +167,13 @@ void Manager::FindReadySources(std::vector<IOSource*>* ready)
// Avoid calling Poll() if we can help it since on very
// high-traffic networks, we spend too much time in
// Poll() and end up dropping packets.
ready->push_back(pkt_src);
ready->push_back({pkt_src, -1, 0});
}
else
{
if ( ! run_state::pseudo_realtime && ! time_to_poll )
// A pcap file is always ready to process unless it's suspended
ready->push_back(pkt_src);
ready->push_back({pkt_src, -1, 0});
}
}
}
@ -189,7 +189,7 @@ void Manager::FindReadySources(std::vector<IOSource*>* ready)
Poll(ready, timeout, timeout_src);
}
void Manager::Poll(std::vector<IOSource*>* ready, double timeout, IOSource* timeout_src)
void Manager::Poll(ReadySources* ready, double timeout, IOSource* timeout_src)
{
struct timespec kqueue_timeout;
ConvertTimeout(timeout, kqueue_timeout);
@ -205,7 +205,7 @@ void Manager::Poll(std::vector<IOSource*>* ready, double timeout, IOSource* time
else if ( ret == 0 )
{
if ( timeout_src )
ready->push_back(timeout_src);
ready->push_back({timeout_src, -1, 0});
}
else
{
@ -217,7 +217,15 @@ void Manager::Poll(std::vector<IOSource*>* ready, double timeout, IOSource* time
{
std::map<int, IOSource*>::const_iterator it = fd_map.find(events[i].ident);
if ( it != fd_map.end() )
ready->push_back(it->second);
ready->push_back({it->second, static_cast<int>(events[i].ident),
IOSource::ProcessFlags::READ});
}
else if ( events[i].filter == EVFILT_WRITE )
{
std::map<int, IOSource*>::const_iterator it = write_fd_map.find(events[i].ident);
if ( it != write_fd_map.end() )
ready->push_back({it->second, static_cast<int>(events[i].ident),
IOSource::ProcessFlags::WRITE});
}
}
}
@ -240,48 +248,106 @@ void Manager::ConvertTimeout(double timeout, struct timespec& spec)
}
}
bool Manager::RegisterFd(int fd, IOSource* src)
bool Manager::RegisterFd(int fd, IOSource* src, int flags)
{
struct kevent event;
EV_SET(&event, fd, EVFILT_READ, EV_ADD, 0, 0, NULL);
int ret = kevent(event_queue, &event, 1, NULL, 0, NULL);
std::vector<struct kevent> new_events;
if ( (flags & IOSource::READ) != 0 )
{
if ( fd_map.count(fd) == 0 )
{
new_events.push_back({});
EV_SET(&(new_events.back()), fd, EVFILT_READ, EV_ADD, 0, 0, NULL);
}
}
if ( (flags & IOSource::WRITE) != 0 )
{
if ( write_fd_map.count(fd) == 0 )
{
new_events.push_back({});
EV_SET(&(new_events.back()), fd, EVFILT_WRITE, EV_ADD, 0, 0, NULL);
}
}
if ( ! new_events.empty() )
{
int ret = kevent(event_queue, new_events.data(), new_events.size(), NULL, 0, NULL);
if ( ret != -1 )
{
events.push_back({});
DBG_LOG(DBG_MAINLOOP, "Registered fd %d from %s", fd, src->Tag());
for ( const auto& a : new_events )
events.push_back({});
if ( (flags & IOSource::READ) != 0 )
fd_map[fd] = src;
if ( (flags & IOSource::WRITE) != 0 )
write_fd_map[fd] = src;
Wakeup("RegisterFd");
return true;
}
else
{
reporter->Error("Failed to register fd %d from %s: %s", fd, src->Tag(), strerror(errno));
reporter->Error("Failed to register fd %d from %s: %s (flags %d)", fd, src->Tag(),
strerror(errno), flags);
return false;
}
}
bool Manager::UnregisterFd(int fd, IOSource* src)
{
if ( fd_map.find(fd) != fd_map.end() )
{
struct kevent event;
EV_SET(&event, fd, EVFILT_READ, EV_DELETE, 0, 0, NULL);
int ret = kevent(event_queue, &event, 1, NULL, 0, NULL);
if ( ret != -1 )
DBG_LOG(DBG_MAINLOOP, "Unregistered fd %d from %s", fd, src->Tag());
return true;
}
bool Manager::UnregisterFd(int fd, IOSource* src, int flags)
{
std::vector<struct kevent> new_events;
if ( (flags & IOSource::READ) != 0 )
{
if ( fd_map.count(fd) != 0 )
{
new_events.push_back({});
EV_SET(&(new_events.back()), fd, EVFILT_READ, EV_DELETE, 0, 0, NULL);
}
}
if ( (flags & IOSource::WRITE) != 0 )
{
if ( write_fd_map.count(fd) != 0 )
{
new_events.push_back({});
EV_SET(&(new_events.back()), fd, EVFILT_WRITE, EV_DELETE, 0, 0, NULL);
}
}
if ( ! new_events.empty() )
{
int ret = kevent(event_queue, new_events.data(), new_events.size(), NULL, 0, NULL);
if ( ret != -1 )
{
DBG_LOG(DBG_MAINLOOP, "Unregistered fd %d from %s", fd, src->Tag());
for ( const auto& a : new_events )
events.pop_back();
if ( (flags & IOSource::READ) != 0 )
fd_map.erase(fd);
if ( (flags & IOSource::WRITE) != 0 )
write_fd_map.erase(fd);
Wakeup("UnregisterFd");
return true;
}
// We don't care about failure here. If it failed to unregister, it's likely because
// the file descriptor was already closed, and kqueue already automatically removed
// it.
}
else
{
reporter->Error("Attempted to unregister an unknown file descriptor %d from %s", fd,
src->Tag());
return false;
}
return true;
}
void Manager::Register(IOSource* src, bool dont_count, bool manage_lifetime)

View file

@ -29,6 +29,15 @@ class PktDumper;
class Manager
{
public:
struct ReadySource
{
IOSource* src = nullptr;
int fd = -1;
int flags = 0;
};
using ReadySources = std::vector<ReadySource>;
/**
* Constructor.
*/
@ -110,22 +119,25 @@ public:
*
* @param ready A vector used to return the set of sources that are ready.
*/
void FindReadySources(std::vector<IOSource*>* ready);
void FindReadySources(ReadySources* ready);
/**
* Registers a file descriptor and associated IOSource with the manager
* to be checked during FindReadySources.
* to be checked during FindReadySources. This will register the file
* descriptor to check for read events.
*
* @param fd A file descriptor pointing at some resource that should be
* checked for readiness.
* @param src The IOSource that owns the file descriptor.
* @param flags A combination of values from IOSource::ProcessFlags for
* which modes we should register for this file descriptor.
*/
bool RegisterFd(int fd, IOSource* src);
bool RegisterFd(int fd, IOSource* src, int flags = IOSource::READ);
/**
* Unregisters a file descriptor from the FindReadySources checks.
*/
bool UnregisterFd(int fd, IOSource* src);
bool UnregisterFd(int fd, IOSource* src, int flags = IOSource::READ);
/**
* Forces the poll in FindReadySources to wake up immediately. This method
@ -147,7 +159,7 @@ private:
* @param timeout_src The source associated with the current timeout value.
* This is typically a timer manager object.
*/
void Poll(std::vector<IOSource*>* ready, double timeout, IOSource* timeout_src);
void Poll(ReadySources* ready, double timeout, IOSource* timeout_src);
/**
* Converts a double timeout value into a timespec struct used for calls
@ -208,6 +220,7 @@ private:
int event_queue = -1;
std::map<int, IOSource*> fd_map;
std::map<int, IOSource*> write_fd_map;
// This is only used for the output of the call to kqueue in FindReadySources().
// The actual events are stored as part of the queue.