diff --git a/src/script_opt/ZAM/Compile.h b/src/script_opt/ZAM/Compile.h index 5cddb776bc..0ac519bf4a 100644 --- a/src/script_opt/ZAM/Compile.h +++ b/src/script_opt/ZAM/Compile.h @@ -228,8 +228,8 @@ private: const ZAMStmt CompileIndex(const NameExpr* n1, int n2_slot, const TypePtr& n2_type, const ListExpr* l, bool in_when); - const ZAMStmt BuildLambda(const NameExpr* n, LambdaExpr* le); - const ZAMStmt BuildLambda(int n_slot, LambdaExpr* le); + const ZAMStmt BuildLambda(const NameExpr* n, ExprPtr le); + const ZAMStmt BuildLambda(int n_slot, ExprPtr le); // Second argument is which instruction slot holds the branch target. const ZAMStmt GenCond(const Expr* e, int& branch_v); diff --git a/src/script_opt/ZAM/Expr.cc b/src/script_opt/ZAM/Expr.cc index 3f2f1207b3..f0f174f478 100644 --- a/src/script_opt/ZAM/Expr.cc +++ b/src/script_opt/ZAM/Expr.cc @@ -189,7 +189,7 @@ const ZAMStmt ZAMCompiler::CompileAssignExpr(const AssignExpr* e) { } if ( rhs->Tag() == EXPR_LAMBDA ) - return BuildLambda(lhs, rhs->AsLambdaExpr()); + return BuildLambda(lhs, op2); if ( rhs->Tag() == EXPR_COND && r1->GetType()->Tag() == TYPE_VECTOR ) return Bool_Vec_CondVVVV(lhs, r1->AsNameExpr(), r2->AsNameExpr(), r3->AsNameExpr()); @@ -807,18 +807,18 @@ const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n1, int n2_slot, const T return AddInst(z); } -const ZAMStmt ZAMCompiler::BuildLambda(const NameExpr* n, LambdaExpr* le) { - return BuildLambda(Frame1Slot(n, OP1_WRITE), le); +const ZAMStmt ZAMCompiler::BuildLambda(const NameExpr* n, ExprPtr e) { + return BuildLambda(Frame1Slot(n, OP1_WRITE), std::move(e)); } -const ZAMStmt ZAMCompiler::BuildLambda(int n_slot, LambdaExpr* le) { - auto& captures = le->GetCaptures(); +const ZAMStmt ZAMCompiler::BuildLambda(int n_slot, ExprPtr e) { + auto lambda = cast_intrusive(e); + auto& captures = lambda->GetCaptures(); int ncaptures = captures ? captures->size() : 0; auto aux = new ZInstAux(ncaptures); - aux->primary_func = le->PrimaryFunc(); - aux->lambda_name = le->Name(); - aux->id_val = le->Ingredients()->GetID(); + aux->lambda = cast_intrusive(std::move(e)); + aux->id_val = lambda->Ingredients()->GetID(); for ( int i = 0; i < ncaptures; ++i ) { auto& id_i = (*captures)[i].Id(); @@ -829,7 +829,7 @@ const ZAMStmt ZAMCompiler::BuildLambda(int n_slot, LambdaExpr* le) { aux->Add(i, FrameSlot(id_i), id_i->GetType()); } - auto z = ZInstI(OP_LAMBDA_Vi, n_slot, le->PrimaryFunc()->FrameSize()); + auto z = ZInstI(OP_LAMBDA_Vi, n_slot, lambda->PrimaryFunc()->FrameSize()); z.op_type = OP_VV_I2; z.aux = aux; @@ -1168,7 +1168,7 @@ const ZAMStmt ZAMCompiler::ConstructTable(const NameExpr* n, const Expr* e) { if ( ! def_attr || def_attr->GetExpr()->Tag() != EXPR_LAMBDA ) return zstmt; - auto def_lambda = def_attr->GetExpr()->AsLambdaExpr(); + auto def_lambda = cast_intrusive(def_attr->GetExpr()); auto dl_t = def_lambda->GetType()->AsFuncType(); auto& captures = dl_t->GetCaptures(); diff --git a/src/script_opt/ZAM/OPs/macros.op b/src/script_opt/ZAM/OPs/macros.op index 7e398299f2..91801b0d07 100644 --- a/src/script_opt/ZAM/OPs/macros.op +++ b/src/script_opt/ZAM/OPs/macros.op @@ -19,8 +19,8 @@ macro Z_TYPE2 z.GetType2() macro Z_AUX z.aux macro Z_AUX_ID z.aux->id_val macro Z_AUX_FUNC z.aux->func -macro Z_AUX_PRIMARY_FUNC z.aux->primary_func -macro Z_AUX_LAMBDA_NAME z.aux->lambda_name +macro Z_AUX_PRIMARY_FUNC z.aux->lambda->PrimaryFunc() +macro Z_AUX_LAMBDA_NAME z.aux->lambda->Name() # Location in the original script. macro Z_LOC z.loc diff --git a/src/script_opt/ZAM/Stmt.cc b/src/script_opt/ZAM/Stmt.cc index 8789447c45..003ddda84d 100644 --- a/src/script_opt/ZAM/Stmt.cc +++ b/src/script_opt/ZAM/Stmt.cc @@ -1062,7 +1062,7 @@ const ZAMStmt ZAMCompiler::CompileWhen(const WhenStmt* ws) { auto timeout = wi->TimeoutExpr(); auto lambda = NewSlot(true); - (void)BuildLambda(lambda, wi->Lambda().get()); + (void)BuildLambda(lambda, wi->Lambda()); std::vector local_aggr_slots; for ( auto& l : wi->WhenExprLocals() ) diff --git a/src/script_opt/ZAM/ZInst.cc b/src/script_opt/ZAM/ZInst.cc index 3b0ca5edb1..a20ce554ae 100644 --- a/src/script_opt/ZAM/ZInst.cc +++ b/src/script_opt/ZAM/ZInst.cc @@ -326,8 +326,11 @@ TraversalCode ZInstAux::Traverse(TraversalCallback* cb) const { HANDLE_TC_STMT_PRE(tc); } - if ( func ) { - tc = func->Traverse(cb); + // Don't traverse the "func" field, as if it's a recursive function + // we can wind up right back here. + + if ( lambda ) { + tc = lambda->Traverse(cb); HANDLE_TC_STMT_PRE(tc); } diff --git a/src/script_opt/ZAM/ZInst.h b/src/script_opt/ZAM/ZInst.h index c83c75e518..72465109c2 100644 --- a/src/script_opt/ZAM/ZInst.h +++ b/src/script_opt/ZAM/ZInst.h @@ -133,9 +133,9 @@ protected: // for 't2' but keep them together for consistency. // Type, usually for interpreting the constant. - TypePtr t = nullptr; + TypePtr t; - TypePtr t2 = nullptr; // just a few ops need two types + TypePtr t2; // just a few ops need two types public: const TypePtr& GetType() const { return t; } @@ -494,11 +494,8 @@ public: AuxElem* elems = nullptr; bool elems_has_slots = true; - // Ingredients associated with lambdas ... - ScriptFuncPtr primary_func; - - // ... and its name. - std::string lambda_name; + // Info for constructing lambdas. + LambdaExprPtr lambda; // For "when" statements. std::shared_ptr wi; @@ -507,11 +504,11 @@ public: std::unique_ptr* cat_args = nullptr; // Used for accessing function names. - IDPtr id_val = nullptr; + IDPtr id_val; // Interpreter call expression associated with this instruction, // for error reporting and stack backtraces. - CallExprPtr call_expr = nullptr; + CallExprPtr call_expr; // Used for direct calls. Func* func = nullptr; @@ -526,7 +523,7 @@ public: EventHandler* event_handler = nullptr; // Used for things like constructors. - AttributesPtr attrs = nullptr; + AttributesPtr attrs; // Whether the instruction can lead to globals/captures changing. // Currently only needed by the optimizer, but convenient to diff --git a/testing/btest/Baseline.zam/opt.validate-ZAM/output b/testing/btest/Baseline.zam/opt.validate-ZAM/output index b2acd53e11..e1c24ab2aa 100644 --- a/testing/btest/Baseline.zam/opt.validate-ZAM/output +++ b/testing/btest/Baseline.zam/opt.validate-ZAM/output @@ -1,2 +1,2 @@ ### BTest baseline data generated by btest-diff. Do not edit. Use "btest -U/-u" to update. Requires BTest >= 0.63. -1226 valid, 1830 tested, 413 skipped +1226 valid, 1834 tested, 423 skipped