diff --git a/aux/broker b/aux/broker index 331966d1f3..1e8d675790 160000 --- a/aux/broker +++ b/aux/broker @@ -1 +1 @@ -Subproject commit 331966d1f3d24c63bedbda79e477f759c4d267f9 +Subproject commit 1e8d6757909750524c15f8eaf3c297243bc55425 diff --git a/scripts/base/frameworks/comm/main.bro b/scripts/base/frameworks/comm/main.bro index af4225f5dd..c69d36db52 100644 --- a/scripts/base/frameworks/comm/main.bro +++ b/scripts/base/frameworks/comm/main.bro @@ -4,4 +4,10 @@ module Comm; export { const endpoint_name = "" &redef; + + type SendFlags: record { + self: bool &default = F; + peers: bool &default = T; + unsolicited: bool &default = F; + }; } diff --git a/src/comm/Manager.cc b/src/comm/Manager.cc index 29ff71d7e0..7027daa79e 100644 --- a/src/comm/Manager.cc +++ b/src/comm/Manager.cc @@ -7,13 +7,31 @@ #include "Reporter.h" #include "comm/comm.bif.h" +using namespace std; + bool comm::Manager::InitPreScript() { return true; } +static int require_field(const RecordType* rt, const char* name) + { + auto rval = rt->FieldOffset(name); + + if ( rval < 0 ) + reporter->InternalError("no field named '%s' in record type '%s'", name, + rt->GetName().data()); + + return rval; + } + bool comm::Manager::InitPostScript() { + auto send_flags_type = internal_type("Comm::SendFlags")->AsRecordType(); + send_flags_self_idx = require_field(send_flags_type, "self"); + send_flags_peers_idx = require_field(send_flags_type, "peers"); + send_flags_unsolicited_idx = require_field(send_flags_type, "unsolicited"); + auto res = broker::init(); if ( res ) @@ -37,7 +55,7 @@ bool comm::Manager::InitPostScript() name = fmt("bro@.%ld", static_cast(getpid())); } - endpoint = std::unique_ptr(new broker::endpoint(name)); + endpoint = unique_ptr(new broker::endpoint(name)); return true; } @@ -56,31 +74,81 @@ bool comm::Manager::Listen(uint16_t port, const char* addr) } bool comm::Manager::Connect(string addr, uint16_t port, - std::chrono::duration retry_interval) + chrono::duration retry_interval) { - auto& peer = peers[std::make_pair(addr, port)]; + auto& peer = peers[make_pair(addr, port)]; if ( peer ) return false; - peer = endpoint->peer(std::move(addr), port, retry_interval); + peer = endpoint->peer(move(addr), port, retry_interval); return true; } bool comm::Manager::Disconnect(const string& addr, uint16_t port) { - auto it = peers.find(std::make_pair(addr, port)); + auto it = peers.find(make_pair(addr, port)); if ( it == peers.end() ) return false; - return endpoint->unpeer(it->second); + auto rval = endpoint->unpeer(it->second); + peers.erase(it); + return rval; + } + +bool comm::Manager::Print(string topic, string msg, const Val* flags) + { + endpoint->send(move(topic), broker::message{move(msg)}, get_flags(flags)); + return true; + } + +bool comm::Manager::SubscribeToPrints(string topic_prefix) + { + auto& q = print_subscriptions[topic_prefix]; + + if ( q ) + return false; + + q = broker::message_queue(move(topic_prefix), *endpoint); + return true; + } + +bool comm::Manager::UnsubscribeToPrints(const string& topic_prefix) + { + return print_subscriptions.erase(topic_prefix); + } + +int comm::Manager::get_flags(const Val* flags) + { + auto r = flags->AsRecordVal(); + int rval = 0; + Val* self_flag = r->LookupWithDefault(send_flags_self_idx); + Val* peers_flag = r->LookupWithDefault(send_flags_peers_idx); + Val* unsolicited_flag = r->LookupWithDefault(send_flags_unsolicited_idx); + + if ( self_flag->AsBool() ) + rval |= broker::SELF; + + if ( peers_flag->AsBool() ) + rval |= broker::PEERS; + + if ( unsolicited_flag->AsBool() ) + rval |= broker::UNSOLICITED; + + Unref(self_flag); + Unref(peers_flag); + Unref(unsolicited_flag); + return rval; } void comm::Manager::GetFds(iosource::FD_Set* read, iosource::FD_Set* write, iosource::FD_Set* except) { read->Insert(endpoint->peer_status().fd()); + + for ( const auto& ps : print_subscriptions ) + read->Insert(ps.second.fd()); } double comm::Manager::NextTimestamp(double* local_network_time) @@ -147,5 +215,41 @@ void comm::Manager::Process() } } + for ( const auto& ps : print_subscriptions ) + { + auto print_messages = ps.second.want_pop(); + + if ( print_messages.empty() ) + continue; + + idle = false; + + if ( ! Comm::print_handler ) + continue; + + for ( auto& pm : print_messages ) + { + if ( pm.size() != 1 ) + { + reporter->Warning("got print message of invalid size: %zd", + pm.size()); + continue; + } + + std::string* msg = broker::get(pm[0]); + + if ( ! msg ) + { + reporter->Warning("got print message of invalid type: %d", + static_cast(broker::which(pm[0]))); + continue; + } + + val_list* vl = new val_list; + vl->append(new StringVal(move(*msg))); + mgr.QueueEvent(Comm::print_handler, vl); + } + } + SetIdle(idle); } diff --git a/src/comm/Manager.h b/src/comm/Manager.h index 412c125d14..0f7d5a4a1c 100644 --- a/src/comm/Manager.h +++ b/src/comm/Manager.h @@ -2,6 +2,7 @@ #define BRO_COMM_MANAGER_H #include +#include #include #include #include @@ -28,8 +29,16 @@ public: bool Disconnect(const std::string& addr, uint16_t port); + bool Print(std::string topic, std::string msg, const Val* flags); + + bool SubscribeToPrints(std::string topic_prefix); + + bool UnsubscribeToPrints(const std::string& topic_prefix); + private: + int get_flags(const Val* flags); + // IOSource interface overrides: void GetFds(iosource::FD_Set* read, iosource::FD_Set* write, iosource::FD_Set* except) override; @@ -43,6 +52,11 @@ private: std::unique_ptr endpoint; std::map, broker::peering> peers; + std::map print_subscriptions; + + int send_flags_self_idx; + int send_flags_peers_idx; + int send_flags_unsolicited_idx; }; } // namespace comm diff --git a/src/comm/comm.bif b/src/comm/comm.bif index 67933df20e..6294864bba 100644 --- a/src/comm/comm.bif +++ b/src/comm/comm.bif @@ -1,10 +1,12 @@ -module Comm; - %%{ #include "comm/Manager.h" %%} +module Comm; + +type Comm::SendFlags: record; + event Comm::remote_connection_established%(peer_address: string, peer_port: port, peer_name: string%); @@ -13,7 +15,7 @@ event Comm::remote_connection_broken%(peer_address: string, event Comm::remote_connection_incompatible%(peer_address: string, peer_port: port%); -function Comm::listen%(p: port, a: string &default=""%): bool +function Comm::listen%(p: port, a: string &default = ""%): bool %{ if ( ! p->IsTCP() ) { @@ -49,3 +51,25 @@ function Comm::disconnect%(a: string, p: port%): bool auto rval = comm_mgr->Disconnect(a->CheckString(), p->Port()); return new Val(rval, TYPE_BOOL); %} + +event Comm::print_handler%(msg: string%); + +function Comm::print%(topic: string, msg: string, + flags: SendFlags &default = SendFlags()%): bool + %{ + auto rval = comm_mgr->Print(topic->CheckString(), msg->CheckString(), + flags); + return new Val(rval, TYPE_BOOL); + %} + +function Comm::subscribe_to_prints%(topic_prefix: string &default = ""%): bool + %{ + auto rval = comm_mgr->SubscribeToPrints(topic_prefix->CheckString()); + return new Val(rval, TYPE_BOOL); + %} + +function Comm::unsubscribe_to_prints%(topic_prefix: string &default = ""%): bool + %{ + auto rval = comm_mgr->UnsubscribeToPrints(topic_prefix->CheckString()); + return new Val(rval, TYPE_BOOL); + %}