Add a simple FD_Set wrapper/helper class.

This commit is contained in:
Jon Siwek 2014-09-09 16:28:04 -05:00
parent cf66bd8b69
commit 59c54a0fc6
16 changed files with 149 additions and 90 deletions

View file

@ -630,11 +630,11 @@ bool ChunkedIOFd::IsFillingUp()
return stats.pending > MAX_BUFFERED_CHUNKS_SOFT; return stats.pending > MAX_BUFFERED_CHUNKS_SOFT;
} }
std::vector<int> ChunkedIOFd::FdSupplements() const iosource::FD_Set ChunkedIOFd::ExtraReadFDs() const
{ {
std::vector<int> rval; iosource::FD_Set rval;
rval.push_back(write_flare.FD()); rval.Insert(write_flare.FD());
rval.push_back(read_flare.FD()); rval.Insert(read_flare.FD());
return rval; return rval;
} }
@ -1140,10 +1140,10 @@ bool ChunkedIOSSL::IsFillingUp()
return false; return false;
} }
std::vector<int> ChunkedIOSSL::FdSupplements() const iosource::FD_Set ChunkedIOSSL::ExtraReadFDs() const
{ {
std::vector<int> rval; iosource::FD_Set rval;
rval.push_back(write_flare.FD()); rval.Insert(write_flare.FD());
return rval; return rval;
} }

View file

@ -7,8 +7,8 @@
#include "List.h" #include "List.h"
#include "util.h" #include "util.h"
#include "Flare.h" #include "Flare.h"
#include "iosource/FD_Set.h"
#include <list> #include <list>
#include <vector>
#ifdef NEED_KRB5_H #ifdef NEED_KRB5_H
# include <krb5.h> # include <krb5.h>
@ -98,8 +98,8 @@ public:
// Returns supplementary file descriptors that become read-ready in order // Returns supplementary file descriptors that become read-ready in order
// to signal that there is some work that can be performed. // to signal that there is some work that can be performed.
virtual std::vector<int> FdSupplements() const virtual iosource::FD_Set ExtraReadFDs() const
{ return std::vector<int>(); } { return iosource::FD_Set(); }
// Makes sure that no additional protocol data is written into // Makes sure that no additional protocol data is written into
// the output stream. If this is activated, the output cannot // the output stream. If this is activated, the output cannot
@ -183,7 +183,7 @@ public:
virtual void Clear(); virtual void Clear();
virtual bool Eof() { return eof; } virtual bool Eof() { return eof; }
virtual int Fd() { return fd; } virtual int Fd() { return fd; }
virtual std::vector<int> FdSupplements() const; virtual iosource::FD_Set ExtraReadFDs() const;
virtual void Stats(char* buffer, int length); virtual void Stats(char* buffer, int length);
private: private:
@ -271,7 +271,7 @@ public:
virtual void Clear(); virtual void Clear();
virtual bool Eof() { return eof; } virtual bool Eof() { return eof; }
virtual int Fd() { return socket; } virtual int Fd() { return socket; }
virtual std::vector<int> FdSupplements() const; virtual iosource::FD_Set ExtraReadFDs() const;
virtual void Stats(char* buffer, int length); virtual void Stats(char* buffer, int length);
private: private:
@ -340,8 +340,8 @@ public:
virtual bool Eof() { return io->Eof(); } virtual bool Eof() { return io->Eof(); }
virtual int Fd() { return io->Fd(); } virtual int Fd() { return io->Fd(); }
virtual std::vector<int> FdSupplements() const virtual iosource::FD_Set ExtraReadFDs() const
{ return io->FdSupplements(); } { return io->ExtraReadFDs(); }
virtual void Stats(char* buffer, int length); virtual void Stats(char* buffer, int length);
void EnableCompression(int level) void EnableCompression(int level)

View file

@ -1216,10 +1216,10 @@ void DNS_Mgr::IssueAsyncRequests()
} }
} }
void DNS_Mgr::GetFds(std::vector<int>* read, std::vector<int>* write, void DNS_Mgr::GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except) iosource::FD_Set* except)
{ {
read->push_back(nb_dns_fd(nb_dns)); read->Insert(nb_dns_fd(nb_dns));
} }
double DNS_Mgr::NextTimestamp(double* network_time) double DNS_Mgr::NextTimestamp(double* network_time)

