2
0
Эх сурвалжийг харах

Add intrinsics.type_equal_proc; Make `map` use an internal equal procedure to compare keys

gingerBill 4 жил өмнө
parent
commit
39bed567b3

+ 3 - 0
core/intrinsics/intrinsics.odin

@@ -1,4 +1,5 @@
 // This is purely for documentation
+//+ignore
 package intrinsics
 
 // Types
@@ -152,3 +153,5 @@ type_polymorphic_record_parameter_value :: proc($T: typeid, index: int) -> $V --
 
 
 type_field_index_of :: proc($T: typeid, $name: string) -> uintptr ---
+
+type_equal_proc :: proc($T: typeid) -> (equal: proc "contextless" (rawptr, rawptr) -> bool) ---

+ 4 - 2
core/runtime/core.odin

@@ -42,6 +42,8 @@ Platform_Endianness :: enum u8 {
 	Big      = 2,
 }
 
+Equal_Proc :: distinct proc "contextless" (rawptr, rawptr) -> bool;
+
 Type_Info_Struct_Soa_Kind :: enum u8 {
 	None    = 0,
 	Fixed   = 1,
@@ -89,7 +91,6 @@ Type_Info_Tuple :: struct { // Only used for procedures parameters and results
 	names:        []string,
 };
 
-Type_Struct_Equal_Proc :: distinct proc "contextless" (rawptr, rawptr) -> bool;
 Type_Info_Struct :: struct {
 	types:        []^Type_Info,
 	names:        []string,
@@ -100,7 +101,7 @@ Type_Info_Struct :: struct {
 	is_raw_union: bool,
 	custom_align: bool,
 
-	equal: Type_Struct_Equal_Proc, // set only when the struct has .Comparable set but does not have .Simple_Compare set
+	equal: Equal_Proc, // set only when the struct has .Comparable set but does not have .Simple_Compare set
 
 	// These are only set iff this structure is an SOA structure
 	soa_kind:      Type_Info_Struct_Soa_Kind,
@@ -351,6 +352,7 @@ Raw_Map :: struct {
 	entries: Raw_Dynamic_Array,
 }
 
+
 /////////////////////////////
 // Init Startup Procedures //
 /////////////////////////////

+ 5 - 10
core/runtime/dynamic_map_internal.odin

@@ -63,13 +63,13 @@ Map_Entry_Header :: struct {
 
 Map_Header :: struct {
 	m:             ^Raw_Map,
-	is_key_string: bool,
+	equal:         Equal_Proc,
 
 	entry_size:    int,
 	entry_align:   int,
 
-	key_offset:  uintptr,
-	key_size:    int,
+	key_offset:    uintptr,
+	key_size:      int,
 
 	value_offset:  uintptr,
 	value_size:    int,
@@ -115,7 +115,7 @@ __get_map_header :: proc "contextless" (m: ^$T/map[$K]$V) -> Map_Header {
 		value: V,
 	};
 
-	header.is_key_string = intrinsics.type_is_string(K);
+	header.equal = intrinsics.type_equal_proc(K);
 
 	header.entry_size    = int(size_of(Entry));
 	header.entry_align   = int(align_of(Entry));
@@ -275,12 +275,7 @@ __dynamic_map_hash_equal :: proc(h: Map_Header, a, b: Map_Hash) -> bool {
 		if a.key_ptr == b.key_ptr {
 			return true;
 		}
-		assert(a.key_ptr != nil && b.key_ptr != nil);
-
-		if h.is_key_string {
-			return (^string)(a.key_ptr)^ == (^string)(b.key_ptr)^;
-		}
-		return memory_equal(a.key_ptr, b.key_ptr, h.key_size);
+		return h.equal(a.key_ptr, b.key_ptr);
 	}
 	return false;
 }

+ 21 - 0
src/check_expr.cpp

@@ -6060,6 +6060,27 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 			break;
 		}
 		break;
+
+	case BuiltinProc_type_equal_proc:
+		{
+			Operand op = {};
+			Type *bt = check_type(c, ce->args[0]);
+			Type *type = base_type(bt);
+			if (type == nullptr || type == t_invalid) {
+				error(ce->args[0], "Expected a type for '%.*s'", LIT(builtin_name));
+				return false;
+			}
+			if (!is_type_comparable(type)) {
+				gbString t = type_to_string(type);
+				error(ce->args[0], "Expected a comparable type for '%.*s', got %s", LIT(builtin_name), t);
+				gb_string_free(t);
+				return false;
+			}
+
+			operand->mode = Addressing_Value;
+			operand->type = t_equal_proc;
+			break;
+		}
 	}
 
 	return true;

+ 8 - 0
src/checker.cpp

@@ -726,6 +726,14 @@ void init_universal(void) {
 	}
 	add_global_type_entity(str_lit("byte"), &basic_types[Basic_u8]);
 
+	{
+		void set_procedure_abi_types(Type *type);
+
+		Type *args[2] = {t_rawptr, t_rawptr};
+		t_equal_proc = alloc_type_proc_from_types(args, 2, t_bool, false, ProcCC_Contextless);
+		set_procedure_abi_types(t_equal_proc);
+	}
+
 // Constants
 	add_global_constant(str_lit("true"),  t_untyped_bool, exact_value_bool(true));
 	add_global_constant(str_lit("false"), t_untyped_bool, exact_value_bool(false));

+ 4 - 0
src/checker_builtin_procs.hpp

@@ -183,6 +183,8 @@ BuiltinProc__type_simple_boolean_end,
 
 	BuiltinProc_type_field_index_of,
 
+	BuiltinProc_type_equal_proc,
+
 BuiltinProc__type_end,
 
 
@@ -367,5 +369,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 
 	{STR_LIT("type_field_index_of"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 
+	{STR_LIT("type_equal_proc"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+
 	{STR_LIT(""), 0, false, Expr_Stmt, BuiltinProcPkg_intrinsics},
 };

+ 50 - 34
src/ir.cpp

@@ -529,6 +529,8 @@ Type *ir_type(irValue *value);
 irValue *ir_gen_anonymous_proc_lit(irModule *m, String prefix_name, Ast *expr, irProcedure *proc = nullptr);
 void ir_begin_procedure_body(irProcedure *proc);
 void ir_end_procedure_body(irProcedure *proc);
+irValue *ir_get_equal_proc_for_type(irModule *m, Type *type);
+
 
 irAddr ir_addr(irValue *addr) {
 	irAddr v = {irAddr_Default, addr};
@@ -3591,7 +3593,6 @@ irValue *ir_gen_map_header(irProcedure *proc, irValue *map_val_ptr, Type *map_ty
 	irValue *m = ir_emit_conv(proc, map_val_ptr, type_deref(ir_type(gep0)));
 	ir_emit_store(proc, gep0, m);
 
-	ir_emit_store(proc, ir_emit_struct_ep(proc, h, 1), ir_const_bool(is_type_string(key_type)));
 
 	i64 entry_size   = type_size_of  (map_type->Map.entry_type);
 	i64 entry_align  = type_align_of (map_type->Map.entry_type);
@@ -3600,6 +3601,7 @@ irValue *ir_gen_map_header(irProcedure *proc, irValue *map_val_ptr, Type *map_ty
 	i64 value_offset = type_offset_of(map_type->Map.entry_type, 3);
 	i64 value_size   = type_size_of  (map_type->Map.value);
 
+	ir_emit_store(proc, ir_emit_struct_ep(proc, h, 1), ir_get_equal_proc_for_type(proc->module, key_type));
 	ir_emit_store(proc, ir_emit_struct_ep(proc, h, 2), ir_const_int(entry_size));
 	ir_emit_store(proc, ir_emit_struct_ep(proc, h, 3), ir_const_int(entry_align));
 	ir_emit_store(proc, ir_emit_struct_ep(proc, h, 4), ir_const_uintptr(key_offset));
@@ -4867,11 +4869,9 @@ irValue *ir_emit_comp_against_nil(irProcedure *proc, TokenKind op_kind, irValue
 	return nullptr;
 }
 
-irValue *ir_get_compare_proc_for_type(irModule *m, Type *type) {
+irValue *ir_get_equal_proc_for_type(irModule *m, Type *type) {
 	Type *original_type = type;
 	type = base_type(type);
-	GB_ASSERT(type->kind == Type_Struct);
-	type_set_offsets(type);
 	Type *pt = alloc_type_pointer(type);
 
 	auto key = hash_type(type);
@@ -4879,12 +4879,6 @@ irValue *ir_get_compare_proc_for_type(irModule *m, Type *type) {
 	if (found) {
 		return *found;
 	}
-	static Type *proc_type = nullptr;
-	if (proc_type == nullptr) {
-		Type *args[2] = {t_rawptr, t_rawptr};
-		proc_type = alloc_type_proc_from_types(args, 2, t_bool, false, ProcCC_Contextless);
-		set_procedure_abi_types(proc_type);
-	}
 
 	static u32 proc_index = 0;
 
@@ -4895,9 +4889,9 @@ irValue *ir_get_compare_proc_for_type(irModule *m, Type *type) {
 
 
 	Ast *body = alloc_ast_node(nullptr, Ast_Invalid);
-	Entity *e = alloc_entity_procedure(nullptr, make_token_ident(proc_name), proc_type, 0);
+	Entity *e = alloc_entity_procedure(nullptr, make_token_ident(proc_name), t_equal_proc, 0);
 	e->Procedure.link_name = proc_name;
-	irValue *p = ir_value_procedure(m, e, proc_type, nullptr, body, proc_name);
+	irValue *p = ir_value_procedure(m, e, t_equal_proc, nullptr, body, proc_name);
 	map_set(&m->values, hash_entity(e), p);
 	string_map_set(&m->members, proc_name, p);
 
@@ -4908,38 +4902,58 @@ irValue *ir_get_compare_proc_for_type(irModule *m, Type *type) {
 	// ir_start_block(proc, proc->decl_block);
 	GB_ASSERT(proc->curr_block != nullptr);
 
-	irBlock *done = ir_new_block(proc, nullptr, "done"); // NOTE(bill): Append later
-
 	irValue *x = proc->params[0];
 	irValue *y = proc->params[1];
 	irValue *lhs = ir_emit_conv(proc, x, pt);
 	irValue *rhs = ir_emit_conv(proc, y, pt);
 
-	irBlock *block_false = ir_new_block(proc, nullptr, "bfalse");
+	irBlock *block_same_ptr = ir_new_block(proc, nullptr, "same_ptr");
+	irBlock *block_diff_ptr = ir_new_block(proc, nullptr, "diff_ptr");
 
-	for_array(i, type->Struct.fields) {
-		irBlock *next_block = ir_new_block(proc, nullptr, "btrue");
+	irValue *same_ptr = ir_emit_comp(proc, Token_CmpEq, lhs, rhs);
+	ir_emit_if(proc, same_ptr, block_same_ptr, block_diff_ptr);
+	ir_start_block(proc, block_same_ptr);
+	ir_emit(proc, ir_instr_return(proc, ir_const_bool(true)));
 
-		irValue *pleft  = ir_emit_struct_ep(proc, lhs, cast(i32)i);
-		irValue *pright = ir_emit_struct_ep(proc, rhs, cast(i32)i);
-		irValue *left = ir_emit_load(proc, pleft);
-		irValue *right = ir_emit_load(proc, pright);
-		irValue *ok = ir_emit_comp(proc, Token_CmpEq, left, right);
+	ir_start_block(proc, block_diff_ptr);
 
-		ir_emit_if(proc, ok, next_block, block_false);
+	if (type->kind == Type_Struct) {
+		type_set_offsets(type);
 
-		ir_emit_jump(proc, next_block);
-		ir_start_block(proc, next_block);
-	}
+		irBlock *done = ir_new_block(proc, nullptr, "done"); // NOTE(bill): Append later
 
-	ir_emit_jump(proc, done);
-	ir_start_block(proc, block_false);
+		irBlock *block_false = ir_new_block(proc, nullptr, "bfalse");
 
-	ir_emit(proc, ir_instr_return(proc, ir_const_bool(false)));
+		for_array(i, type->Struct.fields) {
+			irBlock *next_block = ir_new_block(proc, nullptr, "btrue");
 
-	ir_emit_jump(proc, done);
-	ir_start_block(proc, done);
-	ir_emit(proc, ir_instr_return(proc, ir_const_bool(true)));
+			irValue *pleft  = ir_emit_struct_ep(proc, lhs, cast(i32)i);
+			irValue *pright = ir_emit_struct_ep(proc, rhs, cast(i32)i);
+			irValue *left = ir_emit_load(proc, pleft);
+			irValue *right = ir_emit_load(proc, pright);
+			irValue *ok = ir_emit_comp(proc, Token_CmpEq, left, right);
+
+			ir_emit_if(proc, ok, next_block, block_false);
+
+			ir_emit_jump(proc, next_block);
+			ir_start_block(proc, next_block);
+		}
+
+		ir_emit_jump(proc, done);
+		ir_start_block(proc, block_false);
+
+		ir_emit(proc, ir_instr_return(proc, ir_const_bool(false)));
+
+		ir_emit_jump(proc, done);
+		ir_start_block(proc, done);
+		ir_emit(proc, ir_instr_return(proc, ir_const_bool(true)));
+	} else {
+		irValue *left = ir_emit_load(proc, lhs);
+		irValue *right = ir_emit_load(proc, rhs);
+		irValue *ok = ir_emit_comp(proc, Token_CmpEq, left, right);
+		ok = ir_emit_conv(proc, ok, t_bool);
+		ir_emit(proc, ir_instr_return(proc, ok));
+	}
 
 	ir_end_procedure_body(proc);
 
@@ -5093,7 +5107,7 @@ irValue *ir_emit_comp(irProcedure *proc, TokenKind op_kind, irValue *left, irVal
 			args[2] = ir_const_int(type_size_of(a));
 			res = ir_emit_runtime_call(proc, "memory_equal", args);
 		} else {
-			irValue *value = ir_get_compare_proc_for_type(proc->module, a);
+			irValue *value = ir_get_equal_proc_for_type(proc->module, a);
 			auto args = array_make<irValue *>(permanent_allocator(), 2);
 			args[0] = ir_emit_conv(proc, left_ptr, t_rawptr);
 			args[1] = ir_emit_conv(proc, right_ptr, t_rawptr);
@@ -7572,6 +7586,8 @@ irValue *ir_build_builtin_proc(irProcedure *proc, Ast *expr, TypeAndValue tv, Bu
 		return ir_emit(proc, ir_instr_atomic_cxchg(proc, type, address, old_value, new_value, id));
 	}
 
+	case BuiltinProc_type_equal_proc:
+		return ir_get_equal_proc_for_type(proc->module, ce->args[0]->tav.type);
 
 	}
 
@@ -12353,7 +12369,7 @@ void ir_setup_type_info_data(irProcedure *proc) { // NOTE(bill): Setup type_info
 				ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 7), is_custom_align);
 
 				if (is_type_comparable(t) && !is_type_simple_compare(t)) {
-					ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 8), ir_get_compare_proc_for_type(proc->module, t));
+					ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 8), ir_get_equal_proc_for_type(proc->module, t));
 				}
 
 

+ 48 - 31
src/llvm_backend.cpp

@@ -8555,6 +8555,10 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
 		res.type = fix_typed;
 		return res;
 	}
+
+
+	case BuiltinProc_type_equal_proc:
+		return lb_get_equal_proc_for_type(p->module, ce->args[0]->tav.type);
 	}
 
 	GB_PANIC("Unhandled built-in procedure %.*s", LIT(builtin_procs[id].name));
@@ -9156,11 +9160,11 @@ lbValue lb_emit_comp_against_nil(lbProcedure *p, TokenKind op_kind, lbValue x) {
 	return {};
 }
 
-lbValue lb_get_compare_proc_for_type(lbModule *m, Type *type) {
+lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) {
 	Type *original_type = type;
 	type = base_type(type);
-	GB_ASSERT(type->kind == Type_Struct);
-	type_set_offsets(type);
+	GB_ASSERT(is_type_comparable(type));
+
 	Type *pt = alloc_type_pointer(type);
 	LLVMTypeRef ptr_type = lb_type(m, pt);
 
@@ -9170,13 +9174,6 @@ lbValue lb_get_compare_proc_for_type(lbModule *m, Type *type) {
 	if (found) {
 		compare_proc = *found;
 	} else {
-		static Type *proc_type = nullptr;
-		if (proc_type == nullptr) {
-			Type *args[2] = {t_rawptr, t_rawptr};
-			proc_type = alloc_type_proc_from_types(args, 2, t_bool, false, ProcCC_Contextless);
-			set_procedure_abi_types(proc_type);
-		}
-
 		static u32 proc_index = 0;
 
 		char buf[16] = {};
@@ -9184,7 +9181,7 @@ lbValue lb_get_compare_proc_for_type(lbModule *m, Type *type) {
 		char *str = gb_alloc_str_len(permanent_allocator(), buf, n-1);
 		String proc_name = make_string_c(str);
 
-		lbProcedure *p = lb_create_dummy_procedure(m, proc_name, proc_type);
+		lbProcedure *p = lb_create_dummy_procedure(m, proc_name, t_equal_proc);
 		lb_begin_procedure_body(p);
 
 		LLVMValueRef x = LLVMGetParam(p->value, 0);
@@ -9194,29 +9191,50 @@ lbValue lb_get_compare_proc_for_type(lbModule *m, Type *type) {
 		lbValue lhs = {x, pt};
 		lbValue rhs = {y, pt};
 
-		lbBlock *block_false = lb_create_block(p, "bfalse");
 
-		lbValue res = lb_const_bool(m, t_bool, true);
-		for_array(i, type->Struct.fields) {
-			lbBlock *next_block = lb_create_block(p, "btrue");
+		lbBlock *block_same_ptr = lb_create_block(p, "same_ptr");
+		lbBlock *block_diff_ptr = lb_create_block(p, "diff_ptr");
 
-			lbValue pleft  = lb_emit_struct_ep(p, lhs, cast(i32)i);
-			lbValue pright = lb_emit_struct_ep(p, rhs, cast(i32)i);
-			lbValue left = lb_emit_load(p, pleft);
-			lbValue right = lb_emit_load(p, pright);
-			lbValue ok = lb_emit_comp(p, Token_CmpEq, left, right);
+		lbValue same_ptr = lb_emit_comp(p, Token_CmpEq, lhs, rhs);
+		lb_emit_if(p, same_ptr, block_same_ptr, block_diff_ptr);
+		lb_start_block(p, block_same_ptr);
+		LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 1, false));
 
-			lb_emit_if(p, ok, next_block, block_false);
+		lb_start_block(p, block_diff_ptr);
 
-			lb_emit_jump(p, next_block);
-			lb_start_block(p, next_block);
-		}
+		if (type->kind == Type_Struct)  {
+			type_set_offsets(type);
 
-		LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 1, false));
+			lbBlock *block_false = lb_create_block(p, "bfalse");
+			lbValue res = lb_const_bool(m, t_bool, true);
+
+			for_array(i, type->Struct.fields) {
+				lbBlock *next_block = lb_create_block(p, "btrue");
+
+				lbValue pleft  = lb_emit_struct_ep(p, lhs, cast(i32)i);
+				lbValue pright = lb_emit_struct_ep(p, rhs, cast(i32)i);
+				lbValue left = lb_emit_load(p, pleft);
+				lbValue right = lb_emit_load(p, pright);
+				lbValue ok = lb_emit_comp(p, Token_CmpEq, left, right);
+
+				lb_emit_if(p, ok, next_block, block_false);
+
+				lb_emit_jump(p, next_block);
+				lb_start_block(p, next_block);
+			}
 
-		lb_start_block(p, block_false);
+			LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 1, false));
 
-		LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 0, false));
+			lb_start_block(p, block_false);
+
+			LLVMBuildRet(p->builder, LLVMConstInt(lb_type(m, t_bool), 0, false));
+		} else {
+			lbValue left = lb_emit_load(p, lhs);
+			lbValue right = lb_emit_load(p, rhs);
+			lbValue ok = lb_emit_comp(p, Token_CmpEq, left, right);
+			ok = lb_emit_conv(p, ok, t_bool);
+			LLVMBuildRet(p->builder, ok.value);
+		}
 
 		lb_end_procedure_body(p);
 
@@ -9370,7 +9388,7 @@ lbValue lb_emit_comp(lbProcedure *p, TokenKind op_kind, lbValue left, lbValue ri
 			args[2] = lb_const_int(p->module, t_int, type_size_of(a));
 			res = lb_emit_runtime_call(p, "memory_equal", args);
 		} else {
-			lbValue value = lb_get_compare_proc_for_type(p->module, a);
+			lbValue value = lb_get_equal_proc_for_type(p->module, a);
 			auto args = array_make<lbValue>(permanent_allocator(), 2);
 			args[0] = lb_emit_conv(p, left_ptr, t_rawptr);
 			args[1] = lb_emit_conv(p, right_ptr, t_rawptr);
@@ -10255,8 +10273,6 @@ lbValue lb_gen_map_header(lbProcedure *p, lbValue map_val_ptr, Type *map_type) {
 	lbValue m = lb_emit_conv(p, map_val_ptr, type_deref(gep0.type));
 	lb_emit_store(p, gep0, m);
 
-	lb_emit_store(p, lb_emit_struct_ep(p, h.addr, 1), lb_const_bool(p->module, t_bool, is_type_string(key_type)));
-
 	i64 entry_size   = type_size_of  (map_type->Map.entry_type);
 	i64 entry_align  = type_align_of (map_type->Map.entry_type);
 
@@ -10266,6 +10282,7 @@ lbValue lb_gen_map_header(lbProcedure *p, lbValue map_val_ptr, Type *map_type) {
 	i64 value_offset = type_offset_of(map_type->Map.entry_type, 3);
 	i64 value_size   = type_size_of  (map_type->Map.value);
 
+	lb_emit_store(p, lb_emit_struct_ep(p, h.addr, 1), lb_get_equal_proc_for_type(p->module, key_type));
 	lb_emit_store(p, lb_emit_struct_ep(p, h.addr, 2), lb_const_int(p->module, t_int, entry_size));
 	lb_emit_store(p, lb_emit_struct_ep(p, h.addr, 3), lb_const_int(p->module, t_int, entry_align));
 	lb_emit_store(p, lb_emit_struct_ep(p, h.addr, 4), lb_const_int(p->module, t_uintptr, key_offset));
@@ -12204,7 +12221,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 				vals[6] = is_raw_union.value;
 				vals[7] = is_custom_align.value;
 				if (is_type_comparable(t) && !is_type_simple_compare(t)) {
-					vals[8] = lb_get_compare_proc_for_type(m, t).value;
+					vals[8] = lb_get_equal_proc_for_type(m, t).value;
 				}
 
 

+ 2 - 0
src/llvm_backend.hpp

@@ -379,6 +379,8 @@ lbValue lb_emit_source_code_location(lbProcedure *p, String const &procedure, To
 
 lbValue lb_handle_param_value(lbProcedure *p, Type *parameter_type, ParameterValue const &param_value, TokenPos const &pos);
 
+lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type);
+lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t);
 
 #define LB_STARTUP_RUNTIME_PROC_NAME   "__$startup_runtime"
 #define LB_STARTUP_TYPE_INFO_PROC_NAME "__$startup_type_info"

+ 2 - 0
src/types.cpp

@@ -690,6 +690,8 @@ gb_global Type *t_map_header                     = nullptr;
 gb_global Type *t_vector_x86_mmx                 = nullptr;
 
 
+gb_global Type *t_equal_proc = nullptr;
+
 
 i64      type_size_of               (Type *t);
 i64      type_align_of              (Type *t);