diff --git a/src/Attr.cc b/src/Attr.cc index 3001438bc3..c7d146a821 100644 --- a/src/Attr.cc +++ b/src/Attr.cc @@ -434,24 +434,43 @@ void Attributes::CheckAttr(Attr* a) Error("expiration only applicable to tables"); break; } - + const Expr* expire_func = a->AttrExpr(); + + if ( expire_func->Type()->Tag() != TYPE_FUNC ) + Error("&expire_func attribute is not a function"); + const FuncType* e_ft = expire_func->Type()->AsFuncType(); - if ( ((const BroType*) e_ft)->YieldType()->Tag() != TYPE_INTERVAL ) + if ( e_ft->YieldType()->Tag() != TYPE_INTERVAL ) { Error("&expire_func must yield a value of type interval"); break; } - if ( e_ft->Args()->NumFields() != 2 ) - { - Error("&expire_func function must take exactly two arguments"); + const TableType* the_table = type->AsTableType(); + + if (the_table->IsUnspecifiedTable()) break; + + const type_list* func_index_types = e_ft->ArgTypes()->Types(); + // Keep backwards compatability wth idx: any idiom. + if ( func_index_types->length() == 2 ) + { + if ((*func_index_types)[1]->Tag() == TYPE_ANY) + break; } - // ### Should type-check arguments to make sure first is - // table type and second is table index type. + const type_list* table_index_types = the_table->IndexTypes(); + + type_list expected_args; + expected_args.push_back(type->AsTableType()); + for (const auto& t : *table_index_types) + { + expected_args.push_back(t); + } + if ( ! e_ft->CheckArgs(&expected_args) ) + Error("&expire_func argument type clash"); } break; diff --git a/src/Type.cc b/src/Type.cc index 498763a0f8..af3d9be3c1 100644 --- a/src/Type.cc +++ b/src/Type.cc @@ -587,11 +587,17 @@ int FuncType::CheckArgs(const type_list* args, bool is_init) const const type_list* my_args = arg_types->Types(); if ( my_args->length() != args->length() ) + { + Warn(fmt("Wrong number of arguments for function. Expected %d, got %d.", + args->length(), my_args->length())); return 0; - + } for ( int i = 0; i < my_args->length(); ++i ) if ( ! same_type((*args)[i], (*my_args)[i], is_init) ) - return 0; + { + Warn(fmt("Type mismatch in function arguments. Expected %s, got %s.", + type_name((*args)[i]->Tag()), type_name((*my_args)[i]->Tag()))); + } return 1; } diff --git a/src/Val.cc b/src/Val.cc index 7372c12f66..87a7309dea 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -2306,24 +2306,46 @@ double TableVal::CallExpireFunc(Val* idx) return 0; } + const Func* f = vf->AsFunc(); + val_list vl { Ref() }; - // Flatten lists of a single element. - if ( idx->Type()->Tag() == TYPE_LIST && - idx->AsListVal()->Length() == 1 ) + const auto func_args = f->FType()->ArgTypes()->Types(); + + // backwards compatability with idx: any idiom + bool any_idiom = func_args->length() == 2 && func_args->back()->Tag() == TYPE_ANY; + + if ( idx->Type()->Tag() == TYPE_LIST ) { - Val* old = idx; - idx = idx->AsListVal()->Index(0); - idx->Ref(); - Unref(old); + if ( ! any_idiom ) + { + const val_list* vl0 = idx->AsListVal()->Vals(); + for ( const auto& v : *idx->AsListVal()->Vals() ) + vl.append(v->Ref()); + } + else + { + ListVal* idx_list = idx->AsListVal(); + // Flatten if only one element + if (idx_list->Length() == 1) + { + idx = idx_list->Index(0); + } + vl.append(idx->Ref()); + } + } + else + { + vl.append(idx->Ref()); } - val_list vl{Ref(), idx}; - Val* vs = vf->AsFunc()->Call(&vl); + Val* result = 0; - if ( vs ) + result = f->Call(&vl); + + if ( result ) { - secs = vs->AsInterval(); - Unref(vs); + secs = result->AsInterval(); + Unref(result); } Unref(vf); diff --git a/testing/btest/Baseline/language.expire_func_mod/out b/testing/btest/Baseline/language.expire_func_mod/out index 8790608ec1..a11dad96aa 100644 --- a/testing/btest/Baseline/language.expire_func_mod/out +++ b/testing/btest/Baseline/language.expire_func_mod/out @@ -1,10 +1,20 @@ starting: ashish, 1 +starting: ashish, 1 inside table_expire_func: ashish, 2 +inside table_expire_func: [ashish, ashish], 2 inside table_expire_func: ashish, 3 +inside table_expire_func: [ashish, ashish], 3 inside table_expire_func: ashish, 4 +inside table_expire_func: [ashish, ashish], 4 inside table_expire_func: ashish, 5 +inside table_expire_func: [ashish, ashish], 5 inside table_expire_func: ashish, 6 +inside table_expire_func: [ashish, ashish], 6 inside table_expire_func: ashish, 7 +inside table_expire_func: [ashish, ashish], 7 inside table_expire_func: ashish, 8 +inside table_expire_func: [ashish, ashish], 8 inside table_expire_func: ashish, 9 +inside table_expire_func: [ashish, ashish], 9 inside table_expire_func: ashish, 10 +inside table_expire_func: [ashish, ashish], 10 diff --git a/testing/btest/language/expire_func_mod.zeek b/testing/btest/language/expire_func_mod.zeek index 4e64edc968..bdb6d19cca 100644 --- a/testing/btest/language/expire_func_mod.zeek +++ b/testing/btest/language/expire_func_mod.zeek @@ -8,21 +8,34 @@ redef table_expire_interval = .1 secs ; export { global table_expire_func: function(t: table[string] of count, - s: string): interval; + s: string): interval; + global table_expire_func2: function(t: table[string, string, string] of count, + s: string, s2: string, s3: string): interval; global t: table[string] of count &write_expire=0 secs &expire_func=table_expire_func; + + global tt: table[string, string, string] of count + &write_expire=0 secs + &expire_func=table_expire_func2; } +global die_count = 0; + event die() { + if (die_count < 1) + { + ++die_count; + return; + } terminate(); } function table_expire_func(t: table[string] of count, s: string): interval { - t[s] += 1 ; + t[s] = t[s] + 1 ; print fmt("inside table_expire_func: %s, %s", s, t[s]); @@ -33,9 +46,24 @@ function table_expire_func(t: table[string] of count, s: string): interval return 0 secs; } +function table_expire_func2 (tt: table[string, string, string] of count, s: string, s2: string, s3: string): interval + { + tt[s, s2, s3] += 1; + + print fmt("inside table_expire_func: [%s, %s], %s", s, s2, tt[s, s2, s3]); + + if ( tt[s, s2, s3] < 10 ) + return .1 secs ; + + schedule .1sec { die() }; + return 0 secs; + } + event zeek_init() { - local s="ashish"; + local s = "ashish"; t[s] = 1 ; + tt[s, s, s] = 1; print fmt("starting: %s, %s", s, t[s]); - } + print fmt("starting: %s, %s", s, tt[s, s, s]); + } \ No newline at end of file