View file

@ -132,8 +132,8 @@ protected:
void DoProcess(bool flush); void DoProcess(bool flush);
// IOSource interface. // IOSource interface.
virtual void GetFds(std::vector<int>* read, std::vector<int>* write, virtual void GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except); iosource::FD_Set* except);
virtual double NextTimestamp(double* network_time); virtual double NextTimestamp(double* network_time);
virtual void Process(); virtual void Process();
virtual const char* Tag() { return "DNS_Mgr"; } virtual const char* Tag() { return "DNS_Mgr"; }

View file

@ -1367,17 +1367,14 @@ void RemoteSerializer::Unregister(ID* id)
} }
} }
void RemoteSerializer::GetFds(std::vector<int>* read, std::vector<int>* write, void RemoteSerializer::GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except) iosource::FD_Set* except)
{ {
read->push_back(io->Fd()); read->Insert(io->Fd());
std::vector<int> supp = io->FdSupplements(); read->Aggregate(io->ExtraReadFDs());
for ( size_t i = 0; i < supp.size(); ++i )
read->push_back(supp[i]);
if ( io->CanWrite() ) if ( io->CanWrite() )
write->push_back(io->Fd()); write->Insert(io->Fd());
} }
double RemoteSerializer::NextTimestamp(double* local_network_time) double RemoteSerializer::NextTimestamp(double* local_network_time)
@ -3390,11 +3387,9 @@ void SocketComm::Run()
FD_ZERO(&fd_write); FD_ZERO(&fd_write);
FD_ZERO(&fd_except); FD_ZERO(&fd_except);
int max_fd = 0; int max_fd = io->Fd();
FD_SET(io->Fd(), &fd_read); FD_SET(io->Fd(), &fd_read);
max_fd = io->Fd(); max_fd = std::max(max_fd, io->ExtraReadFDs().Set(&fd_read));
fd_vector_set(io->FdSupplements(), &fd_read, &max_fd);
loop_over_list(peers, i) loop_over_list(peers, i)
{ {
@ -3403,7 +3398,8 @@ void SocketComm::Run()
FD_SET(peers[i]->io->Fd(), &fd_read); FD_SET(peers[i]->io->Fd(), &fd_read);
if ( peers[i]->io->Fd() > max_fd ) if ( peers[i]->io->Fd() > max_fd )
max_fd = peers[i]->io->Fd(); max_fd = peers[i]->io->Fd();
fd_vector_set(peers[i]->io->FdSupplements(), &fd_read, &max_fd); max_fd = std::max(max_fd,
peers[i]->io->ExtraReadFDs().Set(&fd_read));
} }
else else
{ {

View file

@ -140,8 +140,8 @@ public:
void Finish(); void Finish();
// Overidden from IOSource: // Overidden from IOSource:
virtual void GetFds(std::vector<int>* read, std::vector<int>* write, virtual void GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except); iosource::FD_Set* except);
virtual double NextTimestamp(double* local_network_time); virtual double NextTimestamp(double* local_network_time);
virtual void Process(); virtual void Process();
virtual TimerMgr::Tag* GetCurrentTag(); virtual TimerMgr::Tag* GetCurrentTag();

View file

@ -1068,10 +1068,10 @@ void EventPlayer::GotFunctionCall(const char* name, double time,
// We don't replay function calls. // We don't replay function calls.
} }
void EventPlayer::GetFds(std::vector<int>* read, std::vector<int>* write, void EventPlayer::GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except) iosource::FD_Set* except)
{ {
read->push_back(fd); read->Insert(fd);
} }
double EventPlayer::NextTimestamp(double* local_network_time) double EventPlayer::NextTimestamp(double* local_network_time)

View file

