diff --git a/src/script_opt/CPP/GenFunc.cc b/src/script_opt/CPP/GenFunc.cc index 7a36421766..598df3e0fb 100644 --- a/src/script_opt/CPP/GenFunc.cc +++ b/src/script_opt/CPP/GenFunc.cc @@ -149,6 +149,7 @@ void CPPCompile::DeclareLocals(const ProfileFunc* pf, const IDPList* lambda_ids) capture_names.insert(CaptureName(li)); const auto& ls = pf->Locals(); + int num_params = static_cast(pf->Params().size()); // Track whether we generated a declaration. This is just for // tidiness in the output. @@ -162,7 +163,7 @@ void CPPCompile::DeclareLocals(const ProfileFunc* pf, const IDPList* lambda_ids) // No need to declare these, they're passed in as parameters. ln = cn; - else if ( params.count(l) == 0 ) { // Not a parameter, so must be a local. + else if ( params.count(l) == 0 && l->Offset() >= num_params ) { // Not a parameter, so must be a local. Emit("%s %s;", FullTypeName(l->GetType()), ln); did_decl = true; } diff --git a/src/script_opt/ProfileFunc.cc b/src/script_opt/ProfileFunc.cc index 96b1fa8961..e844bfb566 100644 --- a/src/script_opt/ProfileFunc.cc +++ b/src/script_opt/ProfileFunc.cc @@ -41,7 +41,20 @@ ProfileFunc::ProfileFunc(const Func* func, const StmtPtr& body, bool _abs_rec_fi } } - Profile(profiled_func_t.get(), body); + TrackType(profiled_func_t); + body->Traverse(this); + + // Examine the locals and identify the parameters based on their offsets + // (being careful not to be fooled by captures that incidentally have low + // offsets). This approach allows us to accommodate function definitions + // that use different parameter names than appear in the original + // declaration. + num_params = profiled_func_t->Params()->NumFields(); + + for ( auto l : locals ) { + if ( captures.count(l) == 0 && l->Offset() < num_params ) + params.insert(l); + } } ProfileFunc::ProfileFunc(const Stmt* s, bool _abs_rec_fields) { @@ -68,7 +81,17 @@ ProfileFunc::ProfileFunc(const Expr* e, bool _abs_rec_fields) { captures_offsets[oid] = offset++; } - Profile(func->GetType()->AsFuncType(), func->Ingredients()->Body()); + auto ft = func->GetType()->AsFuncType(); + auto& body = func->Ingredients()->Body(); + + num_params = ft->Params()->NumFields(); + + auto& ov = profiled_scope->OrderedVars(); + for ( int i = 0; i < num_params; ++i ) + params.insert(ov[i].get()); + + TrackType(ft); + body->Traverse(this); } else @@ -77,19 +100,6 @@ ProfileFunc::ProfileFunc(const Expr* e, bool _abs_rec_fields) { e->Traverse(this); } -void ProfileFunc::Profile(const FuncType* ft, const StmtPtr& body) { - num_params = ft->Params()->NumFields(); - - assert(profiled_scope != nullptr); - - auto& ov = profiled_scope->OrderedVars(); - for ( int i = 0; i < num_params; ++i ) - params.insert(ov[i].get()); - - TrackType(ft); - body->Traverse(this); -} - TraversalCode ProfileFunc::PreStmt(const Stmt* s) { stmts.push_back(s);