Browse Source

Disallow procedure calls with an associated deferred procedure to be used in logical binary expressions (short-circuiting)

gingerBill 5 years ago
parent
commit
7f5021c8e9
5 changed files with 104 additions and 33 deletions
  1. 56 6
      src/check_expr.cpp
  2. 1 1
      src/common.cpp
  3. 40 11
      src/ir.cpp
  4. 0 13
      src/parser.cpp
  5. 7 2
      src/parser.hpp

+ 56 - 6
src/check_expr.cpp

@@ -2364,6 +2364,11 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint
 
 	ast_node(be, BinaryExpr, node);
 
+	defer({
+		node->viral_state_flags |= be->left->viral_state_flags;
+		node->viral_state_flags |= be->right->viral_state_flags;
+	});
+
 	Token op = be->op;
 	switch (op.kind) {
 	case Token_CmpEq:
@@ -2605,6 +2610,18 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint
 				return;
 			}
 		}
+		break;
+
+	case Token_CmpAnd:
+	case Token_CmpOr:
+		if (be->left->viral_state_flags & ViralStateFlag_ContainsDeferredProcedure) {
+			error(be->left, "Procedure calls that have an associated deferred procedure are not allowed within logical binary expressions");
+		}
+		if (be->right->viral_state_flags & ViralStateFlag_ContainsDeferredProcedure) {
+			error(be->right, "Procedure calls that have an associated deferred procedure are not allowed within logical binary expressions");
+		}
+		break;
+
 	}
 
 	if (x->mode == Addressing_Constant &&
@@ -5440,6 +5457,8 @@ CALL_ARGUMENT_CHECKER(check_call_arguments_internal) {
 				Entity *e = sig_params[operand_index];
 				Type *t = e->type;
 				Operand o = operands[operand_index];
+				call->viral_state_flags |= o.expr->viral_state_flags;
+
 				if (e->kind == Entity_TypeName) {
 					// GB_ASSERT(!variadic);
 					if (o.mode == Addressing_Invalid) {
@@ -5596,6 +5615,12 @@ CALL_ARGUMENT_CHECKER(check_named_call_arguments) {
 	defer (gb_free(c->allocator, visited));
 	auto ordered_operands = array_make<Operand>(c->allocator, param_count);
 	defer (array_free(&ordered_operands));
+	defer ({
+		for_array(i, ordered_operands) {
+			Operand const &o = ordered_operands[i];
+			call->viral_state_flags |= o.expr->viral_state_flags;
+		}
+	});
 
 	for_array(i, ce->args) {
 		Ast *arg = ce->args[i];
@@ -6668,6 +6693,14 @@ ExprKind check_call_expr(CheckerContext *c, Operand *operand, Ast *call, Type *t
 		return builtin_procs[id].kind;
 	}
 
+	Entity *e = entity_of_ident(operand->expr);
+
+	if (e != nullptr && e->kind == Entity_Procedure) {
+		if (e->Procedure.deferred_procedure.entity != nullptr) {
+			call->viral_state_flags |= ViralStateFlag_ContainsDeferredProcedure;
+		}
+	}
+
 	Type *proc_type = base_type(operand->type);
 	if (operand->mode != Addressing_ProcGroup) {
 		bool valid_type = (proc_type != nullptr) && is_type_proc(proc_type);
@@ -7097,6 +7130,8 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 	case_ast_node(te, TernaryExpr, node);
 		Operand cond = {Addressing_Invalid};
 		check_expr(c, &cond, te->cond);
+		node->viral_state_flags |= te->cond->viral_state_flags;
+
 		if (cond.mode != Addressing_Invalid && !is_type_boolean(cond.type)) {
 			error(te->cond, "Non-boolean condition in if expression");
 		}
@@ -7104,9 +7139,11 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 		Operand x = {Addressing_Invalid};
 		Operand y = {Addressing_Invalid};
 		check_expr_or_type(c, &x, te->x, type_hint);
+		node->viral_state_flags |= te->x->viral_state_flags;
 
 		if (te->y != nullptr) {
 			check_expr_or_type(c, &y, te->y, type_hint);
+			node->viral_state_flags |= te->y->viral_state_flags;
 		} else {
 			error(node, "A ternary expression must have an else clause");
 			return kind;
@@ -7761,6 +7798,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 
 	case_ast_node(pe, ParenExpr, node);
 		kind = check_expr_base(c, o, pe->expr, type_hint);
+		node->viral_state_flags |= pe->expr->viral_state_flags;
 		o->expr = node;
 	case_end;
 
@@ -7769,18 +7807,15 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 		error(node, "Unknown tag expression, #%.*s", LIT(name));
 		if (te->expr) {
 			kind = check_expr_base(c, o, te->expr, type_hint);
+			node->viral_state_flags |= te->expr->viral_state_flags;
 		}
 		o->expr = node;
 	case_end;
 
-	case_ast_node(re, RunExpr, node);
-		// TODO(bill): Tag expressions
-		kind = check_expr_base(c, o, re->expr, type_hint);
-		o->expr = node;
-	case_end;
-
 	case_ast_node(ta, TypeAssertion, node);
 		check_expr(c, o, ta->expr);
+		node->viral_state_flags |= ta->expr->viral_state_flags;
+
 		if (o->mode == Addressing_Invalid) {
 			o->expr = node;
 			return kind;
@@ -7874,6 +7909,8 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 		}
 		Type *type = o->type;
 		check_expr_base(c, o, tc->expr, type);
+		node->viral_state_flags |= tc->expr->viral_state_flags;
+
 		if (o->mode != Addressing_Invalid) {
 			switch (tc->token.kind) {
 			case Token_transmute:
@@ -7893,6 +7930,8 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 
 	case_ast_node(ac, AutoCast, node);
 		check_expr_base(c, o, ac->expr, type_hint);
+		node->viral_state_flags |= ac->expr->viral_state_flags;
+
 		if (o->mode == Addressing_Invalid) {
 			o->expr = node;
 			return kind;
@@ -7906,6 +7945,8 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 
 	case_ast_node(ue, UnaryExpr, node);
 		check_expr_base(c, o, ue->expr, type_hint);
+		node->viral_state_flags |= ue->expr->viral_state_flags;
+
 		if (o->mode == Addressing_Invalid) {
 			o->expr = node;
 			return kind;
@@ -7930,6 +7971,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 
 	case_ast_node(se, SelectorExpr, node);
 		check_selector(c, o, node, type_hint);
+		node->viral_state_flags |= se->expr->viral_state_flags;
 	case_end;
 
 
@@ -8000,6 +8042,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 
 	case_ast_node(ie, IndexExpr, node);
 		check_expr(c, o, ie->expr);
+		node->viral_state_flags |= ie->expr->viral_state_flags;
 		if (o->mode == Addressing_Invalid) {
 			o->expr = node;
 			return kind;
@@ -8061,12 +8104,15 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 		i64 index = 0;
 		bool ok = check_index_value(c, false, ie->index, max_count, &index);
 
+		node->viral_state_flags |= ie->index->viral_state_flags;
 	case_end;
 
 
 
 	case_ast_node(se, SliceExpr, node);
 		check_expr(c, o, se->expr);
+		node->viral_state_flags |= se->expr->viral_state_flags;
+
 		if (o->mode == Addressing_Invalid) {
 			o->mode = Addressing_Invalid;
 			o->expr = node;
@@ -8148,6 +8194,8 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 				if (check_index_value(c, false, nodes[i], capacity, &j)) {
 					index = j;
 				}
+
+				node->viral_state_flags |= nodes[i]->viral_state_flags;
 			} else if (i == 0) {
 				index = 0;
 			}
@@ -8173,6 +8221,8 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 
 	case_ast_node(de, DerefExpr, node);
 		check_expr_or_type(c, o, de->expr);
+		node->viral_state_flags |= de->expr->viral_state_flags;
+
 		if (o->mode == Addressing_Invalid) {
 			o->mode = Addressing_Invalid;
 			o->expr = node;

+ 1 - 1
src/common.cpp

@@ -443,7 +443,7 @@ GB_ALLOCATOR_PROC(arena_allocator_proc) {
 		ptr = arena_alloc(arena, size, alignment);
 		break;
 	case gbAllocation_Free:
-		GB_PANIC("gbAllocation_Free not supported");
+		// GB_PANIC("gbAllocation_Free not supported");
 		break;
 	case gbAllocation_Resize:
 		GB_PANIC("gbAllocation_Resize: not supported");

+ 40 - 11
src/ir.cpp

@@ -91,6 +91,7 @@ enum irDeferExitKind {
 enum irDeferKind {
 	irDefer_Node,
 	irDefer_Instr,
+	irDefer_Proc,
 };
 
 struct irDefer {
@@ -102,6 +103,10 @@ struct irDefer {
 		Ast *stmt;
 		// NOTE(bill): 'instr' will be copied every time to create a new one
 		irValue *instr;
+		struct {
+			irValue *deferred;
+			Array<irValue *> result_as_args;
+		} proc;
 	};
 };
 
@@ -1554,6 +1559,17 @@ irDefer ir_add_defer_instr(irProcedure *proc, isize scope_index, irValue *instr)
 	return d;
 }
 
+irDefer ir_add_defer_proc(irProcedure *proc, isize scope_index, irValue *deferred, Array<irValue *> const &result_as_args) {
+	irDefer d = {irDefer_Proc};
+	d.scope_index = proc->scope_index;
+	d.block = proc->curr_block;
+	d.proc.deferred = deferred;
+	d.proc.result_as_args = result_as_args;
+	array_add(&proc->defer_stmts, d);
+	return d;
+}
+
+
 
 
 irValue *ir_add_module_constant(irModule *m, Type *type, ExactValue value) {
@@ -3202,8 +3218,7 @@ irValue *ir_emit_call(irProcedure *p, irValue *value, Array<irValue *> const &ar
 				break;
 			}
 
-			irValue *deferred_call = ir_de_emit(p, ir_emit_call(p, deferred, result_as_args));
-			ir_add_defer_instr(p, p->scope_index, deferred_call);
+			ir_add_defer_proc(p, p->scope_index, deferred, result_as_args);
 		}
 	}
 
@@ -3271,21 +3286,34 @@ void ir_open_scope(irProcedure *proc) {
 	proc->scope_index++;
 }
 
-void ir_close_scope(irProcedure *proc, irDeferExitKind kind, irBlock *block) {
+void ir_close_scope(irProcedure *proc, irDeferExitKind kind, irBlock *block, bool pop_stack=true) {
 	ir_emit_defer_stmts(proc, kind, block);
 	GB_ASSERT(proc->scope_index > 0);
 
 
 	// NOTE(bill): Remove `context`s made in that scope
+
+	isize end_idx = proc->context_stack.count-1;
+	isize pop_count = 0;
+
 	for (;;) {
-		irContextData *end = array_end_ptr(&proc->context_stack);
+		if (end_idx < 0) {
+			break;
+		}
+		irContextData *end = &proc->context_stack[end_idx];
 		if (end == nullptr) {
 			break;
 		}
 		if (end->scope_index != proc->scope_index) {
 			break;
 		}
-		array_pop(&proc->context_stack);
+		end_idx -= 1;
+		pop_count += 1;
+	}
+	if (pop_stack) {
+		for (isize i = 0; i < pop_count; i++) {
+			array_pop(&proc->context_stack);
+		}
 	}
 
 
@@ -6256,6 +6284,8 @@ void ir_build_defer_stmt(irProcedure *proc, irDefer d) {
 		// NOTE(bill): Need to make a new copy
 		irValue *instr = cast(irValue *)gb_alloc_copy(ir_allocator(), d.instr, gb_size_of(irValue));
 		ir_emit(proc, instr);
+	} else if (d.kind == irDefer_Proc) {
+		ir_emit_call(proc, d.proc.deferred, d.proc.result_as_args);
 	}
 }
 
@@ -6959,11 +6989,6 @@ irValue *ir_build_expr_internal(irProcedure *proc, Ast *expr) {
 		return nullptr;
 	case_end;
 
-	case_ast_node(re, RunExpr, expr);
-		// TODO(bill): Run Expression
-		return ir_build_expr(proc, re->expr);
-	case_end;
-
 	case_ast_node(de, DerefExpr, expr);
 		return ir_addr_load(proc, ir_build_addr(proc, expr));
 	case_end;
@@ -9499,7 +9524,6 @@ void ir_build_stmt_internal(irProcedure *proc, Ast *node) {
 	case_ast_node(is, IfStmt, node);
 		ir_emit_comment(proc, str_lit("IfStmt"));
 		ir_open_scope(proc); // Scope #1
-		defer (ir_close_scope(proc, irDeferExit_Default, nullptr));
 
 		if (is->init != nullptr) {
 			// TODO(bill): Should this have a separate block to begin with?
@@ -9525,7 +9549,9 @@ void ir_build_stmt_internal(irProcedure *proc, Ast *node) {
 			tl->is_block = true;
 		}
 
+		// ir_open_scope(proc);
 		ir_build_stmt(proc, is->body);
+		// ir_close_scope(proc, irDeferExit_Default, nullptr);
 
 		ir_emit_jump(proc, done);
 
@@ -9539,7 +9565,10 @@ void ir_build_stmt_internal(irProcedure *proc, Ast *node) {
 			ir_emit_jump(proc, done);
 		}
 
+
 		ir_start_block(proc, done);
+		ir_close_scope(proc, irDeferExit_Default, nullptr);
+
 	case_end;
 
 	case_ast_node(fs, ForStmt, node);

+ 0 - 13
src/parser.cpp

@@ -14,7 +14,6 @@ Token ast_token(Ast *node) {
 		return node->CompoundLit.open;
 
 	case Ast_TagExpr:       return node->TagExpr.token;
-	case Ast_RunExpr:       return node->RunExpr.token;
 	case Ast_BadExpr:       return node->BadExpr.begin;
 	case Ast_UnaryExpr:     return node->UnaryExpr.op;
 	case Ast_BinaryExpr:    return ast_token(node->BinaryExpr.left);
@@ -155,9 +154,6 @@ Ast *clone_ast(Ast *node) {
 	case Ast_TagExpr:
 		n->TagExpr.expr = clone_ast(n->TagExpr.expr);
 		break;
-	case Ast_RunExpr:
-		n->RunExpr.expr = clone_ast(n->RunExpr.expr);
-		break;
 	case Ast_UnaryExpr:
 		n->UnaryExpr.expr = clone_ast(n->UnaryExpr.expr);
 		break;
@@ -458,15 +454,6 @@ Ast *ast_tag_expr(AstFile *f, Token token, Token name, Ast *expr) {
 	return result;
 }
 
-Ast *ast_run_expr(AstFile *f, Token token, Token name, Ast *expr) {
-	Ast *result = alloc_ast_node(f, Ast_RunExpr);
-	result->RunExpr.token = token;
-	result->RunExpr.name = name;
-	result->RunExpr.expr = expr;
-	return result;
-}
-
-
 Ast *ast_tag_stmt(AstFile *f, Token token, Token name, Ast *stmt) {
 	Ast *result = alloc_ast_node(f, Ast_TagStmt);
 	result->TagStmt.token = token;

+ 7 - 2
src/parser.hpp

@@ -191,6 +191,11 @@ enum StmtStateFlag {
 	StmtStateFlag_no_deferred = 1<<5,
 };
 
+enum ViralStateFlag {
+	ViralStateFlag_ContainsDeferredProcedure = 1<<0,
+};
+
+
 enum FieldFlag {
 	FieldFlag_NONE      = 0,
 	FieldFlag_ellipsis  = 1<<0,
@@ -255,7 +260,6 @@ enum StmtAllowFlag {
 AST_KIND(_ExprBegin,  "",  bool) \
 	AST_KIND(BadExpr,      "bad expression",         struct { Token begin, end; }) \
 	AST_KIND(TagExpr,      "tag expression",         struct { Token token, name; Ast *expr; }) \
-	AST_KIND(RunExpr,      "run expression",         struct { Token token, name; Ast *expr; }) \
 	AST_KIND(UnaryExpr,    "unary expression",       struct { Token op; Ast *expr; }) \
 	AST_KIND(BinaryExpr,   "binary expression",      struct { Token op; Ast *left, *right; } ) \
 	AST_KIND(ParenExpr,    "parentheses expression", struct { Ast *expr; Token open, close; }) \
@@ -570,9 +574,10 @@ isize const ast_variant_sizes[] = {
 struct Ast {
 	AstKind      kind;
 	u32          stmt_state_flags;
+	u32          viral_state_flags;
+	bool         been_handled;
 	AstFile *    file;
 	Scope *      scope;
-	bool         been_handled;
 	TypeAndValue tav;
 
 	union {