@ -355,8 +355,8 @@ public:
EventPlayer(const char* file); EventPlayer(const char* file);
virtual ~EventPlayer(); virtual ~EventPlayer();
virtual void GetFds(std::vector<int>* read, std::vector<int>* write, virtual void GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except); iosource::FD_Set* except);
virtual double NextTimestamp(double* local_network_time); virtual double NextTimestamp(double* local_network_time);
virtual void Process(); virtual void Process();
virtual const char* Tag() { return "EventPlayer"; } virtual const char* Tag() { return "EventPlayer"; }

83
src/iosource/FD_Set.h Normal file
View file

@ -0,0 +1,83 @@
#ifndef BRO_FD_SET_H
#define BRO_FD_SET_H
#include <set>
#include <sys/select.h>
namespace iosource {
/**
* A container holding a set of file descriptors.
*/
class FD_Set {
public:
/**
* Constructor. The set is initially empty.
*/
FD_Set()
: max(-1), fds()
{ }
/**
* Insert a file descriptor in to the set.
* @param fd the fd to insert in the set.
* @return false if fd was already in the set, else true.
*/
bool Insert(int fd)
{
if ( max < fd ) max = fd;
return fds.insert(fd).second;
}
/**
* Inserts all the file descriptors from another set in to this one.
* @param other a file descriptor set to merge in to this one.
*/
void Aggregate(const FD_Set& other)
{
for ( std::set<int>::const_iterator it = other.fds.begin();
it != other.fds.end(); ++it )
Insert(*it);
}
/**
* Empties the set.
*/
void Clear()
{ max = -1; fds.clear(); }
/**
* Insert file descriptors in to a fd_set for use with select().
* @return the greatest file descriptor inserted.
*/
int Set(fd_set* set) const
{
for ( std::set<int>::const_iterator it = fds.begin(); it != fds.end();
++it )
FD_SET(*it, set);
return max;
}
/**
* @return Whether a file descriptor belonging to this set is within the
* fd_set arugment.
*/
bool Ready(fd_set* set) const
{
for ( std::set<int>::const_iterator it = fds.begin(); it != fds.end();
++it )
if ( FD_ISSET(*it, set) )
return true;
return false;
}
private:
int max;
std::set<int> fds;
};
} // namespace bro
#endif // BRO_FD_SET_H

View file

