Browse Source

Add `#no_type_assert` and `#type_assert` to disable implicit type assertions with `x.(T)`

gingerBill 3 years ago
parent
commit
24e7356825
8 changed files with 142 additions and 46 deletions
  1. 8 0
      src/check_expr.cpp
  2. 8 0
      src/check_stmt.cpp
  3. 12 0
      src/checker.cpp
  4. 41 31
      src/llvm_backend_expr.cpp
  5. 7 0
      src/llvm_backend_stmt.cpp
  6. 26 15
      src/llvm_backend_utility.cpp
  7. 36 0
      src/parser.cpp
  8. 4 0
      src/parser.hpp

+ 8 - 0
src/check_expr.cpp

@@ -6883,6 +6883,14 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 			out &= ~StateFlag_no_bounds_check;
 		}
 
+		if (in & StateFlag_no_type_assert) {
+			out |= StateFlag_no_type_assert;
+			out &= ~StateFlag_type_assert;
+		} else if (in & StateFlag_type_assert) {
+			out |= StateFlag_type_assert;
+			out &= ~StateFlag_no_type_assert;
+		}
+
 		c->state_flags = out;
 	}
 

+ 8 - 0
src/check_stmt.cpp

@@ -490,6 +490,14 @@ void check_stmt(CheckerContext *ctx, Ast *node, u32 flags) {
 			out &= ~StateFlag_no_bounds_check;
 		}
 
+		if (in & StateFlag_no_type_assert) {
+			out |= StateFlag_no_type_assert;
+			out &= ~StateFlag_type_assert;
+		} else if (in & StateFlag_type_assert) {
+			out |= StateFlag_type_assert;
+			out &= ~StateFlag_no_type_assert;
+		}
+
 		ctx->state_flags = out;
 	}
 

+ 12 - 0
src/checker.cpp

