diff --git a/src/Sessions.cc b/src/Sessions.cc index 62f0714aa8..35481f30aa 100644 --- a/src/Sessions.cc +++ b/src/Sessions.cc @@ -344,8 +344,8 @@ Connection* NetSessions::FindConnection(Val* v) const IPAddr& orig_addr = vl->GetFieldAs(orig_h); const IPAddr& resp_addr = vl->GetFieldAs(resp_h); - const PortVal* orig_portv = vl->GetFieldAs(orig_p); - const PortVal* resp_portv = vl->GetFieldAs(resp_p); + auto orig_portv = vl->GetFieldAs(orig_p); + auto resp_portv = vl->GetFieldAs(resp_p); ConnID id; diff --git a/src/Val.h b/src/Val.h index 9eb639f10d..e7998bded6 100644 --- a/src/Val.h +++ b/src/Val.h @@ -73,6 +73,7 @@ class EnumVal; class OpaqueVal; class VectorVal; class TableEntryVal; +class TypeVal; using AddrValPtr = IntrusivePtr; using EnumValPtr = IntrusivePtr; @@ -446,14 +447,19 @@ public: // Returns a masked port number static uint32_t Mask(uint32_t port_num, TransportProto port_type); - const PortVal* Get() const { return AsPortVal(); } - protected: friend class ValManager; PortVal(uint32_t p); void ValDescribe(ODesc* d) const override; ValPtr DoClone(CloneState* state) override; + +private: + // This method is just here to trick the interface in + // `RecordVal::GetFieldAs` into returning the right type. + // It shouldn't actually be used for anything. + friend class RecordVal; + PortValPtr Get() { return {NewRef{}, this}; } }; class AddrVal final : public Val { @@ -996,6 +1002,39 @@ private: PDict* table_val; }; +// This would be way easier with is_convertible_v, but sadly that won't +// work here because Obj has deleted copy constructors (and for good +// reason). Instead we make up our own type trait here that basically +// combines a bunch of is_same traits into a single trait to make life +// easier in the definitions of GetFieldAs(). +template +struct is_zeek_val + { + static const bool value = std::disjunction_v< + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same, + std::is_same>; + }; +template +inline constexpr bool is_zeek_val_v = is_zeek_val::value; + class RecordVal final : public Val, public notifier::detail::Modifiable { public: explicit RecordVal(RecordTypePtr t, bool init_fields = true); @@ -1148,22 +1187,77 @@ public: // The following return the given field converted to a particular // underlying value. We provide these to enable efficient - // access to record fields (without requiring an intermediary Val) - // if we change the underlying representation of records. - template - auto GetFieldAs(int field) const -> std::invoke_result_t + // access to record fields (without requiring an intermediary Val). + // It is up to the caller to ensure that the field exists in the + // record (using HasField(), if necessary). + template , bool> = true> + auto GetFieldAs(int field) const -> std::invoke_result_t { - auto field_ptr = GetField(field); - auto field_val_ptr = static_cast(field_ptr.get()); - return field_val_ptr->Get(); + if constexpr ( std::is_same_v || + std::is_same_v || + std::is_same_v ) + return record_val->at(field).int_val; + else if constexpr ( std::is_same_v ) + return record_val->at(field).uint_val; + else if constexpr ( std::is_same_v || + std::is_same_v || + std::is_same_v ) + return record_val->at(field).double_val; + else if constexpr ( std::is_same_v ) + return val_mgr->Port(record_val->at(field).uint_val); + else if constexpr ( std::is_same_v ) + return record_val->at(field).string_val->Get(); + else if constexpr ( std::is_same_v ) + return record_val->at(field).addr_val->Get(); + else if constexpr ( std::is_same_v ) + return record_val->at(field).subnet_val->Get(); + else if constexpr ( std::is_same_v ) + return *(record_val->at(field).file_val); + else if constexpr ( std::is_same_v ) + return *(record_val->at(field).func_val); + else if constexpr ( std::is_same_v ) + return record_val->at(field).re_val->Get(); + else if constexpr ( std::is_same_v ) + return record_val->at(field).record_val; + else if constexpr ( std::is_same_v ) + return record_val->at(field).vector_val; + else if constexpr ( std::is_same_v ) + return record_val->at(field).table_val->Get(); + else + { + // TODO: error here, although because of the + // type trait it really shouldn't ever get here. + } + } + + template , bool> = true> + T GetFieldAs(int field) const + { + if constexpr ( std::is_integral_v && std::is_signed_v ) + return record_val->at(field).int_val; + else if constexpr ( std::is_integral_v && + std::is_unsigned_v ) + return record_val->at(field).uint_val; + else if constexpr ( std::is_floating_point_v ) + return record_val->at(field).double_val; + + // Could add other types here using type traits, + // such as is_same_v, etc. + + return T{}; } template - auto GetFieldAs(const char* field) const -> std::invoke_result_t + auto GetFieldAs(const char* field) const { - auto field_ptr = GetField(field); - auto field_val_ptr = static_cast(field_ptr.get()); - return field_val_ptr->Get(); + int idx = GetType()->AsRecordType()->FieldOffset(field); + + if ( idx < 0 ) + reporter->InternalError("missing record field: %s", field); + + return GetFieldAs(idx); } void Describe(ODesc* d) const override;