ソースを参照

Correct static map get; make get take a pointer to simplify compiler internals

gingerBill 2 年 前
コミット
8852d090b6

+ 10 - 11
core/runtime/dynamic_map_internal.odin

@@ -107,7 +107,7 @@ Map_Cell_Info :: struct {
 map_cell_info :: intrinsics.type_map_cell_info
 
 // Same as the above procedure but at runtime with the cell Map_Cell_Info value.
-map_cell_index_dynamic :: #force_inline proc "contextless" (base: uintptr, info: ^Map_Cell_Info, index: uintptr) -> uintptr {
+map_cell_index_dynamic :: #force_inline proc "contextless" (base: uintptr, #no_alias info: ^Map_Cell_Info, index: uintptr) -> uintptr {
 	// Micro-optimize the common cases to save on integer division.
 	elements_per_cell := uintptr(info.elements_per_cell)
 	size_of_cell      := uintptr(info.size_of_cell)
@@ -355,13 +355,13 @@ map_alloc_dynamic :: proc "odin" (info: ^Map_Info, log2_capacity: uintptr, alloc
 // there is no type information.
 //
 // This procedure returns the address of the just inserted value.
-map_insert_hash_dynamic :: proc "odin" (m: ^Raw_Map, #no_alias info: ^Map_Info, h: Map_Hash, ik: uintptr, iv: uintptr) -> (result: uintptr) {
+map_insert_hash_dynamic :: proc "odin" (#no_alias m: ^Raw_Map, #no_alias info: ^Map_Info, h: Map_Hash, ik: uintptr, iv: uintptr) -> (result: uintptr) {
+	h        := h
 	pos      := map_desired_position(m^, h)
 	distance := uintptr(0)
 	mask     := (uintptr(1) << map_log2_cap(m^)) - 1
 
 	ks, vs, hs, sk, sv := map_kvh_data_dynamic(m^, info)
-	_, _ = sk, sv
 
 	// Avoid redundant loads of these values
 	size_of_k := info.ks.size_of_type
@@ -376,7 +376,6 @@ map_insert_hash_dynamic :: proc "odin" (m: ^Raw_Map, #no_alias info: ^Map_Info,
 	tk := map_cell_index_dynamic(sk, info.ks, 1)
 	tv := map_cell_index_dynamic(sv, info.vs, 1)
 
-	h := h
 
 	for {
 		hp := &hs[pos]
@@ -660,19 +659,19 @@ map_get :: proc "contextless" (m: $T/map[$K]$V, key: K) -> (stored_key: K, store
 	}
 }
 
-__dynamic_map_get_with_hash :: proc "contextless" (m: Raw_Map, #no_alias info: ^Map_Info, h: Map_Hash, key: rawptr) -> (ptr: rawptr) {
+__dynamic_map_get_with_hash :: proc "contextless" (#no_alias m: ^Raw_Map, #no_alias info: ^Map_Info, h: Map_Hash, key: rawptr) -> (ptr: rawptr) {
 	if m.len == 0 {
 		return nil
 	}
-	pos := map_desired_position(m, h)
+	pos := map_desired_position(m^, h)
 	distance := uintptr(0)
-	mask := (uintptr(1) << map_log2_cap(m)) - 1
-	ks, vs, hs, _, _ := map_kvh_data_dynamic(m, info)
+	mask := (uintptr(1) << map_log2_cap(m^)) - 1
+	ks, vs, hs, _, _ := map_kvh_data_dynamic(m^, info)
 	for {
 		element_hash := hs[pos]
 		if map_hash_is_empty(element_hash) {
 			return nil
-		} else if distance > map_probe_distance(m, element_hash, pos) {
+		} else if distance > map_probe_distance(m^, element_hash, pos) {
 			return nil
 		} else if element_hash == h && info.key_equal(key, rawptr(map_cell_index_dynamic(ks, info.ks, pos))) {
 			return rawptr(map_cell_index_dynamic(vs, info.vs, pos))
@@ -682,7 +681,7 @@ __dynamic_map_get_with_hash :: proc "contextless" (m: Raw_Map, #no_alias info: ^
 	}
 }
 
-__dynamic_map_get :: proc "contextless" (m: Raw_Map, #no_alias info: ^Map_Info, key: rawptr) -> (ptr: rawptr) {
+__dynamic_map_get :: proc "contextless" (#no_alias m: ^Raw_Map, #no_alias info: ^Map_Info, key: rawptr) -> (ptr: rawptr) {
 	if m.len == 0 {
 		return nil
 	}
@@ -693,7 +692,7 @@ __dynamic_map_get :: proc "contextless" (m: Raw_Map, #no_alias info: ^Map_Info,
 __dynamic_map_set :: proc "odin" (#no_alias m: ^Raw_Map, #no_alias info: ^Map_Info, key, value: rawptr, loc := #caller_location) -> rawptr {
 	hash := info.key_hasher(key, 0)
 
-	if found := __dynamic_map_get_with_hash(m^, info, hash, key); found != nil {
+	if found := __dynamic_map_get_with_hash(m, info, hash, key); found != nil {
 		intrinsics.mem_copy_non_overlapping(found, value, info.vs.size_of_type)
 		return found
 	}

+ 32 - 26
src/llvm_backend.cpp

@@ -140,7 +140,7 @@ lbContextData *lb_push_context_onto_stack(lbProcedure *p, lbAddr ctx) {
 }
 
 
-lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type) {
+lbValue lb_equal_proc_for_type(lbModule *m, Type *type) {
 	type = base_type(type);
 	GB_ASSERT(is_type_comparable(type));
 
@@ -296,7 +296,7 @@ lbValue lb_simple_compare_hash(lbProcedure *p, Type *type, lbValue data, lbValue
 	return lb_emit_runtime_call(p, "default_hasher_n", args);
 }
 
-lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
+lbValue lb_hasher_proc_for_type(lbModule *m, Type *type) {
 	type = core_type(type);
 	GB_ASSERT_MSG(is_type_valid_for_keys(type), "%s", type_to_string(type));
 
@@ -343,7 +343,7 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 			GB_ASSERT(type->Struct.offsets != nullptr);
 			i64 offset = type->Struct.offsets[i];
 			Entity *field = type->Struct.fields[i];
-			lbValue field_hasher = lb_get_hasher_proc_for_type(m, field->type);
+			lbValue field_hasher = lb_hasher_proc_for_type(m, field->type);
 			lbValue ptr = lb_emit_ptr_offset(p, data, lb_const_int(m, t_uintptr, offset));
 
 			args[0] = ptr;
@@ -356,7 +356,7 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 
 		if (is_type_union_maybe_pointer(type)) {
 			Type *v = type->Union.variants[0];
-			lbValue variant_hasher = lb_get_hasher_proc_for_type(m, v);
+			lbValue variant_hasher = lb_hasher_proc_for_type(m, v);
 
 			args[0] = data;
 			args[1] = seed;
@@ -379,7 +379,7 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 			Type *v = type->Union.variants[i];
 			lbValue case_tag = lb_const_union_tag(p->module, type, v);
 
-			lbValue variant_hasher = lb_get_hasher_proc_for_type(m, v);
+			lbValue variant_hasher = lb_hasher_proc_for_type(m, v);
 
 			args[0] = data;
 			args[1] = seed;
@@ -397,7 +397,7 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 		lb_addr_store(p, pres, seed);
 
 		auto args = array_make<lbValue>(permanent_allocator(), 2);
-		lbValue elem_hasher = lb_get_hasher_proc_for_type(m, type->Array.elem);
+		lbValue elem_hasher = lb_hasher_proc_for_type(m, type->Array.elem);
 
 		auto loop_data = lb_loop_start(p, cast(isize)type->Array.count, t_i32);
 
@@ -418,7 +418,7 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 		lb_addr_store(p, res, seed);
 
 		auto args = array_make<lbValue>(permanent_allocator(), 2);
-		lbValue elem_hasher = lb_get_hasher_proc_for_type(m, type->EnumeratedArray.elem);
+		lbValue elem_hasher = lb_hasher_proc_for_type(m, type->EnumeratedArray.elem);
 
 		auto loop_data = lb_loop_start(p, cast(isize)type->EnumeratedArray.count, t_i32);
 
@@ -454,7 +454,7 @@ lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type) {
 }
 
 
-lbValue lb_get_map_get_proc_for_type(lbModule *m, Type *type) {
+lbValue lb_map_get_proc_for_type(lbModule *m, Type *type) {
 	GB_ASSERT(build_context.use_static_map_calls);
 	type = base_type(type);
 	GB_ASSERT(type->kind == Type_Map);
@@ -468,7 +468,7 @@ lbValue lb_get_map_get_proc_for_type(lbModule *m, Type *type) {
 	static u32 proc_index = 0;
 
 	char buf[32] = {};
-	isize n = gb_snprintf(buf, 32, "__$map_get%u", ++proc_index);
+	isize n = gb_snprintf(buf, 32, "__$map_get_%u", ++proc_index);
 	char *str = gb_alloc_str_len(permanent_allocator(), buf, n-1);
 	String proc_name = make_string_c(str);
 
@@ -489,9 +489,23 @@ lbValue lb_get_map_get_proc_for_type(lbModule *m, Type *type) {
 	LLVMAddAttributeAtIndex(p->value, 1+1, nonnull_attr);
 	LLVMAddAttributeAtIndex(p->value, 1+1, noalias_attr);
 
+	lbBlock *loop_block = lb_create_block(p, "loop");
+	lbBlock *hash_block = lb_create_block(p, "hash");
+	lbBlock *probe_block = lb_create_block(p, "probe");
+	lbBlock *increment_block = lb_create_block(p, "increment");
+	lbBlock *hash_compare_block = lb_create_block(p, "hash_compare");
+	lbBlock *key_compare_block = lb_create_block(p, "key_compare");
+	lbBlock *value_block = lb_create_block(p, "value");
+	lbBlock *nil_block = lb_create_block(p, "nil");
+
 	map_ptr = lb_emit_conv(p, map_ptr, t_raw_map_ptr);
 	lbValue map = lb_emit_load(p, map_ptr);
 
+	lbValue length = lb_map_len(p, map);
+
+	lb_emit_if(p, lb_emit_comp(p, Token_CmpEq, length, lb_const_nil(m, t_int)), nil_block, hash_block);
+	lb_start_block(p, hash_block);
+
 	key_ptr = lb_emit_conv(p, key_ptr, alloc_type_pointer(type->Map.key));
 	lbValue key = lb_emit_load(p, key_ptr);
 
@@ -521,16 +535,8 @@ lbValue lb_get_map_get_proc_for_type(lbModule *m, Type *type) {
 	// lbValue res =
 	// LLVMBuildRet(p->builder, res.value);
 
-	lbBlock *loop = lb_create_block(p, "loop");
-	lbBlock *probe_block = lb_create_block(p, "probe");
-	lbBlock *increment_block = lb_create_block(p, "increment");
-	lbBlock *hash_compare_block = lb_create_block(p, "hash_compare");
-	lbBlock *key_compare_block = lb_create_block(p, "key_compare");
-	lbBlock *value_block = lb_create_block(p, "value");
-	lbBlock *nil_block = lb_create_block(p, "nil");
-
-	lb_emit_jump(p, loop);
-	lb_start_block(p, loop);
+	lb_emit_jump(p, loop_block);
+	lb_start_block(p, loop_block);
 
 	lbValue element_hash = lb_emit_load(p, lb_emit_ptr_offset(p, hs, lb_addr_load(p, pos)));
 	{
@@ -577,7 +583,7 @@ lbValue lb_get_map_get_proc_for_type(lbModule *m, Type *type) {
 		lb_addr_store(p, pos, pp);
 		lb_emit_increment(p, distance.addr);
 	}
-	lb_emit_jump(p, loop);
+	lb_emit_jump(p, loop_block);
 
 	lb_start_block(p, nil_block);
 	{
@@ -678,8 +684,8 @@ lbValue lb_gen_map_info_ptr(lbModule *m, Type *map_type) {
 	LLVMValueRef const_values[4] = {};
 	const_values[0] = key_cell_info;
 	const_values[1] = value_cell_info;
-	const_values[2] = lb_get_hasher_proc_for_type(m, map_type->Map.key).value;
-	const_values[3] = lb_get_equal_proc_for_type(m, map_type->Map.key).value;
+	const_values[2] = lb_hasher_proc_for_type(m, map_type->Map.key).value;
+	const_values[3] = lb_equal_proc_for_type(m, map_type->Map.key).value;
 
 	LLVMValueRef llvm_res = llvm_const_named_struct(m, t_map_info, const_values, gb_count_of(const_values));
 	lbValue res = {llvm_res, t_map_info};
@@ -746,7 +752,7 @@ lbValue lb_gen_map_key_hash(lbProcedure *p, lbValue key, Type *key_type, lbValue
 
 	lbValue hashed_key = lb_const_hash(p->module, key, key_type);
 	if (hashed_key.value == nullptr) {
-		lbValue hasher = lb_get_hasher_proc_for_type(p->module, key_type);
+		lbValue hasher = lb_hasher_proc_for_type(p->module, key_type);
 
 		auto args = array_make<lbValue>(permanent_allocator(), 2);
 		args[0] = key_ptr;
@@ -767,16 +773,16 @@ lbValue lb_internal_dynamic_map_get_ptr(lbProcedure *p, lbValue const &map_ptr,
 	key_ptr = lb_emit_conv(p, key_ptr, t_rawptr);
 
 	if (build_context.use_static_map_calls) {
-		lbValue map_get_proc = lb_get_map_get_proc_for_type(p->module, map_type);
+		lbValue map_get_proc = lb_map_get_proc_for_type(p->module, map_type);
 
 		auto args = array_make<lbValue>(permanent_allocator(), 2);
-		args[0] = map_ptr;
+		args[0] = lb_emit_conv(p, map_ptr, t_rawptr);
 		args[1] = key_ptr;
 
 		ptr = lb_emit_call(p, map_get_proc, args);
 	} else {
 		auto args = array_make<lbValue>(permanent_allocator(), 3);
-		args[0] = lb_emit_transmute(p, lb_emit_load(p, map_ptr), t_raw_map);
+		args[0] = lb_emit_transmute(p, map_ptr, t_raw_map_ptr);
 		args[1] = lb_gen_map_info_ptr(p->module, map_type);
 		args[2] = key_ptr;
 

+ 2 - 2
src/llvm_backend.hpp

@@ -463,8 +463,8 @@ lbValue lb_emit_source_code_location_const(lbProcedure *p, String const &procedu
 
 lbValue lb_handle_param_value(lbProcedure *p, Type *parameter_type, ParameterValue const &param_value, TokenPos const &pos);
 
-lbValue lb_get_equal_proc_for_type(lbModule *m, Type *type);
-lbValue lb_get_hasher_proc_for_type(lbModule *m, Type *type);
+lbValue lb_equal_proc_for_type(lbModule *m, Type *type);
+lbValue lb_hasher_proc_for_type(lbModule *m, Type *type);
 lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t);
 
 LLVMMetadataRef lb_debug_type(lbModule *m, Type *type);

+ 1 - 1
src/llvm_backend_expr.cpp

@@ -2215,7 +2215,7 @@ lbValue lb_compare_records(lbProcedure *p, TokenKind op_kind, lbValue left, lbVa
 		args[2] = lb_const_int(p->module, t_int, type_size_of(type));
 		res = lb_emit_runtime_call(p, "memory_equal", args);
 	} else {
-		lbValue value = lb_get_equal_proc_for_type(p->module, type);
+		lbValue value = lb_equal_proc_for_type(p->module, type);
 		auto args = array_make<lbValue>(permanent_allocator(), 2);
 		args[0] = lb_emit_conv(p, left_ptr, t_rawptr);
 		args[1] = lb_emit_conv(p, right_ptr, t_rawptr);

+ 2 - 2
src/llvm_backend_proc.cpp

@@ -2319,10 +2319,10 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
 
 
 	case BuiltinProc_type_equal_proc:
-		return lb_get_equal_proc_for_type(p->module, ce->args[0]->tav.type);
+		return lb_equal_proc_for_type(p->module, ce->args[0]->tav.type);
 
 	case BuiltinProc_type_hasher_proc:
-		return lb_get_hasher_proc_for_type(p->module, ce->args[0]->tav.type);
+		return lb_hasher_proc_for_type(p->module, ce->args[0]->tav.type);
 
 	case BuiltinProc_type_map_info:
 		return lb_gen_map_info_ptr(p->module, ce->args[0]->tav.type);

+ 2 - 2
src/llvm_backend_type.cpp

@@ -666,7 +666,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 				}
 
 				if (is_type_comparable(t) && !is_type_simple_compare(t)) {
-					vals[3] = lb_get_equal_proc_for_type(m, t).value;
+					vals[3] = lb_equal_proc_for_type(m, t).value;
 				}
 
 				vals[4] = lb_const_bool(m, t_bool, t->Union.custom_align != 0).value;
@@ -702,7 +702,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 				vals[6] = is_raw_union.value;
 				vals[7] = is_custom_align.value;
 				if (is_type_comparable(t) && !is_type_simple_compare(t)) {
-					vals[8] = lb_get_equal_proc_for_type(m, t).value;
+					vals[8] = lb_equal_proc_for_type(m, t).value;
 				}
 
 

+ 1 - 1
src/llvm_backend_utility.cpp

@@ -1434,7 +1434,7 @@ lbValue lb_dynamic_array_allocator(lbProcedure *p, lbValue da) {
 }
 
 lbValue lb_map_len(lbProcedure *p, lbValue value) {
-	GB_ASSERT(is_type_map(value.type));
+	GB_ASSERT_MSG(is_type_map(value.type) || are_types_identical(value.type, t_raw_map), "%s", type_to_string(value.type));
 	lbValue len = lb_emit_struct_ev(p, value, 1);
 	return lb_emit_conv(p, len, t_int);
 }