Browse Source

Allow unions which are comparable to also be valid map keys (i.e. hashable)

gingerBill 4 years ago
parent
commit
1a3784c4df
3 changed files with 43 additions and 5 deletions
  1. 6 0
      src/check_type.cpp
  2. 37 2
      src/llvm_backend.cpp
  3. 0 3
      src/types.cpp

+ 6 - 0
src/check_type.cpp

@@ -2109,6 +2109,12 @@ void add_map_key_type_dependencies(CheckerContext *ctx, Type *key) {
 				Entity *field = key->Struct.fields[i];
 				add_map_key_type_dependencies(ctx, field->type);
 			}
+		} else if (key->kind == Type_Union) {
+			add_package_dependency(ctx, "runtime", "default_hasher_n");
+			for_array(i, key->Union.variants) {
+				Type *v = key->Union.variants[i];
+				add_map_key_type_dependencies(ctx, v);
+			}
 		} else if (key->kind == Type_EnumeratedArray) {
 			add_package_dependency(ctx, "runtime", "default_hasher_n");
 			add_map_key_type_dependencies(ctx, key->EnumeratedArray.elem);

+ 37 - 2
src/llvm_backend.cpp

@@ -10315,7 +10315,7 @@ lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) {
 				lb_start_block(p, case_block);
 
 				Type *v = type->Union.variants[i];
-				lbValue tag = lb_const_union_tag(p->module, type, v);
+				lbValue case_tag = lb_const_union_tag(p->module, type, v);
 
 				Type *vp = alloc_type_pointer(v);
 
@@ -10327,7 +10327,7 @@ lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) {
 				LLVMBuildRet(p->builder, ok.value);
 
 
-				LLVMAddCase(v_switch, tag.value, case_block->block);
+				LLVMAddCase(v_switch, case_tag.value, case_block->block);
 			}
 
 			lb_start_block(p, block_false);
@@ -10403,6 +10403,9 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 	lbValue data = {x, t_rawptr};
 	lbValue seed = {y, t_uintptr};
 
+	LLVMAttributeRef nonnull_attr = lb_create_enum_attribute(m->ctx, "nonnull");
+	LLVMAddAttributeAtIndex(p->value, 1+0, nonnull_attr);
+
 	if (is_type_simple_compare(type)) {
 		lbValue res = lb_simple_compare_hash(p, type, data, seed);
 		LLVMBuildRet(p->builder, res.value);
@@ -10425,6 +10428,38 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 			seed = lb_emit_call(p, field_hasher, args);
 		}
 		LLVMBuildRet(p->builder, seed.value);
+	} else if (type->kind == Type_Union)  {
+		lbBlock *end_block = lb_create_block(p, "bend");
+
+		data = lb_emit_conv(p, data, pt);
+
+		lbValue tag_ptr = lb_emit_union_tag_ptr(p, data);
+		lbValue tag = lb_emit_load(p, tag_ptr);
+
+		LLVMValueRef v_switch = LLVMBuildSwitch(p->builder, tag.value, end_block->block, cast(unsigned)type->Union.variants.count);
+
+		auto args = array_make<lbValue>(permanent_allocator(), 2);
+		for_array(i, type->Union.variants) {
+			lbBlock *case_block = lb_create_block(p, "bcase");
+			lb_start_block(p, case_block);
+
+			Type *v = type->Union.variants[i];
+			Type *vp = alloc_type_pointer(v);
+			lbValue case_tag = lb_const_union_tag(p->module, type, v);
+
+			lbValue variant_hasher = lb_get_hasher_proc_for_type(m, v);
+
+			args[0] = data;
+			args[1] = seed;
+			lbValue res = lb_emit_call(p, variant_hasher, args);
+			LLVMBuildRet(p->builder, res.value);
+
+			LLVMAddCase(v_switch, case_tag.value, case_block->block);
+		}
+
+		lb_start_block(p, end_block);
+		LLVMBuildRet(p->builder, seed.value);
+
 	} else if (type->kind == Type_Array) {
 		lbAddr pres = lb_add_local_generated(p, t_uintptr, false);
 		lb_addr_store(p, pres, seed);

+ 0 - 3
src/types.cpp

@@ -1544,9 +1544,6 @@ bool is_type_valid_for_keys(Type *t) {
 	if (is_type_untyped(t)) {
 		return false;
 	}
-	if (t->kind == Type_Union) {
-		return false;
-	}
 	return is_type_comparable(t);
 }