Przeglądaj źródła

Allow `union`s to be comparable if all their variants are comparable

gingerBill 4 lat temu
rodzic
commit
518ecaf9c9
3 zmienionych plików z 85 dodań i 7 usunięć
  1. 3 0
      core/runtime/core.odin
  2. 65 6
      src/llvm_backend.cpp
  3. 17 1
      src/types.cpp

+ 3 - 0
core/runtime/core.odin

@@ -121,6 +121,9 @@ Type_Info_Union :: struct {
 	variants:     []^Type_Info,
 	tag_offset:   uintptr,
 	tag_type:     ^Type_Info,
+
+	equal: Equal_Proc, // set only when the struct has .Comparable set but does not have .Simple_Compare set
+
 	custom_align: bool,
 	no_nil:       bool,
 	maybe:        bool,

+ 65 - 6
src/llvm_backend.cpp

@@ -10259,7 +10259,7 @@ lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) {
 
 	lb_start_block(p, block_diff_ptr);
 
-	if (type->kind == Type_Struct)  {
+	if (type->kind == Type_Struct) {
 		type_set_offsets(type);
 
 		lbBlock *block_false = lb_create_block(p, "bfalse");
@@ -10285,6 +10285,56 @@ lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) {
 		lb_start_block(p, block_false);
 
 		LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 0, false));
+	} else if (type->kind == Type_Union) {
+		if (is_type_union_maybe_pointer(type)) {
+			Type *v = type->Union.variants[0];
+			Type *pv = alloc_type_pointer(v);
+
+			lbValue left = lb_emit_load(p, lb_emit_conv(p, lhs, pv));
+			lbValue right = lb_emit_load(p, lb_emit_conv(p, rhs, pv));
+
+			lbValue ok = lb_emit_comp(p, Token_CmpEq, left, right);
+			ok = lb_emit_conv(p, ok, t_bool);
+			LLVMBuildRet(p->builder, ok.value);
+		} else {
+			lbBlock *block_false = lb_create_block(p, "bfalse");
+			lbBlock *block_switch = lb_create_block(p, "bswitch");
+
+			lbValue left_tag  = lb_emit_load(p, lb_emit_union_tag_ptr(p, lhs));
+			lbValue right_tag = lb_emit_load(p, lb_emit_union_tag_ptr(p, rhs));
+
+			lbValue tag_eq = lb_emit_comp(p, Token_CmpEq, left_tag, right_tag);
+			lb_emit_if(p, tag_eq, block_switch, block_false);
+
+			lb_start_block(p, block_switch);
+			LLVMValueRef v_switch = LLVMBuildSwitch(p->builder, left_tag.value, block_false->block, cast(unsigned)type->Union.variants.count);
+
+
+			for_array(i, type->Union.variants) {
+				lbBlock *case_block = lb_create_block(p, "bcase");
+				lb_start_block(p, case_block);
+
+				Type *v = type->Union.variants[i];
+				lbValue tag = lb_const_union_tag(p->module, type, v);
+
+				Type *vp = alloc_type_pointer(v);
+
+				lbValue left  = lb_emit_load(p, lb_emit_conv(p, lhs, vp));
+				lbValue right = lb_emit_load(p, lb_emit_conv(p, rhs, vp));
+				lbValue ok = lb_emit_comp(p, Token_CmpEq, left, right);
+				ok = lb_emit_conv(p, ok, t_bool);
+
+				LLVMBuildRet(p->builder, ok.value);
+
+
+				LLVMAddCase(v_switch, tag.value, case_block->block);
+			}
+
+			lb_start_block(p, block_false);
+
+			LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 0, false));
+		}
+
 	} else {
 		lbValue left = lb_emit_load(p, lhs);
 		lbValue right = lb_emit_load(p, rhs);
@@ -10565,7 +10615,7 @@ lbValue lb_emit_comp(lbProcedure *p, TokenKind op_kind, lbValue left, lbValue ri
 	}
 
 
-	if (is_type_struct(a) && is_type_comparable(a)) {
+	if ((is_type_struct(a) || is_type_union(b)) && is_type_comparable(a)) {
 		lbValue left_ptr  = lb_address_from_load_or_generate_local(p, left);
 		lbValue right_ptr = lb_address_from_load_or_generate_local(p, right);
 		lbValue res = {};
@@ -13467,7 +13517,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 			tag = lb_const_ptr_cast(m, variant_ptr, t_type_info_union_ptr);
 
 			{
-				LLVMValueRef vals[6] = {};
+				LLVMValueRef vals[7] = {};
 
 				isize variant_count = gb_max(0, t->Union.variants.count);
 				lbValue memory_types = lb_type_info_member_types_offset(p, variant_count);
@@ -13496,10 +13546,19 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 					vals[2] = LLVMConstNull(lb_type(m, t_type_info_ptr));
 				}
 
-				vals[3] = lb_const_bool(m, t_bool, t->Union.custom_align != 0).value;
-				vals[4] = lb_const_bool(m, t_bool, t->Union.no_nil).value;
-				vals[5] = lb_const_bool(m, t_bool, t->Union.maybe).value;
+				if (is_type_comparable(t) && !is_type_simple_compare(t)) {
+					vals[3] = lb_get_equal_proc_for_type(m, t).value;
+				}
+
+				vals[4] = lb_const_bool(m, t_bool, t->Union.custom_align != 0).value;
+				vals[5] = lb_const_bool(m, t_bool, t->Union.no_nil).value;
+				vals[6] = lb_const_bool(m, t_bool, t->Union.maybe).value;
 
+				for (isize i = 0; i < gb_count_of(vals); i++) {
+					if (vals[i] == nullptr) {
+						vals[i]  = LLVMConstNull(lb_type(m, get_struct_field_type(tag.type, i)));
+					}
+				}
 
 				lbValue res = {};
 				res.type = type_deref(tag.type);

+ 17 - 1
src/types.cpp

@@ -1544,6 +1544,9 @@ bool is_type_valid_for_keys(Type *t) {
 	if (is_type_untyped(t)) {
 		return false;
 	}
+	if (t->kind == Type_Union) {
+		return false;
+	}
 	return is_type_comparable(t);
 }
 
@@ -1915,6 +1918,18 @@ bool is_type_comparable(Type *t) {
 			}
 		}
 		return true;
+
+	case Type_Union:
+		if (type_size_of(t) == 0) {
+			return false;
+		}
+		for_array(i, t->Union.variants) {
+			Type *v = t->Union.variants[i];
+			if (!is_type_comparable(v)) {
+				return false;
+			}
+		}
+		return true;
 	}
 	return false;
 }
@@ -1959,7 +1974,8 @@ bool is_type_simple_compare(Type *t) {
 				return false;
 			}
 		}
-		return true;
+		// make it dumb on purpose
+		return t->Union.variants.count == 1;
 
 	case Type_SimdVector:
 		return is_type_simple_compare(t->SimdVector.elem);