diff --git a/src/Val.cc b/src/Val.cc index 340cef6bb5..fe83c8d583 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -72,31 +72,59 @@ Val::~Val() #endif } -Val* Val::Clone() const +Val* Val::Clone() { - SerializationFormat* form = new BinarySerializationFormat(); - form->StartWrite(); - - CloneSerializer ss(form); - SerialInfo sinfo(&ss); - sinfo.cache = false; - sinfo.include_locations = false; - - if ( ! this->Serialize(&sinfo) ) - return 0; - - char* data; - uint32 len = form->EndWrite(&data); - form->StartRead(data, len); - - UnserialInfo uinfo(&ss); - uinfo.cache = false; - Val* clone = Unserialize(&uinfo, type); - - free(data); - return clone; + Val::CloneState state; + return Clone(&state); } +Val* Val::Clone(CloneState* state) + { + auto i = state->clones.find(this); + + if ( i != state->clones.end() ) + return i->second->Ref(); + + auto c = DoClone(state); + assert(c); + + state->clones.insert(std::make_pair(this, c)); + return c; + } + +Val* Val::DoClone(CloneState* state) + { + switch ( type->InternalType() ) { + case TYPE_INTERNAL_INT: + case TYPE_INTERNAL_UNSIGNED: + case TYPE_INTERNAL_DOUBLE: + // Immutable. + return Ref(); + + case TYPE_INTERNAL_OTHER: + // Derived classes are responsible for this. Exception: + // Functions and files. There aren't any derived classes. + if ( type->Tag() == TYPE_FUNC ) + // Immutable. + return Ref(); + + if ( type->Tag() == TYPE_FILE ) + { + auto f = AsFile(); + ::Ref(f); + return new Val(f); + } + + // Fall-through. + + default: + reporter->InternalError("cloning illegal base type"); + } + + reporter->InternalError("cannot be reached"); + return nullptr; + } + bool Val::Serialize(SerialInfo* info) const { return SerialObj::Serialize(info); @@ -862,6 +890,12 @@ void PortVal::ValDescribe(ODesc* d) const d->Add("/unknown"); } +Val* PortVal::DoClone(CloneState* state) + { + // Immutable. + return Ref(); + } + IMPLEMENT_SERIAL(PortVal, SER_PORT_VAL); bool PortVal::DoSerialize(SerialInfo* info) const @@ -920,6 +954,12 @@ Val* AddrVal::SizeVal() const return val_mgr->GetCount(128); } +Val* AddrVal::DoClone(CloneState* state) + { + // Immutable. + return Ref(); + } + IMPLEMENT_SERIAL(AddrVal, SER_ADDR_VAL); bool AddrVal::DoSerialize(SerialInfo* info) const @@ -1044,6 +1084,12 @@ bool SubNetVal::Contains(const IPAddr& addr) const return val.subnet_val->Contains(a); } +Val* SubNetVal::DoClone(CloneState* state) + { + // Immutable. + return Ref(); + } + IMPLEMENT_SERIAL(SubNetVal, SER_SUBNET_VAL); bool SubNetVal::DoSerialize(SerialInfo* info) const @@ -1100,6 +1146,11 @@ unsigned int StringVal::MemoryAllocation() const return padded_sizeof(*this) + val.string_val->MemoryAllocation(); } +Val* StringVal::DoClone(CloneState* state) + { + return new StringVal(new BroString((u_char*) val.string_val->Bytes(), val.string_val->Len(), 1)); + } + IMPLEMENT_SERIAL(StringVal, SER_STRING_VAL); bool StringVal::DoSerialize(SerialInfo* info) const @@ -1162,6 +1213,13 @@ unsigned int PatternVal::MemoryAllocation() const return padded_sizeof(*this) + val.re_val->MemoryAllocation(); } +Val* PatternVal::DoClone(CloneState* state) + { + // TODO: Double-check + auto re = new RE_Matcher(val.re_val->PatternText(), val.re_val->AnywherePatternText()); + return new PatternVal(re); + } + IMPLEMENT_SERIAL(PatternVal, SER_PATTERN_VAL); bool PatternVal::DoSerialize(SerialInfo* info) const @@ -1260,6 +1318,16 @@ void ListVal::Describe(ODesc* d) const } } +Val* ListVal::DoClone(CloneState* state) + { + auto lv = new ListVal(tag); + + loop_over_list(vals, i) + lv->Append(vals[i]->Clone(state)); + + return lv; + } + IMPLEMENT_SERIAL(ListVal, SER_LIST_VAL); bool ListVal::DoSerialize(SerialInfo* info) const @@ -2498,6 +2566,55 @@ void TableVal::ReadOperation(Val* index, TableEntryVal* v) } } +Val* TableVal::DoClone(CloneState* state) + { + auto tv = new TableVal(table_type); + + const PDict(TableEntryVal)* tbl = AsTable(); + IterCookie* cookie = tbl->InitForIteration(); + + HashKey* key; + TableEntryVal* val; + while ( (val = tbl->NextEntry(key, cookie)) ) + { + Val* idx = RecoverIndex(key); + TableEntryVal* nval = val ? new TableEntryVal(*val) : nullptr; + tv->AsNonConstTable()->Insert(key, nval); + + if ( subnets ) + { + tv->subnets->Insert(idx, nval); + Unref(idx); + } + + delete key; + } + + if ( attrs ) + { + ::Ref(attrs); + tv->attrs = attrs; + } + + if ( expire_time ) + { + tv->expire_time = expire_time->Ref(); + + // As network_time is not necessarily initialized yet, we set + // a timer which fires immediately. + timer = new TableValTimer(this, 1); + timer_mgr->Add(timer); + } + + if ( expire_func ) + tv->expire_func = expire_func->Ref(); + + if ( def_val ) + tv->def_val = def_val->Ref(); + + return tv; + } + IMPLEMENT_SERIAL(TableVal, SER_TABLE_VAL); // This is getting rather complex due to the ability to suspend even within @@ -3052,7 +3169,7 @@ void RecordVal::Describe(ODesc* d) const void RecordVal::DescribeReST(ODesc* d) const { const val_list* vl = AsRecord(); - int n = vl->length(); + int n = vl->length(); d->Add("{"); d->PushIndent(); @@ -3077,6 +3194,21 @@ void RecordVal::DescribeReST(ODesc* d) const d->Add("}"); } +Val* RecordVal::DoClone(CloneState* state) + { + // TODO: We leave origin unset, ok? + ::Ref(record_type); + auto rv = new RecordVal(record_type); + + loop_over_list(*val.val_list_val, i) + { + Val* v = (*val.val_list_val)[i]->Clone(state); + rv->val.val_list_val->append(v); + } + + return nullptr; + } + IMPLEMENT_SERIAL(RecordVal, SER_RECORD_VAL); bool RecordVal::DoSerialize(SerialInfo* info) const @@ -3193,6 +3325,12 @@ void EnumVal::ValDescribe(ODesc* d) const d->Add(ename); } +Val* EnumVal::DoClone(CloneState* state) + { + // Immutable. + return Ref(); + } + IMPLEMENT_SERIAL(EnumVal, SER_ENUM_VAL); bool EnumVal::DoSerialize(SerialInfo* info) const @@ -3378,6 +3516,19 @@ bool VectorVal::RemoveProperties(Properties arg_props) return true; } + +Val* VectorVal::DoClone(CloneState* state) + { + auto vv = new VectorVal(vector_type); + for ( unsigned int i = 0; i < val.vector_val->size(); ++i ) + { + auto v = (*val.vector_val)[i]->Clone(state); + vv->val.vector_val->push_back(v); + } + + return vv; + } + IMPLEMENT_SERIAL(VectorVal, SER_VECTOR_VAL); bool VectorVal::DoSerialize(SerialInfo* info) const @@ -3450,6 +3601,12 @@ OpaqueVal::~OpaqueVal() { } +Val* OpaqueVal::DoClone(CloneState* state) + { + // TODO + return nullptr; + } + IMPLEMENT_SERIAL(OpaqueVal, SER_OPAQUE_VAL); bool OpaqueVal::DoSerialize(SerialInfo* info) const diff --git a/src/Val.h b/src/Val.h index 63e790848d..5104c1933e 100644 --- a/src/Val.h +++ b/src/Val.h @@ -172,7 +172,7 @@ public: ~Val() override; Val* Ref() { ::Ref(this); return this; } - virtual Val* Clone() const; + Val* Clone(); int IsZero() const; int IsOne() const; @@ -370,6 +370,9 @@ public: protected: friend class EnumType; + friend class ListVal; + friend class RecordVal; + friend class VectorVal; friend class ValManager; virtual void ValDescribe(ODesc* d) const; @@ -419,6 +422,14 @@ protected: static Val* Unserialize(UnserialInfo* info, TypeTag type, const BroType* exact_type); + // For internal use by the Val::Clone() methods. + struct CloneState { + std::unordered_map clones; + }; + + Val* Clone(CloneState* state); + virtual Val* DoClone(CloneState* state); + BroValUnion val; BroType* type; @@ -639,6 +650,7 @@ protected: PortVal(uint32 p, bool unused); void ValDescribe(ODesc* d) const override; + Val* DoClone(CloneState* state) override; DECLARE_SERIAL(PortVal); }; @@ -664,6 +676,8 @@ protected: explicit AddrVal(TypeTag t) : Val(t) { } explicit AddrVal(BroType* t) : Val(t) { } + Val* DoClone(CloneState* state) override; + DECLARE_SERIAL(AddrVal); }; @@ -692,6 +706,7 @@ protected: SubNetVal() {} void ValDescribe(ODesc* d) const override; + Val* DoClone(CloneState* state) override; DECLARE_SERIAL(SubNetVal); }; @@ -724,6 +739,7 @@ protected: StringVal() {} void ValDescribe(ODesc* d) const override; + Val* DoClone(CloneState* state) override; DECLARE_SERIAL(StringVal); }; @@ -744,6 +760,7 @@ protected: PatternVal() {} void ValDescribe(ODesc* d) const override; + Val* DoClone(CloneState* state) override; DECLARE_SERIAL(PatternVal); }; @@ -789,6 +806,8 @@ protected: friend class Val; ListVal() {} + Val* DoClone(CloneState* state) override; + DECLARE_SERIAL(ListVal); val_list vals; @@ -806,6 +825,15 @@ public: expire_access_time = last_read_update = int(network_time - bro_start_network_time); } + + TableEntryVal(const TableEntryVal& other) + { + val = other.val->Ref(); + last_access_time = other.last_access_time; + expire_access_time = other.expire_access_time; + last_read_update = other.last_read_update; + } + ~TableEntryVal() { } Val* Value() { return val; } @@ -997,6 +1025,8 @@ protected: // Propagates a read operation if necessary. void ReadOperation(Val* index, TableEntryVal *v); + Val* DoClone(CloneState* state) override; + DECLARE_SERIAL(TableVal); TableType* table_type; @@ -1069,6 +1099,8 @@ protected: bool AddProperties(Properties arg_state) override; bool RemoveProperties(Properties arg_state) override; + Val* DoClone(CloneState* state) override; + DECLARE_SERIAL(RecordVal); RecordType* record_type; @@ -1100,6 +1132,7 @@ protected: EnumVal() {} void ValDescribe(ODesc* d) const override; + Val* DoClone(CloneState* state) override; DECLARE_SERIAL(EnumVal); }; @@ -1160,6 +1193,7 @@ protected: bool AddProperties(Properties arg_state) override; bool RemoveProperties(Properties arg_state) override; void ValDescribe(ODesc* d) const override; + Val* DoClone(CloneState* state) override; DECLARE_SERIAL(VectorVal); @@ -1178,6 +1212,8 @@ protected: friend class Val; OpaqueVal() { } + Val* DoClone(CloneState* state) override; + DECLARE_SERIAL(OpaqueVal); }; diff --git a/testing/btest/language/copy-all-types.zeek b/testing/btest/language/copy-all-types.zeek new file mode 100644 index 0000000000..e39308b8e0 --- /dev/null +++ b/testing/btest/language/copy-all-types.zeek @@ -0,0 +1,27 @@ +# @TEST-EXEC: bro -b %INPUT >out +# @TEST-EXEC: btest-diff out + +function check(o1: any, o2: any, equal: bool, expect_same: bool) + { + local expect_msg = (equal ? "ok" : "FAIL0"); + local same = same_object(o1, o2); + + if ( expect_same && ! same ) + expect_msg = "FAIL1"; + + if ( ! expect_same && same ) + expect_msg = "FAIL2"; + + print fmt("orig=%s (%s) clone=%s (%s) equal=%s same_object=%s (%s)", o1, type_name(o1), o2, type_name(o2), equal, same, expect_msg); + } + +event zeek_init() + { + local i1 = -42; + local i2 = copy(i1); + check(i1, i2, i1 == i2, T); + + local s1 = "Foo"; + local s2 = copy(s1); + check(s1, s2, s1 == s2, F); + }