&expire_func(table, arg1, arg2, ...) + type checking.

This commit is contained in:
Zeke Medley 2019-07-18 12:11:27 -07:00
parent c22edc28a5
commit 0ca6b3e013
5 changed files with 110 additions and 25 deletions

View file

@ -436,22 +436,41 @@ void Attributes::CheckAttr(Attr* a)
} }
const Expr* expire_func = a->AttrExpr(); const Expr* expire_func = a->AttrExpr();
if ( expire_func->Type()->Tag() != TYPE_FUNC )
Error("&expire_func attribute is not a function");
const FuncType* e_ft = expire_func->Type()->AsFuncType(); const FuncType* e_ft = expire_func->Type()->AsFuncType();
if ( ((const BroType*) e_ft)->YieldType()->Tag() != TYPE_INTERVAL ) if ( e_ft->YieldType()->Tag() != TYPE_INTERVAL )
{ {
Error("&expire_func must yield a value of type interval"); Error("&expire_func must yield a value of type interval");
break; break;
} }
if ( e_ft->Args()->NumFields() != 2 ) const TableType* the_table = type->AsTableType();
if (the_table->IsUnspecifiedTable())
break;
const type_list* func_index_types = e_ft->ArgTypes()->Types();
// Keep backwards compatability wth idx: any idiom.
if ( func_index_types->length() == 2 )
{ {
Error("&expire_func function must take exactly two arguments"); if ((*func_index_types)[1]->Tag() == TYPE_ANY)
break; break;
} }
// ### Should type-check arguments to make sure first is const type_list* table_index_types = the_table->IndexTypes();
// table type and second is table index type.
type_list expected_args;
expected_args.push_back(type->AsTableType());
for (const auto& t : *table_index_types)
{
expected_args.push_back(t);
}
if ( ! e_ft->CheckArgs(&expected_args) )
Error("&expire_func argument type clash");
} }
break; break;

View file

@ -587,11 +587,17 @@ int FuncType::CheckArgs(const type_list* args, bool is_init) const
const type_list* my_args = arg_types->Types(); const type_list* my_args = arg_types->Types();
if ( my_args->length() != args->length() ) if ( my_args->length() != args->length() )
{
Warn(fmt("Wrong number of arguments for function. Expected %d, got %d.",
args->length(), my_args->length()));
return 0; return 0;
}
for ( int i = 0; i < my_args->length(); ++i ) for ( int i = 0; i < my_args->length(); ++i )
if ( ! same_type((*args)[i], (*my_args)[i], is_init) ) if ( ! same_type((*args)[i], (*my_args)[i], is_init) )
return 0; {
Warn(fmt("Type mismatch in function arguments. Expected %s, got %s.",
type_name((*args)[i]->Tag()), type_name((*my_args)[i]->Tag())));
}
return 1; return 1;
} }

View file

@ -2306,24 +2306,46 @@ double TableVal::CallExpireFunc(Val* idx)
return 0; return 0;
} }
const Func* f = vf->AsFunc();
val_list vl { Ref() };
// Flatten lists of a single element. const auto func_args = f->FType()->ArgTypes()->Types();
if ( idx->Type()->Tag() == TYPE_LIST &&
idx->AsListVal()->Length() == 1 ) // backwards compatability with idx: any idiom
bool any_idiom = func_args->length() == 2 && func_args->back()->Tag() == TYPE_ANY;
if ( idx->Type()->Tag() == TYPE_LIST )
{ {
Val* old = idx; if ( ! any_idiom )
idx = idx->AsListVal()->Index(0); {
idx->Ref(); const val_list* vl0 = idx->AsListVal()->Vals();
Unref(old); for ( const auto& v : *idx->AsListVal()->Vals() )
vl.append(v->Ref());
}
else
{
ListVal* idx_list = idx->AsListVal();
// Flatten if only one element
if (idx_list->Length() == 1)
{
idx = idx_list->Index(0);
}
vl.append(idx->Ref());
}
}
else
{
vl.append(idx->Ref());
} }
val_list vl{Ref(), idx}; Val* result = 0;
Val* vs = vf->AsFunc()->Call(&vl);
if ( vs ) result = f->Call(&vl);
if ( result )
{ {
secs = vs->AsInterval(); secs = result->AsInterval();
Unref(vs); Unref(result);
} }
Unref(vf); Unref(vf);

View file

@ -1,10 +1,20 @@
starting: ashish, 1 starting: ashish, 1
starting: ashish, 1
inside table_expire_func: ashish, 2 inside table_expire_func: ashish, 2
inside table_expire_func: [ashish, ashish], 2
inside table_expire_func: ashish, 3 inside table_expire_func: ashish, 3
inside table_expire_func: [ashish, ashish], 3
inside table_expire_func: ashish, 4 inside table_expire_func: ashish, 4
inside table_expire_func: [ashish, ashish], 4
inside table_expire_func: ashish, 5 inside table_expire_func: ashish, 5
inside table_expire_func: [ashish, ashish], 5
inside table_expire_func: ashish, 6 inside table_expire_func: ashish, 6
inside table_expire_func: [ashish, ashish], 6
inside table_expire_func: ashish, 7 inside table_expire_func: ashish, 7
inside table_expire_func: [ashish, ashish], 7
inside table_expire_func: ashish, 8 inside table_expire_func: ashish, 8
inside table_expire_func: [ashish, ashish], 8
inside table_expire_func: ashish, 9 inside table_expire_func: ashish, 9
inside table_expire_func: [ashish, ashish], 9
inside table_expire_func: ashish, 10 inside table_expire_func: ashish, 10
inside table_expire_func: [ashish, ashish], 10

View file

@ -9,20 +9,33 @@ redef table_expire_interval = .1 secs ;
export { export {
global table_expire_func: function(t: table[string] of count, global table_expire_func: function(t: table[string] of count,
s: string): interval; s: string): interval;
global table_expire_func2: function(t: table[string, string, string] of count,
s: string, s2: string, s3: string): interval;
global t: table[string] of count global t: table[string] of count
&write_expire=0 secs &write_expire=0 secs
&expire_func=table_expire_func; &expire_func=table_expire_func;
global tt: table[string, string, string] of count
&write_expire=0 secs
&expire_func=table_expire_func2;
} }
global die_count = 0;
event die() event die()
{ {
if (die_count < 1)
{
++die_count;
return;
}
terminate(); terminate();
} }
function table_expire_func(t: table[string] of count, s: string): interval function table_expire_func(t: table[string] of count, s: string): interval
{ {
t[s] += 1 ; t[s] = t[s] + 1 ;
print fmt("inside table_expire_func: %s, %s", s, t[s]); print fmt("inside table_expire_func: %s, %s", s, t[s]);
@ -33,9 +46,24 @@ function table_expire_func(t: table[string] of count, s: string): interval
return 0 secs; return 0 secs;
} }
function table_expire_func2 (tt: table[string, string, string] of count, s: string, s2: string, s3: string): interval
{
tt[s, s2, s3] += 1;
print fmt("inside table_expire_func: [%s, %s], %s", s, s2, tt[s, s2, s3]);
if ( tt[s, s2, s3] < 10 )
return .1 secs ;
schedule .1sec { die() };
return 0 secs;
}
event zeek_init() event zeek_init()
{ {
local s="ashish"; local s = "ashish";
t[s] = 1 ; t[s] = 1 ;
tt[s, s, s] = 1;
print fmt("starting: %s, %s", s, t[s]); print fmt("starting: %s, %s", s, t[s]);
print fmt("starting: %s, %s", s, tt[s, s, s]);
} }