Browse Source

Correct implicit union cast

gingerBill 3 years ago
parent
commit
445ca70521
2 changed files with 20 additions and 0 deletions
  1. 11 0
      src/check_expr.cpp
  2. 9 0
      src/llvm_backend_expr.cpp

+ 11 - 0
src/check_expr.cpp

@@ -508,6 +508,10 @@ bool check_cast_internal(CheckerContext *c, Operand *x, Type *type);
 #define MAXIMUM_TYPE_DISTANCE 10
 #define MAXIMUM_TYPE_DISTANCE 10
 
 
 i64 check_distance_between_types(CheckerContext *c, Operand *operand, Type *type) {
 i64 check_distance_between_types(CheckerContext *c, Operand *operand, Type *type) {
+	if (c == nullptr) {
+		GB_ASSERT(operand->mode == Addressing_Value);
+		GB_ASSERT(is_type_typed(operand->type));
+	}
 	if (operand->mode == Addressing_Invalid ||
 	if (operand->mode == Addressing_Invalid ||
 	    type == t_invalid) {
 	    type == t_invalid) {
 		return -1;
 		return -1;
@@ -818,6 +822,13 @@ bool check_is_assignable_to(CheckerContext *c, Operand *operand, Type *type) {
 	return check_is_assignable_to_with_score(c, operand, type, &score);
 	return check_is_assignable_to_with_score(c, operand, type, &score);
 }
 }
 
 
+bool internal_check_is_assignable_to(Type *src, Type *dst) {
+	Operand x = {};
+	x.type = src;
+	x.mode = Addressing_Value;
+	return check_is_assignable_to(nullptr, &x, dst);
+}
+
 AstPackage *get_package_of_type(Type *type) {
 AstPackage *get_package_of_type(Type *type) {
 	for (;;) {
 	for (;;) {
 		if (type == nullptr) {
 		if (type == nullptr) {

+ 9 - 0
src/llvm_backend_expr.cpp

@@ -1834,6 +1834,15 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) {
 				return lb_addr_load(p, parent);
 				return lb_addr_load(p, parent);
 			}
 			}
 		}
 		}
+		if (dst->Union.variants.count == 1) {
+			Type *vt = dst->Union.variants[0];
+			if (internal_check_is_assignable_to(src, vt)) {
+				value = lb_emit_conv(p, value, vt);
+				lbAddr parent = lb_add_local_generated(p, t, true);
+				lb_emit_store_union_variant(p, parent.addr, value, vt);
+				return lb_addr_load(p, parent);
+			}
+		}
 	}
 	}
 
 
 	// NOTE(bill): This has to be done before 'Pointer <-> Pointer' as it's
 	// NOTE(bill): This has to be done before 'Pointer <-> Pointer' as it's