#include "DebugLogger.h" #include "MsgThread.h" #include "Manager.h" #include using namespace threading; namespace threading { ////// Messages. // Signals child thread to shutdown operation. class FinishMessage : public InputMessage { public: FinishMessage(MsgThread* thread, double network_time) : InputMessage("Finish", thread), network_time(network_time) { } virtual bool Process() { bool result = Object()->OnFinish(network_time); Object()->Finished(); return result; } private: double network_time; }; // A dummy message that's only purpose is unblock the current read operation // so that the child's Run() methods can check the termination status. class UnblockMessage : public InputMessage { public: UnblockMessage(MsgThread* thread) : InputMessage("Unblock", thread) { } virtual bool Process() { return true; } }; /// Sends a heartbeat to the child thread. class HeartbeatMessage : public InputMessage { public: HeartbeatMessage(MsgThread* thread, double arg_network_time, double arg_current_time) : InputMessage("Heartbeat", thread) { network_time = arg_network_time; current_time = arg_current_time; } virtual bool Process() { Object()->HeartbeatInChild(); return Object()->OnHeartbeat(network_time, current_time); } private: double network_time; double current_time; }; // A message from the child to be passed on to the Reporter. class ReporterMessage : public OutputMessage { public: enum Type { INFO, WARNING, ERROR, FATAL_ERROR, FATAL_ERROR_WITH_CORE, INTERNAL_WARNING, INTERNAL_ERROR }; ReporterMessage(Type arg_type, MsgThread* thread, const string& arg_msg) : OutputMessage("ReporterMessage", thread) { type = arg_type; msg = arg_msg; } virtual bool Process(); private: string msg; Type type; }; #ifdef DEBUG // A debug message from the child to be passed on to the DebugLogger. class DebugMessage : public OutputMessage { public: DebugMessage(DebugStream arg_stream, MsgThread* thread, const string& arg_msg) : OutputMessage("DebugMessage", thread) { stream = arg_stream; msg = arg_msg; } virtual bool Process() { string s = Object()->Name() + ": " + msg; debug_logger.Log(stream, "%s", s.c_str()); return true; } private: string msg; DebugStream stream; }; #endif } ////// Methods. Message::~Message() { } bool ReporterMessage::Process() { string s = Object()->Name() + ": " + msg; const char* cmsg = s.c_str(); switch ( type ) { case INFO: reporter->Info("%s", cmsg); break; case WARNING: reporter->Warning("%s", cmsg); break; case ERROR: reporter->Error("%s", cmsg); break; case FATAL_ERROR: reporter->FatalError("%s", cmsg); break; case FATAL_ERROR_WITH_CORE: reporter->FatalErrorWithCore("%s", cmsg); break; case INTERNAL_WARNING: reporter->InternalWarning("%s", cmsg); break; case INTERNAL_ERROR : reporter->InternalError("%s", cmsg); break; default: reporter->InternalError("unknown ReporterMessage type %d", type); } return true; } MsgThread::MsgThread() : BasicThread() { cnt_sent_in = cnt_sent_out = 0; finished = false; stopped = false; thread_mgr->AddMsgThread(this); } // Set by Bro's main signal handler. extern int signal_val; void MsgThread::OnStop() { if ( stopped ) return; // Signal thread to terminate and wait until it has acknowledged. SendIn(new FinishMessage(this, network_time), true); int old_signal_val = signal_val; signal_val = 0; int cnt = 0; bool aborted = 0; while ( ! finished ) { // Terminate if we get another kill signal. if ( signal_val == SIGTERM || signal_val == SIGINT ) { // Abort all threads here so that we won't hang next // on another one. fprintf(stderr, "received signal while waiting for thread %s, aborting all ...\n", Name().c_str()); thread_mgr->KillThreads(); aborted = true; break; } if ( ++cnt % 10000 == 0 ) // Insurance against broken threads ... { fprintf(stderr, "killing thread %s ...\n", Name().c_str()); Kill(); aborted = true; break; } usleep(1000); } Finished(); signal_val = old_signal_val; // One more message to make sure the current queue read operation unblocks. if ( ! aborted ) SendIn(new UnblockMessage(this), true); } void MsgThread::Heartbeat() { SendIn(new HeartbeatMessage(this, network_time, current_time())); } void MsgThread::HeartbeatInChild() { string n = Name(); n = Fmt("bro: %s (%" PRIu64 "/%" PRIu64 ")", n.c_str(), cnt_sent_in - queue_in.Size(), cnt_sent_out - queue_out.Size()); SetOSName(n.c_str()); } void MsgThread::Finished() { // This is thread-safe "enough", we're the only one ever writing // there. finished = true; } void MsgThread::Info(const char* msg) { SendOut(new ReporterMessage(ReporterMessage::INFO, this, msg)); } void MsgThread::Warning(const char* msg) { SendOut(new ReporterMessage(ReporterMessage::WARNING, this, msg)); } void MsgThread::Error(const char* msg) { SendOut(new ReporterMessage(ReporterMessage::ERROR, this, msg)); } void MsgThread::FatalError(const char* msg) { SendOut(new ReporterMessage(ReporterMessage::FATAL_ERROR, this, msg)); } void MsgThread::FatalErrorWithCore(const char* msg) { SendOut(new ReporterMessage(ReporterMessage::FATAL_ERROR_WITH_CORE, this, msg)); } void MsgThread::InternalWarning(const char* msg) { SendOut(new ReporterMessage(ReporterMessage::INTERNAL_WARNING, this, msg)); } void MsgThread::InternalError(const char* msg) { // This one aborts immediately. fprintf(stderr, "internal error in thread: %s\n", msg); abort(); } #ifdef DEBUG void MsgThread::Debug(DebugStream stream, const char* msg) { SendOut(new DebugMessage(stream, this, msg)); } #endif void MsgThread::SendIn(BasicInputMessage* msg, bool force) { if ( Terminating() && ! force ) { delete msg; return; } DBG_LOG(DBG_THREADING, "Sending '%s' to %s ...", msg->Name().c_str(), Name().c_str()); queue_in.Put(msg); ++cnt_sent_in; } void MsgThread::SendOut(BasicOutputMessage* msg, bool force) { if ( Terminating() && ! force ) { delete msg; return; } queue_out.Put(msg); ++cnt_sent_out; } BasicOutputMessage* MsgThread::RetrieveOut() { BasicOutputMessage* msg = queue_out.Get(); assert(msg); DBG_LOG(DBG_THREADING, "Retrieved '%s' from %s", msg->Name().c_str(), Name().c_str()); return msg; } BasicInputMessage* MsgThread::RetrieveIn() { BasicInputMessage* msg = queue_in.Get(); assert(msg); #ifdef DEBUG string s = Fmt("Retrieved '%s' in %s", msg->Name().c_str(), Name().c_str()); Debug(DBG_THREADING, s.c_str()); #endif return msg; } void MsgThread::Run() { while ( ! finished ) { BasicInputMessage* msg = RetrieveIn(); bool result = msg->Process(); if ( ! result ) { string s = msg->Name() + " failed, terminating thread (MsgThread)"; Error(s.c_str()); break; } delete msg; } Finished(); } void MsgThread::GetStats(Stats* stats) { stats->sent_in = cnt_sent_in; stats->sent_out = cnt_sent_out; stats->pending_in = queue_in.Size(); stats->pending_out = queue_out.Size(); queue_in.GetStats(&stats->queue_in_stats); queue_out.GetStats(&stats->queue_out_stats); }