@@ -4875,6 +4875,9 @@ bool check_proc_info(Checker *c, ProcInfo *pi, UntypedExprInfoMap *untyped, Proc
 	bool bounds_check    = (pi->tags & ProcTag_bounds_check)    != 0;
 	bool no_bounds_check = (pi->tags & ProcTag_no_bounds_check) != 0;
 
+	bool type_assert    = (pi->tags & ProcTag_type_assert)    != 0;
+	bool no_type_assert = (pi->tags & ProcTag_no_type_assert) != 0;
+
 	if (bounds_check) {
 		ctx.state_flags |= StateFlag_bounds_check;
 		ctx.state_flags &= ~StateFlag_no_bounds_check;
@@ -4882,6 +4885,15 @@ bool check_proc_info(Checker *c, ProcInfo *pi, UntypedExprInfoMap *untyped, Proc
 		ctx.state_flags |= StateFlag_no_bounds_check;
 		ctx.state_flags &= ~StateFlag_bounds_check;
 	}
+
+	if (type_assert) {
+		ctx.state_flags |= StateFlag_type_assert;
+		ctx.state_flags &= ~StateFlag_no_type_assert;
+	} else if (no_type_assert) {
+		ctx.state_flags |= StateFlag_no_type_assert;
+		ctx.state_flags &= ~StateFlag_type_assert;
+	}
+
 	if (pi->body != nullptr && e != nullptr) {
 		GB_ASSERT((e->flags & EntityFlag_ProcBodyChecked) == 0);
 	}

+ 41 - 31
src/llvm_backend_expr.cpp

@@ -2768,27 +2768,29 @@ lbValue lb_build_unary_and(lbProcedure *p, Ast *expr) {
 				Type *src_type = type_deref(v.type);
 				Type *dst_type = type;
 
-				lbValue src_tag = {};
-				lbValue dst_tag = {};
-				if (is_type_union_maybe_pointer(src_type)) {
-					src_tag = lb_emit_comp_against_nil(p, Token_NotEq, v);
-					dst_tag = lb_const_bool(p->module, t_bool, true);
-				} else {
-					src_tag = lb_emit_load(p, lb_emit_union_tag_ptr(p, v));
-					dst_tag = lb_const_union_tag(p->module, src_type, dst_type);
-				}
 
-				lbValue ok = lb_emit_comp(p, Token_CmpEq, src_tag, dst_tag);
-				auto args = array_make<lbValue>(permanent_allocator(), 6);
-				args[0] = ok;
+				if ((p->state_flags & StateFlag_no_type_assert) == 0) {
+					lbValue src_tag = {};
+					lbValue dst_tag = {};
+					if (is_type_union_maybe_pointer(src_type)) {
+						src_tag = lb_emit_comp_against_nil(p, Token_NotEq, v);
+						dst_tag = lb_const_bool(p->module, t_bool, true);
+					} else {
+						src_tag = lb_emit_load(p, lb_emit_union_tag_ptr(p, v));
+						dst_tag = lb_const_union_tag(p->module, src_type, dst_type);
+					}
+					lbValue ok = lb_emit_comp(p, Token_CmpEq, src_tag, dst_tag);
+					auto args = array_make<lbValue>(permanent_allocator(), 6);
+					args[0] = ok;
 
-				args[1] = lb_find_or_add_entity_string(p->module, get_file_path_string(pos.file_id));
-				args[2] = lb_const_int(p->module, t_i32, pos.line);
-				args[3] = lb_const_int(p->module, t_i32, pos.column);
+					args[1] = lb_find_or_add_entity_string(p->module, get_file_path_string(pos.file_id));
+					args[2] = lb_const_int(p->module, t_i32, pos.line);
+					args[3] = lb_const_int(p->module, t_i32, pos.column);
 
-				args[4] = lb_typeid(p->module, src_type);
-				args[5] = lb_typeid(p->module, dst_type);
-				lb_emit_runtime_call(p, "type_assertion_check", args);
+					args[4] = lb_typeid(p->module, src_type);
+					args[5] = lb_typeid(p->module, dst_type);
+					lb_emit_runtime_call(p, "type_assertion_check", args);
+				}
 
 				lbValue data_ptr = v;
 				return lb_emit_conv(p, data_ptr, tv.type);
@@ -2797,23 +2799,23 @@ lbValue lb_build_unary_and(lbProcedure *p, Ast *expr) {
 				if (is_type_pointer(v.type)) {
 					v = lb_emit_load(p, v);
 				}
-
 				lbValue data_ptr = lb_emit_struct_ev(p, v, 0);
-				lbValue any_id = lb_emit_struct_ev(p, v, 1);
-				lbValue id = lb_typeid(p->module, type);
+				if ((p->state_flags & StateFlag_no_type_assert) == 0) {
+					lbValue any_id = lb_emit_struct_ev(p, v, 1);
 
+					lbValue id = lb_typeid(p->module, type);
+					lbValue ok = lb_emit_comp(p, Token_CmpEq, any_id, id);
+					auto args = array_make<lbValue>(permanent_allocator(), 6);
+					args[0] = ok;
 
-				lbValue ok = lb_emit_comp(p, Token_CmpEq, any_id, id);
-				auto args = array_make<lbValue>(permanent_allocator(), 6);
-				args[0] = ok;
-
-				args[1] = lb_find_or_add_entity_string(p->module, get_file_path_string(pos.file_id));
-				args[2] = lb_const_int(p->module, t_i32, pos.line);
-				args[3] = lb_const_int(p->module, t_i32, pos.column);
+					args[1] = lb_find_or_add_entity_string(p->module, get_file_path_string(pos.file_id));
+					args[2] = lb_const_int(p->module, t_i32, pos.line);
+					args[3] = lb_const_int(p->module, t_i32, pos.column);
 
-				args[4] = any_id;
-				args[5] = id;
-				lb_emit_runtime_call(p, "type_assertion_check", args);
+					args[4] = any_id;
+					args[5] = id;
+					lb_emit_runtime_call(p, "type_assertion_check", args);
+				}
 
 				return lb_emit_conv(p, data_ptr, tv.type);
 			} else {
@@ -2843,6 +2845,14 @@ lbValue lb_build_expr(lbProcedure *p, Ast *expr) {
 			out &= ~StateFlag_bounds_check;
 		}
 
+		if (in & StateFlag_type_assert) {
+			out |= StateFlag_type_assert;
+			out &= ~StateFlag_no_type_assert;
+		} else if (in & StateFlag_no_type_assert) {
+			out |= StateFlag_no_type_assert;
+			out &= ~StateFlag_type_assert;
+		}
+
 		p->state_flags = out;
 	}
 

+ 7 - 0
src/llvm_backend_stmt.cpp

@@ -1991,6 +1991,13 @@ void lb_build_stmt(lbProcedure *p, Ast *node) {
 			out |= StateFlag_no_bounds_check;
 			out &= ~StateFlag_bounds_check;
 		}
+		if (in & StateFlag_no_type_assert) {
+			out |= StateFlag_no_type_assert;
+			out &= ~StateFlag_type_assert;
+		} else if (in & StateFlag_type_assert) {
+			out |= StateFlag_type_assert;
+			out &= ~StateFlag_no_type_assert;
+		}
 
 		p->state_flags = out;
 	}

+ 26 - 15
src/llvm_backend_utility.cpp

@@ -626,6 +626,12 @@ lbValue lb_emit_union_cast(lbProcedure *p, lbValue value, Type *type, TokenPos p
 
 	lbValue value_  = lb_address_from_load_or_generate_local(p, value);
 
+	if ((p->state_flags & StateFlag_no_type_assert) != 0 && !is_tuple) {
+		// just do a bit cast of the data at the front
+		lbValue ptr = lb_emit_conv(p, value_, alloc_type_pointer(type));
+		return lb_emit_load(p, ptr);
+	}
+
 	lbValue tag = {};
 	lbValue dst_tag = {};
 	lbValue cond = {};
@@ -666,23 +672,22 @@ lbValue lb_emit_union_cast(lbProcedure *p, lbValue value, Type *type, TokenPos p
 	lb_start_block(p, end_block);
 
 	if (!is_tuple) {
-		{
-			// NOTE(bill): Panic on invalid conversion
-			Type *dst_type = tuple->Tuple.variables[0]->type;
+		GB_ASSERT((p->state_flags & StateFlag_no_type_assert) == 0);
+		// NOTE(bill): Panic on invalid conversion
+		Type *dst_type = tuple->Tuple.variables[0]->type;
 
-			lbValue ok = lb_emit_load(p, lb_emit_struct_ep(p, v.addr, 1));
-			auto args = array_make<lbValue>(permanent_allocator(), 7);
-			args[0] = ok;
+		lbValue ok = lb_emit_load(p, lb_emit_struct_ep(p, v.addr, 1));
+		auto args = array_make<lbValue>(permanent_allocator(), 7);
+		args[0] = ok;
 
-			args[1] = lb_const_string(m, get_file_path_string(pos.file_id));
-			args[2] = lb_const_int(m, t_i32, pos.line);
-			args[3] = lb_const_int(m, t_i32, pos.column);
+		args[1] = lb_const_string(m, get_file_path_string(pos.file_id));
+		args[2] = lb_const_int(m, t_i32, pos.line);
+		args[3] = lb_const_int(m, t_i32, pos.column);
 
-			args[4] = lb_typeid(m, src_type);
-			args[5] = lb_typeid(m, dst_type);
-			args[6] = lb_emit_conv(p, value_, t_rawptr);
-			lb_emit_runtime_call(p, "type_assertion_check2", args);
-		}
+		args[4] = lb_typeid(m, src_type);
+		args[5] = lb_typeid(m, dst_type);
+		args[6] = lb_emit_conv(p, value_, t_rawptr);
+		lb_emit_runtime_call(p, "type_assertion_check2", args);
 
 		return lb_emit_load(p, lb_emit_struct_ep(p, v.addr, 0));
 	}
@@ -706,6 +711,13 @@ lbAddr lb_emit_any_cast_addr(lbProcedure *p, lbValue value, Type *type, TokenPos
 	}
 	Type *dst_type = tuple->Tuple.variables[0]->type;
 
+	if ((p->state_flags & StateFlag_no_type_assert) != 0 && !is_tuple) {
+		// just do a bit cast of the data at the front
+		lbValue ptr = lb_emit_struct_ev(p, value, 0);
+		ptr = lb_emit_conv(p, ptr, alloc_type_pointer(type));
+		return lb_addr(ptr);
+	}
+
 	lbAddr v = lb_add_local_generated(p, tuple, true);
 
 	lbValue dst_typeid = lb_typeid(m, dst_type);
@@ -731,7 +743,6 @@ lbAddr lb_emit_any_cast_addr(lbProcedure *p, lbValue value, Type *type, TokenPos
 
 	if (!is_tuple) {
 		// NOTE(bill): Panic on invalid conversion
-
 		lbValue ok = lb_emit_load(p, lb_emit_struct_ep(p, v.addr, 1));
 		auto args = array_make<lbValue>(permanent_allocator(), 7);
 		args[0] = ok;

+ 36 - 0
src/parser.cpp

@@ -1843,6 +1843,8 @@ void parse_proc_tags(AstFile *f, u64 *tags) {
 		ELSE_IF_ADD_TAG(require_results)
 		ELSE_IF_ADD_TAG(bounds_check)
 		ELSE_IF_ADD_TAG(no_bounds_check)
+		ELSE_IF_ADD_TAG(type_assert)
+		ELSE_IF_ADD_TAG(no_type_assert)
 		else {
 			syntax_error(tag_expr, "Unknown procedure type tag #%.*s", LIT(tag_name));
 		}
@@ -1853,6 +1855,10 @@ void parse_proc_tags(AstFile *f, u64 *tags) {
 	if ((*tags & ProcTag_bounds_check) && (*tags & ProcTag_no_bounds_check)) {
 		syntax_error(f->curr_token, "You cannot apply both #bounds_check and #no_bounds_check to a procedure");
 	}
+
+	if ((*tags & ProcTag_type_assert) && (*tags & ProcTag_no_type_assert)) {
+		syntax_error(f->curr_token, "You cannot apply both #type_assert and #no_type_assert to a procedure");
+	}
 }
 
 
@@ -2000,11 +2006,23 @@ Ast *parse_check_directive_for_statement(Ast *s, Token const &tag_token, u16 sta
 			syntax_error(tag_token, "#bounds_check and #no_bounds_check cannot be applied together");
 		}
 		break;
+	case StateFlag_type_assert:
+		if ((s->state_flags & StateFlag_no_type_assert) != 0) {
+			syntax_error(tag_token, "#type_assert and #no_type_assert cannot be applied together");
+		}
+		break;
+	case StateFlag_no_type_assert:
+		if ((s->state_flags & StateFlag_type_assert) != 0) {
+			syntax_error(tag_token, "#type_assert and #no_type_assert cannot be applied together");
+		}
+		break;
 	}
 
 	switch (state_flag) {
 	case StateFlag_bounds_check:
 	case StateFlag_no_bounds_check:
+	case StateFlag_type_assert:
+	case StateFlag_no_type_assert:
 		switch (s->kind) {
 		case Ast_BlockStmt:
 		case Ast_IfStmt:
@@ -2128,6 +2146,12 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 		} else if (name.string == "no_bounds_check") {
 			Ast *operand = parse_expr(f, lhs);
 			return parse_check_directive_for_statement(operand, name, StateFlag_no_bounds_check);
+		} else if (name.string == "type_assert") {
+			Ast *operand = parse_expr(f, lhs);
+			return parse_check_directive_for_statement(operand, name, StateFlag_type_assert);
+		} else if (name.string == "no_type_assert") {
+			Ast *operand = parse_expr(f, lhs);
+			return parse_check_directive_for_statement(operand, name, StateFlag_no_type_assert);
 		} else if (name.string == "relative") {
 			Ast *tag = ast_basic_directive(f, token, name);
 			tag = parse_call_expr(f, tag);
@@ -2224,6 +2248,12 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 			if (tags & ProcTag_bounds_check) {
 				body->state_flags |= StateFlag_bounds_check;
 			}
+			if (tags & ProcTag_no_type_assert) {
+				body->state_flags |= StateFlag_no_type_assert;
+			}
+			if (tags & ProcTag_type_assert) {
+				body->state_flags |= StateFlag_type_assert;
+			}
 
 			return ast_proc_lit(f, type, body, tags, where_token, where_clauses);
 		} else if (allow_token(f, Token_do)) {
@@ -4611,6 +4641,12 @@ Ast *parse_stmt(AstFile *f) {
 		} else if (tag == "no_bounds_check") {
 			s = parse_stmt(f);
 			return parse_check_directive_for_statement(s, name, StateFlag_no_bounds_check);
+		} else if (tag == "type_assert") {
+			s = parse_stmt(f);
+			return parse_check_directive_for_statement(s, name, StateFlag_type_assert);
+		} else if (tag == "no_type_assert") {
+			s = parse_stmt(f);
+			return parse_check_directive_for_statement(s, name, StateFlag_no_type_assert);
 		} else if (tag == "partial") {
 			s = parse_stmt(f);
 			switch (s->kind) {

+ 4 - 0
src/parser.hpp

@@ -226,6 +226,8 @@ enum ProcInlining {
 enum ProcTag {
 	ProcTag_bounds_check    = 1<<0,
 	ProcTag_no_bounds_check = 1<<1,
+	ProcTag_type_assert     = 1<<2,
+	ProcTag_no_type_assert  = 1<<3,
 
 	ProcTag_require_results = 1<<4,
 	ProcTag_optional_ok     = 1<<5,
@@ -258,6 +260,8 @@ ProcCallingConvention default_calling_convention(void) {
 enum StateFlag : u8 {
 	StateFlag_bounds_check    = 1<<0,
 	StateFlag_no_bounds_check = 1<<1,
+	StateFlag_type_assert     = 1<<2,
+	StateFlag_no_type_assert  = 1<<3,
 
 	StateFlag_BeenHandled = 1<<7,
 };