Browse Source

Add `equal` procedure field to `runtime.Type_Info_Struct`

gingerBill 4 years ago
parent
commit
4e370e6ed8
5 changed files with 149 additions and 56 deletions
  1. 78 0
      core/reflect/reflect.odin
  2. 5 0
      core/runtime/core.odin
  3. 8 3
      src/ir.cpp
  4. 9 5
      src/llvm_backend.cpp
  5. 49 48
      src/types.cpp

+ 78 - 0
core/reflect/reflect.odin

@@ -1206,3 +1206,81 @@ as_raw_data :: proc(a: any) -> (value: rawptr, valid: bool) {
 
 	return;
 }
+
+
+not_equal :: proc(a, b: any) -> bool {
+	return !equal(a, b);
+}
+equal :: proc(a, b: any) -> bool {
+	if a == nil && b == nil {
+		return true;
+	}
+
+	if a.id != b.id {
+		return false;
+	}
+
+	if a.data == b.data {
+		return true;
+	}
+
+	t := type_info_of(a.id);
+	if .Comparable not_in t.flags {
+		return false;
+	}
+
+	if t.size == 0 {
+		return true;
+	}
+
+	if .Simple_Compare in t.flags {
+		return mem.compare_byte_ptrs((^byte)(a.data), (^byte)(b.data), t.size) == 0;
+	}
+
+	t = runtime.type_info_core(t);
+
+	#partial switch v in t.variant {
+	case Type_Info_String:
+		if v.is_cstring {
+			x := string((^cstring)(a.data)^);
+			y := string((^cstring)(b.data)^);
+			return x == y;
+		} else {
+			x := (^string)(a.data)^;
+			y := (^string)(b.data)^;
+			return x == y;
+		}
+
+	case Type_Info_Array:
+		for i in 0..<v.count {
+			x := rawptr(uintptr(a.data) + uintptr(v.elem_size*i));
+			y := rawptr(uintptr(b.data) + uintptr(v.elem_size*i));
+			if !equal(any{x, v.elem.id}, any{y, v.elem.id}) {
+				return false;
+			}
+		}
+	case Type_Info_Enumerated_Array:
+		for i in 0..<v.count {
+			x := rawptr(uintptr(a.data) + uintptr(v.elem_size*i));
+			y := rawptr(uintptr(b.data) + uintptr(v.elem_size*i));
+			if !equal(any{x, v.elem.id}, any{y, v.elem.id}) {
+				return false;
+			}
+		}
+	case Type_Info_Struct:
+		if v.equal != nil {
+			return v.equal(a.data, b.data);
+		} else {
+			for offset, i in v.offsets {
+				x := rawptr(uintptr(a.data) + offset);
+				y := rawptr(uintptr(b.data) + offset);
+				id := v.types[i].id;
+				if !equal(any{x, id}, any{y, id}) {
+					return false;
+				}
+			}
+		}
+	}
+
+	return true;
+}

+ 5 - 0
core/runtime/core.odin

@@ -88,6 +88,8 @@ Type_Info_Tuple :: struct { // Only used for procedures parameters and results
 	types:        []^Type_Info,
 	names:        []string,
 };
+
+Type_Struct_Equal_Proc :: distinct proc "contextless" (rawptr, rawptr) -> bool;
 Type_Info_Struct :: struct {
 	types:        []^Type_Info,
 	names:        []string,
@@ -97,6 +99,9 @@ Type_Info_Struct :: struct {
 	is_packed:    bool,
 	is_raw_union: bool,
 	custom_align: bool,
+
+	equal: Type_Struct_Equal_Proc, // set only when the struct has .Comparable set but does not have .Simple_Compare set
+
 	// These are only set iff this structure is an SOA structure
 	soa_kind:      Type_Info_Struct_Soa_Kind,
 	soa_base_type: ^Type_Info,

+ 8 - 3
src/ir.cpp

@@ -12353,8 +12353,13 @@ void ir_setup_type_info_data(irProcedure *proc) { // NOTE(bill): Setup type_info
 				ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 6), is_raw_union);
 				ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 7), is_custom_align);
 
+				if (is_type_comparable(t) && !is_type_simple_compare(t)) {
+					ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 8), ir_get_compare_proc_for_type(proc->module, t));
+				}
+
+
 				if (t->Struct.soa_kind != StructSoa_None) {
-					irValue *kind = ir_emit_struct_ep(proc, tag, 8);
+					irValue *kind = ir_emit_struct_ep(proc, tag, 9);
 					Type *kind_type = type_deref(ir_type(kind));
 
 					irValue *soa_kind = ir_value_constant(kind_type, exact_value_i64(t->Struct.soa_kind));
@@ -12363,8 +12368,8 @@ void ir_setup_type_info_data(irProcedure *proc) { // NOTE(bill): Setup type_info
 
 
 					ir_emit_store(proc, kind, soa_kind);
-					ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 9), soa_type);
-					ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 10), soa_len);
+					ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 10), soa_type);
+					ir_emit_store(proc, ir_emit_struct_ep(proc, tag, 11), soa_len);
 				}
 			}
 

