diff --git a/src/Expr.cc b/src/Expr.cc index afd9b7a695..0690bf74b6 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -3018,7 +3018,10 @@ IntrusivePtr RecordConstructorExpr::InitVal(const BroType* t, IntrusivePtr< if ( v ) { RecordVal* rv = v->AsRecordVal(); - auto ar = rv->CoerceTo(t->AsRecordType(), aggr.release()); + auto bt = const_cast(t); + IntrusivePtr rt{NewRef{}, bt->AsRecordType()}; + auto aggr_rec = cast_intrusive(std::move(aggr)); + auto ar = rv->CoerceTo(std::move(rt), std::move(aggr_rec)); if ( ar ) return ar; @@ -3633,7 +3636,11 @@ IntrusivePtr RecordCoerceExpr::InitVal(const BroType* t, IntrusivePtr if ( auto v = Eval(nullptr) ) { RecordVal* rv = v->AsRecordVal(); - if ( auto ar = rv->CoerceTo(t->AsRecordType(), aggr.release()) ) + auto bt = const_cast(t); + IntrusivePtr rt{NewRef{}, bt->AsRecordType()}; + auto aggr_rec = cast_intrusive(std::move(aggr)); + + if ( auto ar = rv->CoerceTo(std::move(rt), std::move(aggr_rec)) ) return ar; } @@ -3673,19 +3680,19 @@ IntrusivePtr RecordCoerceExpr::Fold(Val* v) const } BroType* rhs_type = rhs->GetType().get(); - BroType* field_type = val_type->GetFieldType(i).get(); + const auto& field_type = val_type->GetFieldType(i); if ( rhs_type->Tag() == TYPE_RECORD && field_type->Tag() == TYPE_RECORD && - ! same_type(rhs_type, field_type) ) + ! same_type(rhs_type, field_type.get()) ) { - if ( auto new_val = rhs->AsRecordVal()->CoerceTo(field_type->AsRecordType()) ) + if ( auto new_val = rhs->AsRecordVal()->CoerceTo(cast_intrusive(field_type)) ) rhs = std::move(new_val); } else if ( BothArithmetic(rhs_type->Tag(), field_type->Tag()) && - ! same_type(rhs_type, field_type) ) + ! same_type(rhs_type, field_type.get()) ) { - if ( auto new_val = check_and_promote(rhs, field_type, false, op->GetLocationInfo()) ) + if ( auto new_val = check_and_promote(rhs, field_type.get(), false, op->GetLocationInfo()) ) rhs = std::move(new_val); else RuntimeError("Failed type conversion"); @@ -3706,7 +3713,7 @@ IntrusivePtr RecordCoerceExpr::Fold(Val* v) const ! same_type(def_type.get(), field_type.get()) ) { auto tmp = def_val->AsRecordVal()->CoerceTo( - field_type->AsRecordType()); + cast_intrusive(field_type)); if ( tmp ) def_val = std::move(tmp); diff --git a/src/Val.cc b/src/Val.cc index 62ffb512ce..2d642fc057 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -2694,7 +2694,7 @@ RecordVal::RecordVal(IntrusivePtr t, bool init_fields) : Val(std::mo def->GetType()->Tag() == TYPE_RECORD && ! same_type(def->GetType().get(), type.get()) ) { - auto tmp = def->AsRecordVal()->CoerceTo(type->AsRecordType()); + auto tmp = def->AsRecordVal()->CoerceTo(cast_intrusive(type)); if ( tmp ) def = std::move(tmp); @@ -2796,17 +2796,17 @@ IntrusivePtr RecordVal::Lookup(const char* field, bool with_default) const return with_default ? LookupWithDefault(idx) : IntrusivePtr{NewRef{}, Lookup(idx)}; } -IntrusivePtr RecordVal::CoerceTo(const RecordType* t, Val* aggr, bool allow_orphaning) const +IntrusivePtr RecordVal::CoerceTo(IntrusivePtr t, + IntrusivePtr aggr, + bool allow_orphaning) const { - if ( ! record_promotion_compatible(t->AsRecordType(), GetType()->AsRecordType()) ) + if ( ! record_promotion_compatible(t.get(), GetType()->AsRecordType()) ) return nullptr; if ( ! aggr ) - aggr = new RecordVal({NewRef{}, const_cast(t)}); + aggr = make_intrusive(std::move(t)); - RecordVal* ar = aggr->AsRecordVal(); RecordType* ar_t = aggr->GetType()->AsRecordType(); - const RecordType* rv_t = GetType()->AsRecordType(); int i; @@ -2840,15 +2840,15 @@ IntrusivePtr RecordVal::CoerceTo(const RecordType* t, Val* aggr, bool auto rhs = make_intrusive(IntrusivePtr{NewRef{}, v}); auto e = make_intrusive(std::move(rhs), cast_intrusive(ft)); - ar->Assign(t_i, e->Eval(nullptr)); + aggr->Assign(t_i, e->Eval(nullptr)); continue; } - ar->Assign(t_i, {NewRef{}, v}); + aggr->Assign(t_i, {NewRef{}, v}); } for ( i = 0; i < ar_t->NumFields(); ++i ) - if ( ! ar->Lookup(i) && + if ( ! aggr->Lookup(i) && ! ar_t->FieldDecl(i)->FindAttr(ATTR_OPTIONAL) ) { char buf[512]; @@ -2857,15 +2857,16 @@ IntrusivePtr RecordVal::CoerceTo(const RecordType* t, Val* aggr, bool Error(buf); } - return {AdoptRef{}, ar}; + return aggr; } -IntrusivePtr RecordVal::CoerceTo(RecordType* t, bool allow_orphaning) +IntrusivePtr RecordVal::CoerceTo(IntrusivePtr t, + bool allow_orphaning) { - if ( same_type(GetType().get(), t) ) + if ( same_type(GetType().get(), t.get()) ) return {NewRef{}, this}; - return CoerceTo(t, nullptr, allow_orphaning); + return CoerceTo(std::move(t), nullptr, allow_orphaning); } IntrusivePtr RecordVal::GetRecordFieldsVal() const diff --git a/src/Val.h b/src/Val.h index d272dd4787..37323c9a2f 100644 --- a/src/Val.h +++ b/src/Val.h @@ -983,8 +983,11 @@ public: // // The *allow_orphaning* parameter allows for a record to be demoted // down to a record type that contains less fields. - IntrusivePtr CoerceTo(const RecordType* other, Val* aggr, bool allow_orphaning = false) const; - IntrusivePtr CoerceTo(RecordType* other, bool allow_orphaning = false); + IntrusivePtr CoerceTo(IntrusivePtr other, + IntrusivePtr aggr, + bool allow_orphaning = false) const; + IntrusivePtr CoerceTo(IntrusivePtr other, + bool allow_orphaning = false); unsigned int MemoryAllocation() const override; void DescribeReST(ODesc* d) const override; diff --git a/src/file_analysis/analyzer/extract/functions.bif b/src/file_analysis/analyzer/extract/functions.bif index 6d0ac3435d..13cc904c4f 100644 --- a/src/file_analysis/analyzer/extract/functions.bif +++ b/src/file_analysis/analyzer/extract/functions.bif @@ -11,7 +11,7 @@ module FileExtract; function FileExtract::__set_limit%(file_id: string, args: any, n: count%): bool %{ using zeek::BifType::Record::Files::AnalyzerArgs; - auto rv = args->AsRecordVal()->CoerceTo(AnalyzerArgs.get()); + auto rv = args->AsRecordVal()->CoerceTo(AnalyzerArgs); bool result = file_mgr->SetExtractionLimit(file_id->CheckString(), rv.get(), n); return val_mgr->Bool(result); %} diff --git a/src/file_analysis/file_analysis.bif b/src/file_analysis/file_analysis.bif index e15163dd83..1f74668dd4 100644 --- a/src/file_analysis/file_analysis.bif +++ b/src/file_analysis/file_analysis.bif @@ -42,7 +42,7 @@ function Files::__set_reassembly_buffer%(file_id: string, max: count%): bool function Files::__add_analyzer%(file_id: string, tag: Files::Tag, args: any%): bool %{ using zeek::BifType::Record::Files::AnalyzerArgs; - auto rv = args->AsRecordVal()->CoerceTo(AnalyzerArgs.get()); + auto rv = args->AsRecordVal()->CoerceTo(AnalyzerArgs); bool result = file_mgr->AddAnalyzer(file_id->CheckString(), file_mgr->GetComponentTag(tag), rv.get()); return val_mgr->Bool(result); @@ -52,7 +52,7 @@ function Files::__add_analyzer%(file_id: string, tag: Files::Tag, args: any%): b function Files::__remove_analyzer%(file_id: string, tag: Files::Tag, args: any%): bool %{ using zeek::BifType::Record::Files::AnalyzerArgs; - auto rv = args->AsRecordVal()->CoerceTo(AnalyzerArgs.get()); + auto rv = args->AsRecordVal()->CoerceTo(AnalyzerArgs); bool result = file_mgr->RemoveAnalyzer(file_id->CheckString(), file_mgr->GetComponentTag(tag) , rv.get()); return val_mgr->Bool(result); diff --git a/src/logging/Manager.cc b/src/logging/Manager.cc index 04704fc8bf..a2773c7c22 100644 --- a/src/logging/Manager.cc +++ b/src/logging/Manager.cc @@ -701,7 +701,7 @@ bool Manager::Write(EnumVal* id, RecordVal* columns_arg) if ( ! stream->enabled ) return true; - auto columns = columns_arg->CoerceTo(stream->columns); + auto columns = columns_arg->CoerceTo({NewRef{}, stream->columns}); if ( ! columns ) { @@ -747,7 +747,7 @@ bool Manager::Write(EnumVal* id, RecordVal* columns_arg) const auto& rt = filter->path_func->GetType()->Params()->GetFieldType("rec"); if ( rt->Tag() == TYPE_RECORD ) - rec_arg = columns->CoerceTo(rt->AsRecordType(), true); + rec_arg = columns->CoerceTo(cast_intrusive(rt), true); else // Can be TYPE_ANY here. rec_arg = columns;