Browse Source

Improve type inference rules for implicit selector expressions

New improvements:
`(.A == x)`
`a_union_containing_many_enums = .A;`
gingerBill 4 years ago
parent
commit
da7a9a3584
2 changed files with 139 additions and 89 deletions
  1. 110 68
      src/check_expr.cpp
  2. 29 21
      src/llvm_backend.cpp

+ 110 - 68
src/check_expr.cpp

@@ -2504,6 +2504,10 @@ bool check_binary_array_expr(CheckerContext *c, Token op, Operand *x, Operand *y
 	return false;
 }
 
+bool is_ise_expr(Ast *node) {
+	node = unparen_expr(node);
+	return node->kind == Ast_ImplicitSelectorExpr;
+}
 
 void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint, bool use_lhs_as_type_hint=false) {
 	GB_ASSERT(node->kind == Ast_BinaryExpr);
@@ -2521,8 +2525,14 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint
 	case Token_CmpEq:
 	case Token_NotEq: {
 		// NOTE(bill): Allow comparisons between types
-		check_expr_or_type(c, x, be->left, type_hint);
-		check_expr_or_type(c, y, be->right, x->type);
+		if (is_ise_expr(be->left)) {
+			// Evalute the right before the left for an '.X' expression
+			check_expr_or_type(c, y, be->right, type_hint);
+			check_expr_or_type(c, x, be->left, y->type);
+		} else {
+			check_expr_or_type(c, x, be->left, type_hint);
+			check_expr_or_type(c, y, be->right, x->type);
+		}
 		bool xt = x->mode == Addressing_Type;
 		bool yt = y->mode == Addressing_Type;
 		// If only one is a type, this is an error
@@ -2629,11 +2639,22 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint
 		return;
 
 	default:
-		check_expr_with_type_hint(c, x, be->left, type_hint);
-		if (use_lhs_as_type_hint) {
-			check_expr_with_type_hint(c, y, be->right, x->type);
+		if (is_ise_expr(be->left)) {
+			// Evalute the right before the left for an '.X' expression
+			check_expr_or_type(c, y, be->right, type_hint);
+
+			if (use_lhs_as_type_hint) { // RHS in this case
+				check_expr_or_type(c, x, be->left, y->type);
+			} else {
+				check_expr_with_type_hint(c, x, be->left, type_hint);
+			}
 		} else {
-			check_expr_with_type_hint(c, y, be->right, type_hint);
+			check_expr_with_type_hint(c, x, be->left, type_hint);
+			if (use_lhs_as_type_hint) {
+				check_expr_with_type_hint(c, y, be->right, x->type);
+			} else {
+				check_expr_with_type_hint(c, y, be->right, type_hint);
+			}
 		}
 		break;
 	}
@@ -5929,6 +5950,88 @@ bool check_is_operand_compound_lit_constant(CheckerContext *c, Operand *o) {
 }
 
 
+bool attempt_implicit_selector_expr(CheckerContext *c, Operand *o, AstImplicitSelectorExpr *ise, Type *th) {
+	if (is_type_enum(th)) {
+		Type *enum_type = base_type(th);
+		GB_ASSERT(enum_type->kind == Type_Enum);
+
+		String name = ise->selector->Ident.token.string;
+
+		Entity *e = scope_lookup_current(enum_type->Enum.scope, name);
+		if (e == nullptr) {
+			return false;
+		}
+		GB_ASSERT(are_types_identical(base_type(e->type), enum_type));
+		GB_ASSERT(e->kind == Entity_Constant);
+		o->value = e->Constant.value;
+		o->mode = Addressing_Constant;
+		o->type = e->type;
+		return true;
+	}
+	bool show_error = true;
+	if (is_type_union(th)) {
+		Type *union_type = base_type(th);
+		isize enum_count = 0;
+		Type *et = nullptr;
+
+		auto operands = array_make<Operand>(temporary_allocator(), 0, union_type->Union.variants.count);
+
+		for_array(i, union_type->Union.variants) {
+			Type *vt = union_type->Union.variants[i];
+
+			Operand x = {};
+			if (attempt_implicit_selector_expr(c, &x, ise, vt)) {
+				array_add(&operands, x);
+			}
+		}
+
+		if (operands.count == 1) {
+			*o = operands[0];
+			return true;
+		}
+	}
+	return false;
+}
+
+ExprKind check_implicit_selector_expr(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) {
+	ast_node(ise, ImplicitSelectorExpr, node);
+
+	o->type = t_invalid;
+	o->expr = node;
+	o->mode = Addressing_Invalid;
+
+	Type *th = type_hint;
+
+	if (th == nullptr) {
+		gbString str = expr_to_string(node);
+		error(node, "Cannot determine type for implicit selector expression '%s'", str);
+		gb_string_free(str);
+		return Expr_Expr;
+	}
+	o->type = th;
+	Type *enum_type = th;
+
+	bool ok = attempt_implicit_selector_expr(c, o, ise, th);
+	if (!ok) {
+		String name = ise->selector->Ident.token.string;
+
+		if (is_type_enum(th)) {
+			gbString typ = type_to_string(th);
+			error(node, "Undeclared name %.*s for type '%s'", LIT(name), typ);
+			gb_string_free(typ);
+		} else {
+			gbString typ = type_to_string(th);
+			gbString str = expr_to_string(node);
+			error(node, "Invalid type '%s' for implicit selector expression '%s'", typ, str);
+			gb_string_free(str);
+			gb_string_free(typ);
+		}
+	}
+
+	o->expr = node;
+	return Expr_Expr;
+}
+
 ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) {
 	u32 prev_state_flags = c->state_flags;
 	defer (c->state_flags = prev_state_flags);
@@ -7395,68 +7498,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 
 
 	case_ast_node(ise, ImplicitSelectorExpr, node);
-		o->type = t_invalid;
-		o->expr = node;
-		o->mode = Addressing_Invalid;
-
-		Type *th = type_hint;
-
-		if (th == nullptr) {
-			gbString str = expr_to_string(node);
-			error(node, "Cannot determine type for implicit selector expression '%s'", str);
-			gb_string_free(str);
-			return Expr_Expr;
-		}
-		o->type = th;
-		Type *enum_type = th;
-
-		if (!is_type_enum(th)) {
-			bool show_error = true;
-			if (is_type_union(th)) {
-				Type *union_type = base_type(th);
-				isize enum_count = 0;
-				Type *et = nullptr;
-				for_array(i, union_type->Union.variants) {
-					Type *vt = union_type->Union.variants[i];
-					if (is_type_enum(vt)) {
-						enum_count += 1;
-						et = vt;
-					}
-				}
-				if (enum_count == 1) {
-					show_error = false;
-					enum_type = et;
-				}
-			}
-
-			if (show_error) {
-				gbString typ = type_to_string(th);
-				gbString str = expr_to_string(node);
-				error(node, "Invalid type '%s' for implicit selector expression '%s'", typ, str);
-				gb_string_free(str);
-				gb_string_free(typ);
-				return Expr_Expr;
-			}
-		}
-		GB_ASSERT(ise->selector->kind == Ast_Ident);
-		String name = ise->selector->Ident.token.string;
-
-		enum_type = base_type(enum_type);
-		GB_ASSERT(enum_type->kind == Type_Enum);
-		Entity *e = scope_lookup_current(enum_type->Enum.scope, name);
-		if (e == nullptr) {
-			gbString typ = type_to_string(th);
-			error(node, "Undeclared name %.*s for type '%s'", LIT(name), typ);
-			gb_string_free(typ);
-			return Expr_Expr;
-		}
-		GB_ASSERT(are_types_identical(base_type(e->type), enum_type));
-		GB_ASSERT(e->kind == Entity_Constant);
-		o->value = e->Constant.value;
-		o->mode = Addressing_Constant;
-		o->type = e->type;
-
-		return Expr_Expr;
+		return check_implicit_selector_expr(c, o, node, type_hint);
 	case_end;
 
 	case_ast_node(ie, IndexExpr, node);

+ 29 - 21
src/llvm_backend.cpp

@@ -10667,6 +10667,30 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 	return {p->value, p->type};
 }
 
+lbValue lb_compare_records(lbProcedure *p, TokenKind op_kind, lbValue left, lbValue right, Type *type) {
+	GB_ASSERT((is_type_struct(type) || is_type_union(type)) && is_type_comparable(type));
+	lbValue left_ptr  = lb_address_from_load_or_generate_local(p, left);
+	lbValue right_ptr = lb_address_from_load_or_generate_local(p, right);
+	lbValue res = {};
+	if (is_type_simple_compare(type)) {
+		// TODO(bill): Test to see if this is actually faster!!!!
+		auto args = array_make<lbValue>(permanent_allocator(), 3);
+		args[0] = lb_emit_conv(p, left_ptr, t_rawptr);
+		args[1] = lb_emit_conv(p, right_ptr, t_rawptr);
+		args[2] = lb_const_int(p->module, t_int, type_size_of(type));
+		res = lb_emit_runtime_call(p, "memory_equal", args);
+	} else {
+		lbValue value = lb_get_equal_proc_for_type(p->module, type);
+		auto args = array_make<lbValue>(permanent_allocator(), 2);
+		args[0] = lb_emit_conv(p, left_ptr, t_rawptr);
+		args[1] = lb_emit_conv(p, right_ptr, t_rawptr);
+		res = lb_emit_call(p, value, args);
+	}
+	if (op_kind == Token_NotEq) {
+		res = lb_emit_unary_arith(p, Token_Not, res, res.type);
+	}
+	return res;
+}
 
 
 lbValue lb_emit_comp(lbProcedure *p, TokenKind op_kind, lbValue left, lbValue right) {
@@ -10797,27 +10821,11 @@ lbValue lb_emit_comp(lbProcedure *p, TokenKind op_kind, lbValue left, lbValue ri
 
 
 	if ((is_type_struct(a) || is_type_union(a)) && is_type_comparable(a)) {
-		lbValue left_ptr  = lb_address_from_load_or_generate_local(p, left);
-		lbValue right_ptr = lb_address_from_load_or_generate_local(p, right);
-		lbValue res = {};
-		if (is_type_simple_compare(a)) {
-			// TODO(bill): Test to see if this is actually faster!!!!
-			auto args = array_make<lbValue>(permanent_allocator(), 3);
-			args[0] = lb_emit_conv(p, left_ptr, t_rawptr);
-			args[1] = lb_emit_conv(p, right_ptr, t_rawptr);
-			args[2] = lb_const_int(p->module, t_int, type_size_of(a));
-			res = lb_emit_runtime_call(p, "memory_equal", args);
-		} else {
-			lbValue value = lb_get_equal_proc_for_type(p->module, a);
-			auto args = array_make<lbValue>(permanent_allocator(), 2);
-			args[0] = lb_emit_conv(p, left_ptr, t_rawptr);
-			args[1] = lb_emit_conv(p, right_ptr, t_rawptr);
-			res = lb_emit_call(p, value, args);
-		}
-		if (op_kind == Token_NotEq) {
-			res = lb_emit_unary_arith(p, Token_Not, res, res.type);
-		}
-		return res;
+		return lb_compare_records(p, op_kind, left, right, a);
+	}
+
+	if ((is_type_struct(b) || is_type_union(b)) && is_type_comparable(b)) {
+		return lb_compare_records(p, op_kind, left, right, b);
 	}
 
 	if (is_type_string(a)) {