+ 9 - 5
src/llvm_backend.cpp

@@ -12178,7 +12178,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 		case Type_Struct: {
 			tag = lb_const_ptr_cast(m, variant_ptr, t_type_info_struct_ptr);
 
-			LLVMValueRef vals[11] = {};
+			LLVMValueRef vals[12] = {};
 
 
 			{
@@ -12188,18 +12188,22 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 				vals[5] = is_packed.value;
 				vals[6] = is_raw_union.value;
 				vals[7] = is_custom_align.value;
+				if (is_type_comparable(t) && !is_type_simple_compare(t)) {
+					vals[8] = lb_get_compare_proc_for_type(m, t).value;
+				}
+
 
 				if (t->Struct.soa_kind != StructSoa_None) {
-					lbValue kind = lb_emit_struct_ep(p, tag, 8);
+					lbValue kind = lb_emit_struct_ep(p, tag, 9);
 					Type *kind_type = type_deref(kind.type);
 
 					lbValue soa_kind = lb_const_value(m, kind_type, exact_value_i64(t->Struct.soa_kind));
 					lbValue soa_type = lb_type_info(m, t->Struct.soa_elem);
 					lbValue soa_len = lb_const_int(m, t_int, t->Struct.soa_count);
 
-					vals[8]  = soa_kind.value;
-					vals[9]  = soa_type.value;
-					vals[10] = soa_len.value;
+					vals[9]  = soa_kind.value;
+					vals[1]  = soa_type.value;
+					vals[11] = soa_len.value;
 				}
 			}
 

+ 49 - 48
src/types.cpp

@@ -1321,54 +1321,6 @@ Type *core_array_type(Type *t) {
 	return t;
 }
 
-// NOTE(bill): type can be easily compared using memcmp
-bool is_type_simple_compare(Type *t) {
-	t = core_type(t);
-	switch (t->kind) {
-	case Type_Array:
-		return is_type_simple_compare(t->Array.elem);
-
-	case Type_EnumeratedArray:
-		return is_type_simple_compare(t->EnumeratedArray.elem);
-
-	case Type_Basic:
-		if (t->Basic.flags & BasicFlag_SimpleCompare) {
-			return true;
-		}
-		return false;
-
-	case Type_Pointer:
-	case Type_Proc:
-	case Type_BitSet:
-	case Type_BitField:
-		return true;
-
-	case Type_Struct:
-		for_array(i, t->Struct.fields) {
-			Entity *f = t->Struct.fields[i];
-			if (!is_type_simple_compare(f->type)) {
-				return false;
-			}
-		}
-		return true;
-
-	case Type_Union:
-		for_array(i, t->Union.variants) {
-			Type *v = t->Union.variants[i];
-			if (!is_type_simple_compare(v)) {
-				return false;
-			}
-		}
-		return true;
-
-	case Type_SimdVector:
-		return is_type_simple_compare(t->SimdVector.elem);
-
-	}
-
-	return false;
-}
-
 
 
 Type *base_complex_elem_type(Type *t) {
@@ -1978,6 +1930,55 @@ bool is_type_comparable(Type *t) {
 	return false;
 }
 
+// NOTE(bill): type can be easily compared using memcmp
+bool is_type_simple_compare(Type *t) {
+	t = core_type(t);
+	switch (t->kind) {
+	case Type_Array:
+		return is_type_simple_compare(t->Array.elem);
+
+	case Type_EnumeratedArray:
+		return is_type_simple_compare(t->EnumeratedArray.elem);
+
+	case Type_Basic:
+		if (t->Basic.flags & BasicFlag_SimpleCompare) {
+			return true;
+		}
+		return false;
+
+	case Type_Pointer:
+	case Type_Proc:
+	case Type_BitSet:
+	case Type_BitField:
+		return true;
+
+	case Type_Struct:
+		for_array(i, t->Struct.fields) {
+			Entity *f = t->Struct.fields[i];
+			if (!is_type_simple_compare(f->type)) {
+				return false;
+			}
+		}
+		return true;
+
+	case Type_Union:
+		for_array(i, t->Union.variants) {
+			Type *v = t->Union.variants[i];
+			if (!is_type_simple_compare(v)) {
+				return false;
+			}
+		}
+		return true;
+
+	case Type_SimdVector:
+		return is_type_simple_compare(t->SimdVector.elem);
+
+	}
+
+	return false;
+}
+
+
 Type *strip_type_aliasing(Type *x) {
 	if (x == nullptr) {
 		return x;