@ -8,8 +8,7 @@ extern "C" {
} }
#include <string> #include <string>
#include <vector> #include "FD_Set.h"
#include "Timer.h" #include "Timer.h"
namespace iosource { namespace iosource {
@ -62,8 +61,7 @@ public:
* *
* @param except Pointer to container where to insert a except descriptor. * @param except Pointer to container where to insert a except descriptor.
*/ */
virtual void GetFds(std::vector<int>* read, std::vector<int>* write, virtual void GetFds(FD_Set* read, FD_Set* write, FD_Set* except) = 0;
std::vector<int>* except) = 0;
/** /**
* Returns the timestamp (in \a global network time) associated with * Returns the timestamp (in \a global network time) associated with

View file

@ -44,15 +44,6 @@ void Manager::RemoveAll()
dont_counts = sources.size(); dont_counts = sources.size();
} }
static void fd_vector_set(const std::vector<int>& fds, fd_set* set, int* max)
{
for ( size_t i = 0; i < fds.size(); ++i )
{
FD_SET(fds[i], set);
*max = ::max(fds[i], *max);
}
}
IOSource* Manager::FindSoonest(double* ts) IOSource* Manager::FindSoonest(double* ts)
{ {
// Remove sources which have gone dry. For simplicity, we only // Remove sources which have gone dry. For simplicity, we only
@ -124,14 +115,9 @@ IOSource* Manager::FindSoonest(double* ts)
// be ready. // be ready.
continue; continue;
src->fd_read.clear(); src->Clear();
src->fd_write.clear();
src->fd_except.clear();
src->src->GetFds(&src->fd_read, &src->fd_write, &src->fd_except); src->src->GetFds(&src->fd_read, &src->fd_write, &src->fd_except);
src->SetFds(&fd_read, &fd_write, &fd_except, &maxx);
fd_vector_set(src->fd_read, &fd_read, &maxx);
fd_vector_set(src->fd_write, &fd_write, &maxx);
fd_vector_set(src->fd_except, &fd_except, &maxx);
} }
// We can't block indefinitely even when all sources are dry: // We can't block indefinitely even when all sources are dry:
@ -316,21 +302,10 @@ PktDumper* Manager::OpenPktDumper(const string& path, bool append)
return pd; return pd;
} }
static bool fd_vector_ready(const std::vector<int>& fds, fd_set* set) void Manager::Source::SetFds(fd_set* read, fd_set* write, fd_set* except,
int* maxx) const
{ {
for ( size_t i = 0; i < fds.size(); ++i ) *maxx = std::max(*maxx, fd_read.Set(read));
if ( FD_ISSET(fds[i], set) ) *maxx = std::max(*maxx, fd_write.Set(write));
return true; *maxx = std::max(*maxx, fd_except.Set(except));
return false;
}
bool Manager::Source::Ready(fd_set* read, fd_set* write, fd_set* except) const
{
if ( fd_vector_ready(fd_read, read) ||
fd_vector_ready(fd_write, write) ||
fd_vector_ready(fd_except, except) )
return true;
return false;
} }

View file

@ -5,8 +5,7 @@
#include <string> #include <string>
#include <list> #include <list>
#include <vector> #include "iosource/FD_Set.h"
#include <sys/select.h>
namespace iosource { namespace iosource {
@ -115,11 +114,19 @@ private:
struct Source { struct Source {
IOSource* src; IOSource* src;
std::vector<int> fd_read; FD_Set fd_read;
std::vector<int> fd_write; FD_Set fd_write;
std::vector<int> fd_except; FD_Set fd_except;
bool Ready(fd_set* read, fd_set* write, fd_set* except) const; bool Ready(fd_set* read, fd_set* write, fd_set* except) const
{ return fd_read.Ready(read) || fd_write.Ready(write) ||
fd_except.Ready(except); }
void SetFds(fd_set* read, fd_set* write, fd_set* except,
int* maxx) const;
void Clear()
{ fd_read.Clear(); fd_write.Clear(); fd_except.Clear(); }
}; };
typedef std::list<Source*> SourceList; typedef std::list<Source*> SourceList;

View file

@ -218,8 +218,8 @@ void PktSrc::Done()
Close(); Close();
} }
void PktSrc::GetFds(std::vector<int>* read, std::vector<int>* write, void PktSrc::GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except) iosource::FD_Set* except)
{ {
if ( pseudo_realtime ) if ( pseudo_realtime )
{ {
@ -230,7 +230,7 @@ void PktSrc::GetFds(std::vector<int>* read, std::vector<int>* write,
} }
if ( IsOpen() && props.selectable_fd >= 0 ) if ( IsOpen() && props.selectable_fd >= 0 )
read->push_back(props.selectable_fd); read->Insert(props.selectable_fd);
} }
double PktSrc::NextTimestamp(double* local_network_time) double PktSrc::NextTimestamp(double* local_network_time)

View file

@ -388,8 +388,8 @@ private:
// IOSource interface implementation. // IOSource interface implementation.
virtual void Init(); virtual void Init();
virtual void Done(); virtual void Done();
virtual void GetFds(std::vector<int>* read, std::vector<int>* write, virtual void GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except); iosource::FD_Set* except);
virtual double NextTimestamp(double* local_network_time); virtual double NextTimestamp(double* local_network_time);
virtual void Process(); virtual void Process();
virtual const char* Tag(); virtual const char* Tag();

View file

@ -65,8 +65,8 @@ void Manager::AddMsgThread(MsgThread* thread)
msg_threads.push_back(thread); msg_threads.push_back(thread);
} }
void Manager::GetFds(std::vector<int>* read, std::vector<int>* write, void Manager::GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except) iosource::FD_Set* except)
{ {
} }

View file

@ -103,8 +103,8 @@ protected:
/** /**
* Part of the IOSource interface. * Part of the IOSource interface.
*/ */
virtual void GetFds(std::vector<int>* read, std::vector<int>* write, virtual void GetFds(iosource::FD_Set* read, iosource::FD_Set* write,
std::vector<int>* except); iosource::FD_Set* except);
/** /**
* Part of the IOSource interface. * Part of the IOSource interface.