diff --git a/src/Conn.cc b/src/Conn.cc index 8d47402825..fad1abf31b 100644 --- a/src/Conn.cc +++ b/src/Conn.cc @@ -28,11 +28,9 @@ uint64_t Connection::current_connections = 0; Connection::Connection(const detail::ConnIDKey& k, double t, const ConnID* id, uint32_t flow, const Packet* pkt) : Session(t, connection_timeout, connection_status_update, - detail::connection_status_update_interval) + detail::connection_status_update_interval), + key(k) { - key = k; - key_valid = true; - orig_addr = id->src_addr; resp_addr = id->dst_addr; orig_port = id->src_port; diff --git a/src/Conn.h b/src/Conn.h index b88ad3c7a3..ec0cdc1b88 100644 --- a/src/Conn.h +++ b/src/Conn.h @@ -116,8 +116,6 @@ public: const detail::ConnIDKey& Key() const { return key; } detail::SessionKey SessionKey(bool copy) const override { return detail::SessionKey{&key, sizeof(key), copy}; } - void ClearKey() override { key_valid = false; } - bool IsKeyValid() const override { return key_valid; } const IPAddr& OrigAddr() const { return orig_addr; } const IPAddr& RespAddr() const { return resp_addr; } @@ -271,7 +269,6 @@ protected: std::shared_ptr encapsulation; // tunnels detail::ConnIDKey key; - bool key_valid; unsigned int skip:1; unsigned int weird:1; diff --git a/src/IPAddr.cc b/src/IPAddr.cc index 365cef76ba..70809353f3 100644 --- a/src/IPAddr.cc +++ b/src/IPAddr.cc @@ -15,38 +15,43 @@ namespace zeek { const IPAddr IPAddr::v4_unspecified = IPAddr(in4_addr{}); - const IPAddr IPAddr::v6_unspecified = IPAddr(); -detail::ConnIDKey detail::BuildConnIDKey(const ConnID& id) - { - ConnIDKey key; +namespace detail { - // Lookup up connection based on canonical ordering, which is +ConnIDKey::ConnIDKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, + uint16_t dst_port, TransportProto t, bool one_way) + : transport(t) + { + // Lookup up connection based on canonical ordering, which is // the smaller of and // followed by the other. - if ( id.is_one_way || - addr_port_canon_lt(id.src_addr, id.src_port, id.dst_addr, id.dst_port) + if ( one_way || + addr_port_canon_lt(src, src_port, dst, dst_port) ) { - key.ip1 = id.src_addr.in6; - key.ip2 = id.dst_addr.in6; - key.port1 = id.src_port; - key.port2 = id.dst_port; + ip1 = src.in6; + ip2 = dst.in6; + port1 = src_port; + port2 = dst_port; } else { - key.ip1 = id.dst_addr.in6; - key.ip2 = id.src_addr.in6; - key.port1 = id.dst_port; - key.port2 = id.src_port; + ip1 = dst.in6; + ip2 = src.in6; + port1 = dst_port; + port2 = src_port; } - - key.transport = id.proto; - - return key; } +detail::ConnIDKey::ConnIDKey(const ConnID& id) + : ConnIDKey(id.src_addr, id.dst_addr, id.src_port, id.dst_port, + id.proto, id.is_one_way) + { + } + +} // namespace detail + IPAddr::IPAddr(const String& s) { Init(s.CheckString()); diff --git a/src/IPAddr.h b/src/IPAddr.h index 569b5edcbe..ca9332a93e 100644 --- a/src/IPAddr.h +++ b/src/IPAddr.h @@ -28,12 +28,9 @@ struct ConnIDKey { uint16_t port2; TransportProto transport; - ConnIDKey() : port1(0), port2(0), transport(TRANSPORT_UNKNOWN) - { - memset(&ip1, 0, sizeof(in6_addr)); - memset(&ip2, 0, sizeof(in6_addr)); - } - + ConnIDKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, + uint16_t dst_port, TransportProto t, bool one_way); + ConnIDKey(const ConnID& conn); ConnIDKey(const ConnIDKey& rhs) { *this = rhs; @@ -55,11 +52,6 @@ struct ConnIDKey { } }; -/** - * Returns a map key for a given ConnID. - */ -ConnIDKey BuildConnIDKey(const ConnID& id); - } // namespace detail /** @@ -398,8 +390,6 @@ public: */ void ConvertToThreadingValue(threading::Value::addr_t* v) const; - friend detail::ConnIDKey detail::BuildConnIDKey(const ConnID& id); - unsigned int MemoryAllocation() const { return padded_sizeof(*this); } /** @@ -451,6 +441,7 @@ public: static const IPAddr v6_unspecified; private: + friend struct detail::ConnIDKey; friend class IPPrefix; /** diff --git a/src/Session.cc b/src/Session.cc index e0b84dbe71..9097344307 100644 --- a/src/Session.cc +++ b/src/Session.cc @@ -118,6 +118,7 @@ Session::Session(double t, session_status_update_event(status_update_event), session_status_update_interval(status_update_interval) { + in_session_table = true; record_contents = record_packets = 1; record_current_packet = record_current_content = 0; is_active = 1; @@ -229,7 +230,7 @@ void Session::AddTimer(timer_func timer, double t, bool do_expire, // If the key is cleared, the session isn't stored in the session table // anymore and will soon be deleted. We're not installed new timers // anymore then. - if ( ! IsKeyValid() ) + if ( ! IsInSessionTable() ) return; detail::Timer* conn_timer = new detail::SessionTimer(this, timer, t, do_expire, type); diff --git a/src/Session.h b/src/Session.h index 8e8d13cb82..59d9efafce 100644 --- a/src/Session.h +++ b/src/Session.h @@ -56,20 +56,20 @@ public: * Returns a key for the session. This is used as the key for storing * the session in SessionManager. * - * @param copy Flag to indicate that the key returned has a copy of the + * @param copy Flag to indicate that the key returned must have a copy of the * key data instead of just a pointer to it. */ virtual detail::SessionKey SessionKey(bool copy) const = 0; /** - * Set the key as invalid. + * Set whether this session is in the session table. */ - virtual void ClearKey() = 0; + void SetInSessionTable(bool in_table) { in_session_table = in_table; } /** - * Return whether the key is valid. + * Return whether this session is in the session table. */ - virtual bool IsKeyValid() const = 0; + bool IsInSessionTable() const { return in_session_table; } double StartTime() const { return start_time; } void SetStartTime(double t) { start_time = t; } @@ -245,6 +245,7 @@ protected: unsigned int is_active:1; unsigned int record_packets:1, record_contents:1; unsigned int record_current_packet:1, record_current_content:1; + bool in_session_table; }; namespace detail { @@ -274,12 +275,15 @@ public: ~SessionKey(); // Implement move semantics for SessionKey, since they're used as keys - // in a map and copying them would cause double-free issues. Adding this - // constructor and operator explicitly disables the equivalent copy - // operations. + // in a map. SessionKey(SessionKey&& rhs); SessionKey& operator=(SessionKey&& rhs); + // Explicitly delete the copy constructor and operator since copying + // may cause issues with double-freeing pointers. + SessionKey(const SessionKey& rhs) = delete; + SessionKey& operator=(const SessionKey& rhs) = delete; + void CopyData(); bool operator<(const SessionKey& rhs) const; diff --git a/src/SessionManager.cc b/src/SessionManager.cc index d16944705a..3dd0d9525a 100644 --- a/src/SessionManager.cc +++ b/src/SessionManager.cc @@ -35,10 +35,72 @@ zeek::SessionManager* zeek::session_mgr = nullptr; zeek::SessionManager*& zeek::sessions = zeek::session_mgr; namespace zeek { +namespace detail { + +class ProtocolStats { + +public: + + struct Protocol { + telemetry::IntGauge active; + telemetry::IntCounter total; + ssize_t max = 0; + + Protocol(telemetry::IntGaugeFamily active_family, + telemetry::IntCounterFamily total_family, + std::string protocol) : active(active_family.GetOrAdd({{"protocol", protocol}})), + total(total_family.GetOrAdd({{"protocol", protocol}})) + { + } + }; + + using ProtocolMap = std::map; + + ProtocolMap::iterator InitCounters(const std::string& protocol) + { + telemetry::IntGaugeFamily active_family = telemetry_mgr->GaugeFamily( + "zeek", "active-sessions", {"protocol"}, "Active Zeek Sessions"); + telemetry::IntCounterFamily total_family = telemetry_mgr->CounterFamily( + "zeek", "total-sessions", {"protocol"}, + "Total number of sessions", "1", true); + + auto [it, inserted] = entries.insert( + {protocol, Protocol{active_family, total_family, protocol}}); + + if ( inserted ) + return it; + + return entries.end(); + } + + Protocol* GetCounters(const std::string& protocol) + { + auto it = entries.find(protocol); + if ( it == entries.end() ) + it = InitCounters(protocol); + + if ( it != entries.end() ) + return &(it->second); + + return nullptr; + } + +private: + + ProtocolMap entries; +}; + +} // namespace detail + +SessionManager::SessionManager() + { + stats = new detail::ProtocolStats(); + } SessionManager::~SessionManager() { Clear(); + delete stats; } void SessionManager::Done() @@ -126,7 +188,7 @@ void SessionManager::ProcessTransportLayer(double t, const Packet* pkt, size_t r return; } - detail::ConnIDKey conn_key = detail::BuildConnIDKey(id); + detail::ConnIDKey conn_key(id); detail::SessionKey key(&conn_key, sizeof(conn_key), false); Connection* conn = nullptr; @@ -203,7 +265,7 @@ void SessionManager::ProcessTransportLayer(double t, const Packet* pkt, size_t r } int SessionManager::ParseIPPacket(int caplen, const u_char* const pkt, int proto, - IP_Hdr*& inner) + IP_Hdr*& inner) { if ( proto == IPPROTO_IPV6 ) { @@ -292,7 +354,6 @@ Connection* SessionManager::FindConnection(Val* v) resp_h = 2; resp_p = 3; } - else { // While it's not a conn_id, it may have equivalent fields. @@ -314,18 +375,11 @@ Connection* SessionManager::FindConnection(Val* v) auto orig_portv = vl->GetFieldAs(orig_p); auto resp_portv = vl->GetFieldAs(resp_p); - ConnID id; + detail::ConnIDKey conn_key(orig_addr, resp_addr, + htons((unsigned short) orig_portv->Port()), + htons((unsigned short) resp_portv->Port()), + orig_portv->PortType(), false); - id.src_addr = orig_addr; - id.dst_addr = resp_addr; - - id.src_port = htons((unsigned short) orig_portv->Port()); - id.dst_port = htons((unsigned short) resp_portv->Port()); - - id.is_one_way = false; // ### incorrect for ICMP connections - id.proto = orig_portv->PortType(); - - detail::ConnIDKey conn_key = detail::BuildConnIDKey(id); detail::SessionKey key(&conn_key, sizeof(conn_key), false); Connection* conn = nullptr; @@ -338,36 +392,34 @@ Connection* SessionManager::FindConnection(Val* v) void SessionManager::Remove(Session* s) { - Connection* c = static_cast(s); - - if ( s->IsKeyValid() ) + if ( s->IsInSessionTable() ) { s->CancelTimers(); s->Done(); s->RemovalEvent(); - // Clears out the session's copy of the key so that if the - // session has been Ref()'d somewhere, we know that on a future - // call to Remove() that it's no longer in the map. detail::SessionKey key = s->SessionKey(false); if ( session_map.erase(key) == 0 ) reporter->InternalWarning("connection missing"); else { - if ( auto* stat_block = stats.GetCounters(c->TransportIdentifier()) ) - stat_block->num.Dec(); + Connection* c = static_cast(s); + if ( auto* stat_block = stats->GetCounters(c->TransportIdentifier()) ) + stat_block->active.Dec(); } - s->ClearKey(); + // Mark that the session isn't in the table so that in case the + // session has been Ref()'d somewhere, we know that on a future + // call to Remove() that it's no longer in the map. + s->SetInSessionTable(false); + Unref(s); } } void SessionManager::Insert(Session* s) { - assert(s->IsKeyValid()); - Session* old = nullptr; detail::SessionKey key = s->SessionKey(true); @@ -383,7 +435,7 @@ void SessionManager::Insert(Session* s) // Some clean-ups similar to those in Remove() (but invisible // to the script layer). old->CancelTimers(); - old->ClearKey(); + old->SetInSessionTable(false); Unref(old); } } @@ -410,19 +462,19 @@ void SessionManager::Clear() void SessionManager::GetStats(SessionStats& s) { - auto* tcp_stats = stats.GetCounters("tcp"); + auto* tcp_stats = stats->GetCounters("tcp"); s.max_TCP_conns = tcp_stats->max; - s.num_TCP_conns = tcp_stats->num.Value(); + s.num_TCP_conns = tcp_stats->active.Value(); s.cumulative_TCP_conns = tcp_stats->total.Value(); - auto* udp_stats = stats.GetCounters("udp"); + auto* udp_stats = stats->GetCounters("udp"); s.max_UDP_conns = udp_stats->max; - s.num_UDP_conns = udp_stats->num.Value(); + s.num_UDP_conns = udp_stats->active.Value(); s.cumulative_UDP_conns = udp_stats->total.Value(); - auto* icmp_stats = stats.GetCounters("icmp"); + auto* icmp_stats = stats->GetCounters("icmp"); s.max_ICMP_conns = icmp_stats->max; - s.num_ICMP_conns = icmp_stats->num.Value(); + s.num_ICMP_conns = icmp_stats->active.Value(); s.cumulative_ICMP_conns = icmp_stats->total.Value(); s.num_fragments = detail::fragment_mgr->Size(); @@ -634,17 +686,18 @@ unsigned int SessionManager::MemoryAllocation() void SessionManager::InsertSession(detail::SessionKey key, Session* session) { + session->SetInSessionTable(true); key.CopyData(); session_map.insert_or_assign(std::move(key), session); std::string protocol = session->TransportIdentifier(); - if ( auto* stat_block = stats.GetCounters(protocol) ) + if ( auto* stat_block = stats->GetCounters(protocol) ) { - stat_block->num.Inc(); + stat_block->active.Inc(); stat_block->total.Inc(); - if ( stat_block->num.Value() > stat_block->max ) + if ( stat_block->active.Value() > stat_block->max ) stat_block->max++; } } diff --git a/src/SessionManager.h b/src/SessionManager.h index b6f05dca53..e0b19af880 100644 --- a/src/SessionManager.h +++ b/src/SessionManager.h @@ -15,12 +15,18 @@ namespace zeek { -namespace detail { class PacketFilter; } +namespace detail { + +class PacketFilter; +class ProtocolStats; + +} // namespace detail class EncapsulationStack; class Packet; class Connection; struct ConnID; +class StatBlocks; struct SessionStats { size_t num_TCP_conns; @@ -42,7 +48,7 @@ struct SessionStats { class SessionManager final { public: - SessionManager() = default; + SessionManager(); ~SessionManager(); void Done(); // call to drain events before destructing @@ -141,59 +147,6 @@ public: private: - class StatBlocks { - - public: - - struct Block { - telemetry::IntGauge num; - telemetry::IntCounter total; - size_t max = 0; - - Block(telemetry::IntGaugeFamily num_family, - telemetry::IntCounterFamily total_family, - std::string protocol) : num(num_family.GetOrAdd({{"protocol", protocol}})), - total(total_family.GetOrAdd({{"protocol", protocol}})) - { - } - }; - - using BlockMap = std::map; - - BlockMap::iterator InitCounters(std::string protocol) - { - telemetry::IntGaugeFamily num_family = telemetry_mgr->GaugeFamily( - "zeek", "open-sessions", {"protocol"}, "Active Zeek Sessions"); - telemetry::IntCounterFamily total_family = telemetry_mgr->CounterFamily( - "zeek", "sessions", {"protocol"}, - "Total number of sessions", "1", true); - - auto [it, inserted] = entries.insert( - {protocol, Block{num_family, total_family, protocol}}); - - if ( inserted ) - return it; - - return entries.end(); - } - - Block* GetCounters(std::string protocol) - { - auto it = entries.find(protocol); - if ( it == entries.end() ) - it = InitCounters(protocol); - - if ( it != entries.end() ) - return &(it->second); - - return nullptr; - } - - private: - - BlockMap entries; - }; - using SessionMap = std::map; Connection* NewConn(const detail::ConnIDKey& k, double t, const ConnID* id, @@ -230,7 +183,7 @@ private: void InsertSession(detail::SessionKey key, Session* session); SessionMap session_map; - StatBlocks stats; + detail::ProtocolStats* stats; }; // Manager for the currently active sessions. diff --git a/src/fuzzers/pop3-fuzzer.cc b/src/fuzzers/pop3-fuzzer.cc index 3354b4954c..ad19ff38a9 100644 --- a/src/fuzzers/pop3-fuzzer.cc +++ b/src/fuzzers/pop3-fuzzer.cc @@ -26,7 +26,7 @@ static zeek::Connection* add_connection() conn_id.dst_port = htons(80); conn_id.is_one_way = false; conn_id.proto = TRANSPORT_TCP; - zeek::detail::ConnIDKey key = zeek::detail::BuildConnIDKey(conn_id); + zeek::detail::ConnIDKey key(conn_id); zeek::Connection* conn = new zeek::Connection(key, network_time_start, &conn_id, 1, &p); conn->SetTransport(TRANSPORT_TCP);