diff --git a/src/Val.cc b/src/Val.cc index e76a62d0f8..0fbbc20979 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -2788,13 +2788,17 @@ RecordVal::RecordTypeValMap RecordVal::parse_time_records; RecordVal::RecordVal(RecordTypePtr t, bool init_fields) : Val(std::move(t)) { origin = nullptr; - auto rt = GetType()->AsRecordType(); + rt = {NewRef{}, GetType()->AsRecordType()}; + int n = rt->NumFields(); - auto vl = record_val = new std::vector; - vl->reserve(n); + + record_val = new std::vector; + record_val->reserve(n); + + is_in_record = new std::vector(n, false); if ( run_state::is_parsing ) - parse_time_records[rt].emplace_back(NewRef{}, this); + parse_time_records[rt.get()].emplace_back(NewRef{}, this); if ( ! init_fields ) return; @@ -2815,9 +2819,10 @@ RecordVal::RecordVal(RecordTypePtr t, bool init_fields) : Val(std::move(t)) catch ( InterpreterException& ) { if ( run_state::is_parsing ) - parse_time_records[rt].pop_back(); + parse_time_records[rt.get()].pop_back(); delete record_val; + delete is_in_record; throw; } @@ -2848,13 +2853,23 @@ RecordVal::RecordVal(RecordTypePtr t, bool init_fields) : Val(std::move(t)) def = make_intrusive(cast_intrusive(type)); } - vl->emplace_back(std::move(def)); + if ( def ) + { + record_val->emplace_back(ZVal(def, def->GetType())); + (*is_in_record)[i] = true; + } + else + { + record_val->emplace_back(ZVal()); + (*is_in_record)[i] = false; + } } } RecordVal::~RecordVal() { delete record_val; + delete is_in_record; } ValPtr RecordVal::SizeVal() const @@ -2864,13 +2879,28 @@ ValPtr RecordVal::SizeVal() const void RecordVal::Assign(int field, ValPtr new_val) { - (*record_val)[field] = std::move(new_val); + auto t = rt->GetFieldType(field); + + if ( new_val ) + { + (*record_val)[field] = ZVal(new_val, t); + (*is_in_record)[field] = true; + } + else + { + if ( HasField(field) ) + DeleteIfManaged((*record_val)[field], t); + + (*record_val)[field] = ZVal(); + (*is_in_record)[field] = false; + } + Modified(); } ValPtr RecordVal::GetFieldOrDefault(int field) const { - const auto& val = (*record_val)[field]; + auto val = GetField(field); if ( val ) return val; @@ -2878,9 +2908,9 @@ ValPtr RecordVal::GetFieldOrDefault(int field) const return GetType()->AsRecordType()->FieldDefault(field); } -void RecordVal::ResizeParseTimeRecords(RecordType* rt) +void RecordVal::ResizeParseTimeRecords(RecordType* revised_rt) { - auto it = parse_time_records.find(rt); + auto it = parse_time_records.find(revised_rt); if ( it == parse_time_records.end() ) return; @@ -2890,14 +2920,14 @@ void RecordVal::ResizeParseTimeRecords(RecordType* rt) for ( auto& rv : rvs ) { int current_length = rv->NumFields(); - auto required_length = rt->NumFields(); + auto required_length = revised_rt->NumFields(); if ( required_length > current_length ) { rv->Reserve(required_length); for ( auto i = current_length; i < required_length; ++i ) - rv->AppendField(rt->FieldDefault(i)); + rv->AppendField(revised_rt->FieldDefault(i)); } } } @@ -2907,7 +2937,7 @@ void RecordVal::DoneParsing() parse_time_records.clear(); } -const ValPtr& RecordVal::GetField(const char* field) const +ValPtr RecordVal::GetField(const char* field) const { int idx = GetType()->AsRecordType()->FieldOffset(field); @@ -3007,11 +3037,10 @@ TableValPtr RecordVal::GetRecordFieldsVal() const void RecordVal::Describe(ODesc* d) const { auto n = record_val->size(); - auto record_type = GetType()->AsRecordType(); if ( d->IsBinary() || d->IsPortable() ) { - record_type->Describe(d); + rt->Describe(d); d->SP(); d->Add(static_cast(n)); d->SP(); @@ -3024,12 +3053,12 @@ void RecordVal::Describe(ODesc* d) const if ( ! d->IsBinary() && i > 0 ) d->Add(", "); - d->Add(record_type->FieldName(i)); + d->Add(rt->FieldName(i)); if ( ! d->IsBinary() ) d->Add("="); - const auto& v = (*record_val)[i]; + auto v = GetField(i); if ( v ) v->Describe(d); @@ -3044,7 +3073,7 @@ void RecordVal::Describe(ODesc* d) const void RecordVal::DescribeReST(ODesc* d) const { auto n = record_val->size(); - auto record_type = GetType()->AsRecordType(); + auto rt = GetType()->AsRecordType(); d->Add("{"); d->PushIndent(); @@ -3054,10 +3083,10 @@ void RecordVal::DescribeReST(ODesc* d) const if ( i > 0 ) d->NL(); - d->Add(record_type->FieldName(i)); + d->Add(rt->FieldName(i)); d->Add("="); - const auto& v = (*record_val)[i]; + auto v = GetField(i); if ( v ) v->Describe(d); @@ -3080,9 +3109,11 @@ ValPtr RecordVal::DoClone(CloneState* state) rv->origin = nullptr; state->NewClone(this, rv); - for ( const auto& vlv : *record_val) + int n = NumFields(); + for ( auto i = 0; i < n; ++i ) { - auto v = vlv ? vlv->Clone(state) : nullptr; + auto f_i = GetField(i); + auto v = f_i ? f_i->Clone(state) : nullptr; rv->AppendField(std::move(v)); } @@ -3092,16 +3123,25 @@ ValPtr RecordVal::DoClone(CloneState* state) unsigned int RecordVal::MemoryAllocation() const { unsigned int size = 0; - const auto& vl = *record_val; - for ( const auto& v : vl ) + int n = NumFields(); + for ( auto i = 0; i < n; ++i ) { - if ( v ) - size += v->MemoryAllocation(); + auto f_i = GetField(i); + if ( f_i ) + size += f_i->MemoryAllocation(); } - size += util::pad_size(vl.capacity() * sizeof(ValPtr)); - size += padded_sizeof(vl); + size += util::pad_size(record_val->capacity() * sizeof(ZVal)); + size += padded_sizeof(*record_val); + + // It's tricky sizing is_in_record since it's a std::vector + // specialization. We approximate this by not scaling capacity() + // by sizeof(bool) but just using its raw value. That's still + // presumably going to be an overestimate. + size += util::pad_size(is_in_record->capacity()); + size += padded_sizeof(*is_in_record); + return size + padded_sizeof(*this); } diff --git a/src/Val.h b/src/Val.h index 0efe90f3db..b242b5ba3a 100644 --- a/src/Val.h +++ b/src/Val.h @@ -15,6 +15,7 @@ #include "zeek/Reporter.h" #include "zeek/net_util.h" #include "zeek/Dict.h" +#include "zeek/ZVal.h" // We have four different port name spaces: TCP, UDP, ICMP, and UNKNOWN. // We distinguish between them based on the bits specified in the *_PORT_MASK @@ -1023,12 +1024,23 @@ public: /** * Appends a value to the record's fields. The caller is responsible - * for ensuring that fields are appended in the correct orer and + * for ensuring that fields are appended in the correct order and * with the correct type. * @param v The value to append. */ void AppendField(ValPtr v) - { record_val->emplace_back(std::move(v)); } + { + if ( v ) + { + (*is_in_record)[record_val->size()] = true; + record_val->emplace_back(ZVal(v, v->GetType())); + } + else + { + (*is_in_record)[record_val->size()] = false; + record_val->emplace_back(ZVal()); + } + } /** * Ensures that the record has enough internal storage for the @@ -1036,22 +1048,44 @@ public: * @param n The number of fields. */ void Reserve(unsigned int n) - { record_val->reserve(n); } + { + record_val->reserve(n); + is_in_record->reserve(n); + + for ( auto i = is_in_record->size(); i < n; ++i ) + is_in_record->emplace_back(false); + } /** * Returns the number of fields in the record. * @return The number of fields in the record. */ - unsigned int NumFields() + unsigned int NumFields() const { return record_val->size(); } + /** + * Returns true if the given field is in the record, false if + * it's missing. + * @param field The field index to retrieve. + * @return Whether there's a value for the given field index. + */ + bool HasField(int field) const + { + return (*is_in_record)[field]; + } + /** * Returns the value of a given field index. * @param field The field index to retrieve. * @return The value at the given field index. */ - const ValPtr& GetField(int field) const - { return (*record_val)[field]; } + ValPtr GetField(int field) const + { + if ( ! HasField(field) ) + return nullptr; + + return (*record_val)[field].ToVal(rt->GetFieldType(field)); + } /** * Returns the value of a given field index as cast to type @c T. @@ -1078,7 +1112,7 @@ public: * @return The value of the given field. If no such field name exists, * a fatal error occurs. */ - const ValPtr& GetField(const char* field) const; + ValPtr GetField(const char* field) const; /** * Returns the value of a given field name as cast to type @c T. @@ -1119,7 +1153,7 @@ public: template auto GetFieldAs(int field) const -> std::invoke_result_t { - auto& field_ptr = GetField(field); + auto field_ptr = GetField(field); auto field_val_ptr = static_cast(field_ptr.get()); return field_val_ptr->Get(); } @@ -1127,7 +1161,7 @@ public: template auto GetFieldAs(const char* field) const -> std::invoke_result_t { - auto& field_ptr = GetField(field); + auto field_ptr = GetField(field); auto field_val_ptr = static_cast(field_ptr.get()); return field_val_ptr->Get(); } @@ -1182,7 +1216,16 @@ protected: static RecordTypeValMap parse_time_records; private: - std::vector* record_val; + // Keep this handy for quick access during low-level operations. + RecordTypePtr rt; + + // Low-level values of each of the fields. + std::vector* record_val; + + // Whether a given field exists - for optional fields, and because + // Zeek does not enforce that non-optional fields are actually + // present. + std::vector* is_in_record; }; class EnumVal final : public detail::IntValImplementation { diff --git a/src/ZVal.cc b/src/ZVal.cc index 937dc27aae..17f94ef825 100644 --- a/src/ZVal.cc +++ b/src/ZVal.cc @@ -1,6 +1,5 @@ // See the file "COPYING" in the main distribution directory for copyright. -#include "zeek/ZVal.h" #include "zeek/ZeekString.h" #include "zeek/File.h" #include "zeek/Func.h" diff --git a/src/file_analysis/analyzer/extract/Extract.cc b/src/file_analysis/analyzer/extract/Extract.cc index 001c49ba01..6520a6dc13 100644 --- a/src/file_analysis/analyzer/extract/Extract.cc +++ b/src/file_analysis/analyzer/extract/Extract.cc @@ -34,7 +34,7 @@ Extract::~Extract() util::safe_close(fd); } -static const ValPtr& get_extract_field_val(const RecordValPtr& args, +static ValPtr get_extract_field_val(const RecordValPtr& args, const char* name) { const auto& rval = args->GetField(name);