diff --git a/src/Conn.cc b/src/Conn.cc index c00e86b337..97a2564d1a 100644 --- a/src/Conn.cc +++ b/src/Conn.cc @@ -23,8 +23,27 @@ namespace zeek { uint64_t Connection::total_connections = 0; uint64_t Connection::current_connections = 0; +Connection::Connection(zeek::IPBasedConnKeyPtr k, const zeek::ConnTuple& ct, double t, uint32_t flow, const Packet* pkt) + : Session(t, connection_timeout, connection_status_update, detail::connection_status_update_interval), + key(std::move(k)) { + orig_addr = ct.src_addr; + resp_addr = ct.dst_addr; + orig_port = ct.src_port; + resp_port = ct.dst_port; + + switch ( ct.proto ) { + case IPPROTO_TCP: proto = TRANSPORT_TCP; break; + case IPPROTO_UDP: proto = TRANSPORT_UDP; break; + case IPPROTO_ICMP: + case IPPROTO_ICMPV6: proto = TRANSPORT_ICMP; break; + default: proto = TRANSPORT_UNKNOWN; break; + } + + Init(flow, pkt); +} + Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt) - : Session(t, connection_timeout, connection_status_update, detail::connection_status_update_interval), key(k) { + : Session(t, connection_timeout, connection_status_update, detail::connection_status_update_interval) { orig_addr = id->src_addr; resp_addr = id->dst_addr; orig_port = id->src_port; @@ -38,6 +57,28 @@ Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, default: proto = TRANSPORT_UNKNOWN; break; } + key = std::make_unique(); + key->InitTuple(*id); + key->Init(*pkt); + + Init(flow, pkt); +} + +Connection::~Connection() { + if ( ! finished ) + reporter->InternalError("Done() not called before destruction of Connection"); + + CancelTimers(); + + if ( conn_val ) + conn_val->SetOrigin(nullptr); + + delete adapter; + + --current_connections; +} + +void Connection::Init(uint32_t flow, const Packet* pkt) { orig_flow_label = flow; resp_flow_label = 0; saw_first_orig_packet = 1; @@ -71,20 +112,6 @@ Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, encapsulation = pkt->encap; } -Connection::~Connection() { - if ( ! finished ) - reporter->InternalError("Done() not called before destruction of Connection"); - - CancelTimers(); - - if ( conn_val ) - conn_val->SetOrigin(nullptr); - - delete adapter; - - --current_connections; -} - void Connection::CheckEncapsulation(const std::shared_ptr& arg_encap) { if ( encapsulation && arg_encap ) { if ( *encapsulation != *arg_encap ) { @@ -157,6 +184,13 @@ void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, i run_state::current_pkt = nullptr; } + +const ConnKey& Connection::Key() const { return *key; } + +session::detail::Key Connection::SessionKey(bool copy) const { return key->SessionKey(); } + +uint8_t Connection::KeyProto() const { return key->PackedTuple().proto; } + bool Connection::IsReuse(double t, const u_char* pkt) { return adapter && adapter->IsReuse(t, pkt); } namespace { @@ -186,6 +220,7 @@ const RecordValPtr& Connection::GetVal() { TransportProto prot_type = ConnTransport(); + // XXX this could technically move into IPBasedConnKey. auto id_val = make_intrusive(id::conn_id); id_val->Assign(0, make_intrusive(orig_addr)); id_val->Assign(1, val_mgr->Port(ntohs(orig_port), prot_type)); @@ -193,6 +228,9 @@ const RecordValPtr& Connection::GetVal() { id_val->Assign(3, val_mgr->Port(ntohs(resp_port), prot_type)); id_val->Assign(4, KeyProto()); + // Allow customized ConnKeys to augment the conn_id: + key->PopulateConnIdVal(*id_val); + auto orig_endp = make_intrusive(id::endpoint); orig_endp->Assign(0, 0); orig_endp->Assign(1, 0); diff --git a/src/Conn.h b/src/Conn.h index 95883b9b88..ca634a821f 100644 --- a/src/Conn.h +++ b/src/Conn.h @@ -5,6 +5,7 @@ #include #include +#include "zeek/ConnKey.h" #include "zeek/IPAddr.h" #include "zeek/IntrusivePtr.h" #include "zeek/Rule.h" @@ -27,6 +28,9 @@ class RecordVal; using ValPtr = IntrusivePtr; using RecordValPtr = IntrusivePtr; +class IPBasedConnKey; +using IPBasedConnKeyPtr = std::unique_ptr; + namespace detail { class Specific_RE_Matcher; @@ -64,6 +68,7 @@ static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPA class Connection final : public session::Session { public: + Connection(zeek::IPBasedConnKeyPtr k, const zeek::ConnTuple& ct, double t, uint32_t flow, const Packet* pkt); Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt); ~Connection() override; @@ -101,10 +106,12 @@ public: // Keys are only considered valid for a connection when a // connection is in the session map. If it is removed, the key // should be marked invalid. - const detail::ConnKey& Key() const { return key; } - session::detail::Key SessionKey(bool copy) const override { - return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE, copy}; - } + // + // These touch the key, which we forward-declared above. Therefore this + // hides the implementation, which has the full class definition. + const ConnKey& Key() const; + session::detail::Key SessionKey(bool copy) const override; + uint8_t KeyProto() const; const IPAddr& OrigAddr() const { return orig_addr; } const IPAddr& RespAddr() const { return resp_addr; } @@ -130,8 +137,6 @@ public: return "unknown"; } - uint8_t KeyProto() const { return key.transport; } - // Returns true if the packet reflects a reuse of this // connection (i.e., not a continuation but the beginning of // a new connection). @@ -196,6 +201,9 @@ public: bool IsFinished() { return finished; } private: + // Common initialization for the constructors. + void Init(uint32_t flow, const Packet* pkt); + friend class session::detail::Timer; IPAddr orig_addr; @@ -211,7 +219,7 @@ private: std::shared_ptr encapsulation; // tunnels uint8_t tunnel_changes = 0; - detail::ConnKey key; + IPBasedConnKeyPtr key; unsigned int weird : 1; unsigned int finished : 1; diff --git a/src/analyzer/Analyzer.cc b/src/analyzer/Analyzer.cc index f1ef9ed459..fc73a468f7 100644 --- a/src/analyzer/Analyzer.cc +++ b/src/analyzer/Analyzer.cc @@ -8,6 +8,7 @@ #include "zeek/Conn.h" #include "zeek/Event.h" #include "zeek/analyzer/Manager.h" +#include "zeek/packet_analysis/protocol/ip/conn_key/IPBasedConnKey.h" #include "zeek/packet_analysis/protocol/tcp/TCPSessionAdapter.h" #include "zeek/3rdparty/doctest.h" @@ -806,8 +807,9 @@ TEST_SUITE("Analyzer management") { REQUIRE(zeek::analyzer_mgr); zeek::Packet p; - zeek::ConnTuple t; - auto conn = std::make_unique(zeek::detail::ConnKey(t), 0, &t, 0, &p); + zeek::ConnTuple ct; + zeek::IPBasedConnKeyPtr kp = std::make_unique(); + auto conn = std::make_unique(std::move(kp), ct, 0, 0, &p); auto* tcp = new zeek::packet_analysis::TCP::TCPSessionAdapter(conn.get()); conn->SetSessionAdapter(tcp, nullptr); @@ -838,8 +840,9 @@ TEST_SUITE("Analyzer management") { REQUIRE(zeek::analyzer_mgr); zeek::Packet p; - zeek::ConnTuple t; - auto conn = std::make_unique(zeek::detail::ConnKey(t), 0, &t, 0, &p); + zeek::ConnTuple ct; + zeek::IPBasedConnKeyPtr kp = std::make_unique(); + auto conn = std::make_unique(std::move(kp), ct, 0, 0, &p); auto ssh = zeek::analyzer_mgr->InstantiateAnalyzer("SSH", conn.get()); REQUIRE(ssh); diff --git a/src/analyzer/protocol/smtp/BDAT.cc b/src/analyzer/protocol/smtp/BDAT.cc index f83d011592..56fe42f4bb 100644 --- a/src/analyzer/protocol/smtp/BDAT.cc +++ b/src/analyzer/protocol/smtp/BDAT.cc @@ -5,6 +5,7 @@ #include "zeek/Conn.h" #include "zeek/DebugLogger.h" #include "zeek/analyzer/protocol/mime/MIME.h" +#include "zeek/packet_analysis/protocol/ip/conn_key/IPBasedConnKey.h" #include "zeek/util.h" #include "zeek/3rdparty/doctest.h" @@ -327,8 +328,9 @@ private: TEST_CASE("line forward testing") { zeek::Packet p; - zeek::ConnTuple t; - auto conn = std::make_unique(zeek::detail::ConnKey(t), 0, &t, 0, &p); + zeek::ConnTuple ct; + zeek::IPBasedConnKeyPtr kp = std::make_unique(); + auto conn = std::make_unique(std::move(kp), ct, 0, 0, &p); auto smtp_analyzer = std::unique_ptr(zeek::analyzer_mgr->InstantiateAnalyzer("SMTP", conn.get())); auto mail = std::make_unique(smtp_analyzer.get()); diff --git a/src/packet_analysis/protocol/gtpv1/GTPv1.cc b/src/packet_analysis/protocol/gtpv1/GTPv1.cc index 9dbfccdb80..547e145ba3 100644 --- a/src/packet_analysis/protocol/gtpv1/GTPv1.cc +++ b/src/packet_analysis/protocol/gtpv1/GTPv1.cc @@ -19,11 +19,15 @@ bool GTPv1_Analyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* pack } auto conn = static_cast(packet->session); - zeek::detail::ConnKey conn_key = conn->Key(); + const auto& key = conn->Key(); + auto sk = key.SessionKey(); - auto cm_it = conn_map.find(conn_key); + auto cm_it = conn_map.find(sk); if ( cm_it == conn_map.end() ) { - cm_it = conn_map.insert(cm_it, {conn_key, std::make_unique(this)}); + sk.CopyData(); // Copy key data to store in map. + auto [it, inserted] = conn_map.emplace(std::move(sk), std::make_unique(this)); + assert(inserted); + cm_it = it; // Let script land know about the state we created, so it will // register a conn removal hook for cleanup. diff --git a/src/packet_analysis/protocol/gtpv1/GTPv1.h b/src/packet_analysis/protocol/gtpv1/GTPv1.h index f48707dea1..846f550942 100644 --- a/src/packet_analysis/protocol/gtpv1/GTPv1.h +++ b/src/packet_analysis/protocol/gtpv1/GTPv1.h @@ -3,6 +3,7 @@ #pragma once #include "zeek/packet_analysis/Analyzer.h" +#include "zeek/session/Key.h" #include "packet_analysis/protocol/gtpv1/gtpv1_pac.h" @@ -27,11 +28,10 @@ public: gtp_hdr_val = std::move(val); } - void RemoveConnection(const zeek::detail::ConnKey& conn_key) { conn_map.erase(conn_key); } + void RemoveConnection(const zeek::session::detail::Key& conn_key) { conn_map.erase(conn_key); } protected: - using ConnMap = std::map>; - ConnMap conn_map; + std::map> conn_map; int inner_packet_offset = -1; uint8_t next_header = 0; diff --git a/src/packet_analysis/protocol/gtpv1/functions.bif b/src/packet_analysis/protocol/gtpv1/functions.bif index 05376a920e..d48cf8acce 100644 --- a/src/packet_analysis/protocol/gtpv1/functions.bif +++ b/src/packet_analysis/protocol/gtpv1/functions.bif @@ -2,6 +2,7 @@ module PacketAnalyzer::GTPV1; %%{ #include "zeek/Conn.h" +#include "zeek/conn_key/Manager.h" #include "zeek/session/Manager.h" #include "zeek/packet_analysis/Manager.h" #include "zeek/packet_analysis/protocol/gtpv1/GTPv1.h" @@ -12,8 +13,12 @@ function remove_gtpv1_connection%(cid: conn_id%) : bool zeek::packet_analysis::AnalyzerPtr gtpv1 = zeek::packet_mgr->GetAnalyzer("GTPv1"); if ( gtpv1 ) { - zeek::detail::ConnKey conn_key(cid); - static_cast(gtpv1.get())->RemoveConnection(conn_key); + auto r = zeek::conn_key_mgr->GetFactory().ConnKeyFromVal(*cid); + if ( ! r.has_value() ) + return zeek::val_mgr->False(); + + auto sk = r.value()->SessionKey(); + static_cast(gtpv1.get())->RemoveConnection(sk); } return zeek::val_mgr->True(); diff --git a/src/packet_analysis/protocol/ip/IPBasedAnalyzer.cc b/src/packet_analysis/protocol/ip/IPBasedAnalyzer.cc index 1b816a11b7..bfee8c3a04 100644 --- a/src/packet_analysis/protocol/ip/IPBasedAnalyzer.cc +++ b/src/packet_analysis/protocol/ip/IPBasedAnalyzer.cc @@ -7,6 +7,8 @@ #include "zeek/Val.h" #include "zeek/analyzer/Manager.h" #include "zeek/analyzer/protocol/pia/PIA.h" +#include "zeek/conn_key/Manager.h" +#include "zeek/packet_analysis/protocol/ip/conn_key/IPBasedConnKey.h" #include "zeek/plugin/Manager.h" #include "zeek/session/Manager.h" @@ -26,13 +28,30 @@ bool IPBasedAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* pkt if ( ! BuildConnTuple(len, data, pkt, tuple) ) return false; - const std::shared_ptr& ip_hdr = pkt->ip_hdr; - zeek::detail::ConnKey key(tuple); + static IPBasedConnKeyPtr key; // Note, this is static for reuse: + if ( ! key ) { + ConnKeyPtr ck = conn_key_mgr->GetFactory().NewConnKey(); - Connection* conn = session_mgr->FindConnection(key); + // The IPBasedAnalyzer requires a factory that produces IPBasedConnKey instances. + // We could check with dynamic_cast, but that's probably slow, so assume plugin + // providers know what they're doing here and anyhow, we don't really have analyzers + // that instantiate non-IP connections today and definitely not here! + key = IPBasedConnKeyPtr(static_cast(ck.release())); + } + + // Initialize the key with the IP conn tuple and the packet as additional context. + // + // Custom IPConnKey implementations can fiddle with the Key through + // the DoInit(const Packet& pkt) hook called at this point. + key->InitTuple(tuple); + key->Init(*pkt); + + const std::shared_ptr& ip_hdr = pkt->ip_hdr; + + Connection* conn = session_mgr->FindConnection(*key); if ( ! conn ) { - conn = NewConn(&tuple, key, pkt); + conn = NewConn(tuple, std::move(key), pkt); if ( conn ) session_mgr->Insert(conn, false); } @@ -41,7 +60,7 @@ bool IPBasedAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* pkt conn->Event(connection_reused, nullptr); session_mgr->Remove(conn); - conn = NewConn(&tuple, key, pkt); + conn = NewConn(tuple, std::move(key), pkt); if ( conn ) session_mgr->Insert(conn, false); } @@ -140,18 +159,19 @@ bool IPBasedAnalyzer::IsLikelyServerPort(uint32_t port) const { return port_cache.find(port) != port_cache.end(); } -zeek::Connection* IPBasedAnalyzer::NewConn(const ConnTuple* id, const zeek::detail::ConnKey& key, const Packet* pkt) { - int src_h = ntohs(id->src_port); - int dst_h = ntohs(id->dst_port); +zeek::Connection* IPBasedAnalyzer::NewConn(const ConnTuple& id, IPBasedConnKeyPtr key, const Packet* pkt) { + int src_h = ntohs(id.src_port); + int dst_h = ntohs(id.dst_port); bool flip = false; if ( ! WantConnection(src_h, dst_h, pkt->ip_hdr->Payload(), flip) ) return nullptr; - Connection* conn = new Connection(key, run_state::processing_start_time, id, pkt->ip_hdr->FlowLabel(), pkt); + Connection* conn = + new Connection(std::move(key), id, run_state::processing_start_time, pkt->ip_hdr->FlowLabel(), pkt); conn->SetTransport(transport); - if ( flip && ! id->dst_addr.IsBroadcast() ) + if ( flip && ! id.dst_addr.IsBroadcast() ) conn->FlipRoles(); BuildSessionAnalyzerTree(conn); diff --git a/src/packet_analysis/protocol/ip/IPBasedAnalyzer.h b/src/packet_analysis/protocol/ip/IPBasedAnalyzer.h index bc6d0d08e8..ecaaa0970a 100644 --- a/src/packet_analysis/protocol/ip/IPBasedAnalyzer.h +++ b/src/packet_analysis/protocol/ip/IPBasedAnalyzer.h @@ -7,6 +7,7 @@ #include "zeek/Tag.h" #include "zeek/packet_analysis/Analyzer.h" +#include "zeek/packet_analysis/protocol/ip/conn_key/IPBasedConnKey.h" namespace zeek::analyzer::pia { class PIA; @@ -184,7 +185,7 @@ private: * @param key A connection ID key generated from the ID. * @param pkt The packet associated with the new connection. */ - zeek::Connection* NewConn(const ConnTuple* id, const zeek::detail::ConnKey& key, const Packet* pkt); + zeek::Connection* NewConn(const ConnTuple& id, IPBasedConnKeyPtr key, const Packet* pkt); void BuildSessionAnalyzerTree(Connection* conn); diff --git a/src/packet_analysis/protocol/teredo/Teredo.cc b/src/packet_analysis/protocol/teredo/Teredo.cc index 68acf554bf..a2a0781b71 100644 --- a/src/packet_analysis/protocol/teredo/Teredo.cc +++ b/src/packet_analysis/protocol/teredo/Teredo.cc @@ -185,15 +185,19 @@ bool TeredoAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* pack return false; } - zeek::detail::ConnKey conn_key = conn->Key(); - OrigRespMap::iterator or_it = orig_resp_map.find(conn_key); + const auto& k = conn->Key(); + auto sk = k.SessionKey(); + OrigRespMap::iterator or_it = orig_resp_map.find(sk); // The first time a teredo packet is parsed successfully, insert // state into orig_resp_map so we can confirm when both sides // see valid Teredo packets. Further, raise an event so that script // layer can install a connection removal hooks to cleanup later. if ( or_it == orig_resp_map.end() ) { - or_it = orig_resp_map.insert(or_it, {conn_key, {}}); + sk.CopyData(); // Copy key data to store in map. + auto [it, inserted] = orig_resp_map.emplace(std::move(sk), OrigResp{}); + assert(inserted); + or_it = it; packet->session->EnqueueEvent(new_teredo_state, nullptr, packet->session->GetVal()); } diff --git a/src/packet_analysis/protocol/teredo/Teredo.h b/src/packet_analysis/protocol/teredo/Teredo.h index 7e2ca41948..e29879f668 100644 --- a/src/packet_analysis/protocol/teredo/Teredo.h +++ b/src/packet_analysis/protocol/teredo/Teredo.h @@ -8,6 +8,7 @@ #include "zeek/RE.h" #include "zeek/Reporter.h" #include "zeek/packet_analysis/Analyzer.h" +#include "zeek/session/Key.h" namespace zeek::packet_analysis::teredo { @@ -44,7 +45,7 @@ public: bool DetectProtocol(size_t len, const uint8_t* data, Packet* packet) override; - void RemoveConnection(const zeek::detail::ConnKey& conn_key) { orig_resp_map.erase(conn_key); } + void RemoveConnection(const zeek::session::detail::Key& conn_key) { orig_resp_map.erase(conn_key); } protected: struct OrigResp { @@ -52,7 +53,7 @@ protected: bool valid_resp = false; bool confirmed = false; }; - using OrigRespMap = std::map; + using OrigRespMap = std::map; OrigRespMap orig_resp_map; std::unique_ptr pattern_re; diff --git a/src/packet_analysis/protocol/teredo/functions.bif b/src/packet_analysis/protocol/teredo/functions.bif index 8607712ca5..8b1a5eb48c 100644 --- a/src/packet_analysis/protocol/teredo/functions.bif +++ b/src/packet_analysis/protocol/teredo/functions.bif @@ -2,6 +2,7 @@ module PacketAnalyzer::TEREDO; %%{ #include "zeek/Conn.h" +#include "zeek/conn_key/Manager.h" #include "zeek/session/Manager.h" #include "zeek/packet_analysis/Manager.h" #include "zeek/packet_analysis/protocol/teredo/Teredo.h" @@ -12,8 +13,12 @@ function remove_teredo_connection%(cid: conn_id%) : bool zeek::packet_analysis::AnalyzerPtr teredo = zeek::packet_mgr->GetAnalyzer("Teredo"); if ( teredo ) { - zeek::detail::ConnKey conn_key(cid); - static_cast(teredo.get())->RemoveConnection(conn_key); + auto r = zeek::conn_key_mgr->GetFactory().ConnKeyFromVal(*cid); + if ( ! r.has_value() ) + return zeek::val_mgr->False(); + + auto sk = r.value()->SessionKey(); + static_cast(teredo.get())->RemoveConnection(sk); } return zeek::val_mgr->True(); diff --git a/src/session/Manager.cc b/src/session/Manager.cc index 32dfa4f226..ad1fc49b0d 100644 --- a/src/session/Manager.cc +++ b/src/session/Manager.cc @@ -17,6 +17,7 @@ #include "zeek/Stats.h" #include "zeek/Timer.h" #include "zeek/TunnelEncapsulation.h" +#include "zeek/conn_key/Manager.h" #include "zeek/packet_analysis/Manager.h" #include "zeek/session/Session.h" #include "zeek/telemetry/Manager.h" @@ -88,23 +89,23 @@ Manager::~Manager() { } Connection* Manager::FindConnection(Val* v) { - zeek::detail::ConnKey conn_key(v); + // XXX: This could in the future dispatch to different factories for + // different kinds of Vals. ``v`` will usually be a conn_id instance, which + // is IP-specific. If ``v`` is something else, maybe we'd like to use a + // different builder. + auto r = conn_key_mgr->GetFactory().ConnKeyFromVal(*v); - if ( ! conn_key.Valid() ) { + if ( ! r.has_value() ) { // Produce a loud error for invalid script-layer conn_id records. - const char* extra = ""; - if ( conn_key.transport == UNKNOWN_IP_PROTO ) - extra = ": the proto field has the \"unknown\" 65535 value. Did you forget to set it?"; - - zeek::emit_builtin_error(zeek::util::fmt("invalid connection ID record encountered%s", extra)); + zeek::emit_builtin_error(r.error().c_str()); return nullptr; } - return FindConnection(conn_key); + return FindConnection(*r.value()); } -Connection* Manager::FindConnection(const zeek::detail::ConnKey& conn_key) { - detail::Key key(&conn_key, sizeof(conn_key), detail::Key::CONNECTION_KEY_TYPE, false); +Connection* Manager::FindConnection(const zeek::ConnKey& conn_key) { + auto key = conn_key.SessionKey(); auto it = session_map.find(key); if ( it != session_map.end() ) diff --git a/src/session/Manager.h b/src/session/Manager.h index 5803673cdf..c4e1ef4f73 100644 --- a/src/session/Manager.h +++ b/src/session/Manager.h @@ -5,6 +5,7 @@ #include // for u_char #include +#include "zeek/ConnKey.h" #include "zeek/Frag.h" #include "zeek/session/Session.h" @@ -70,7 +71,7 @@ public: * @param conn_key The key for the connection to search for. * @return The connection, or nullptr if one doesn't exist. */ - Connection* FindConnection(const zeek::detail::ConnKey& conn_key); + Connection* FindConnection(const zeek::ConnKey& conn_key); void Remove(Session* s); void Insert(Session* c, bool remove_existing = true);