Browse Source

Merge pull request #2894 from jakubtomsu/union-tag-intrinsics

New built-in procedures for unions
gingerBill 1 year ago
parent
commit
955be66f1a
3 changed files with 216 additions and 1 deletions
  1. 7 0
      core/intrinsics/intrinsics.odin
  2. 196 0
      src/check_builtin.cpp
  3. 13 1
      src/checker_builtin_procs.hpp

+ 7 - 0
core/intrinsics/intrinsics.odin

@@ -162,7 +162,14 @@ type_is_matrix           :: proc($T: typeid) -> bool ---
 type_has_nil :: proc($T: typeid) -> bool ---
 
 type_is_specialization_of :: proc($T, $S: typeid) -> bool ---
+
 type_is_variant_of :: proc($U, $V: typeid) -> bool where type_is_union(U) ---
+type_union_tag_type :: proc($T: typeid) -> typeid where type_is_union(T) ---
+type_union_tag_offset :: proc($T: typeid) -> uintptr where type_is_union(T) ---
+type_union_base_tag_value :: proc($T: typeid) -> int where type_is_union(U) ---
+type_union_variant_count :: proc($T: typeid) -> int where type_is_union(T) ---
+type_variant_type_of :: proc($T: typeid, $index: int) -> typeid where type_is_union(T) ---
+type_variant_index_of :: proc($U, $V: typeid) -> int where type_is_union(U) ---
 
 type_has_field :: proc($T: typeid, $name: string) -> bool ---
 type_field_type :: proc($T: typeid, $name: string) -> typeid ---

+ 196 - 0
src/check_builtin.cpp

@@ -5117,6 +5117,202 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 		}
 		break;
 
+	case BuiltinProc_type_union_tag_type:
+		{
+			if (operand->mode != Addressing_Type) {
+				error(operand->expr, "Expected a type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			Type *u = operand->type;
+
+			if (!is_type_union(u)) {
+				error(operand->expr, "Expected a union type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			u = base_type(u);
+			GB_ASSERT(u->kind == Type_Union);
+			
+			operand->mode = Addressing_Type;
+			operand->type = union_tag_type(u);
+		}
+		break;
+
+	case BuiltinProc_type_union_tag_offset:
+		{
+			if (operand->mode != Addressing_Type) {
+				error(operand->expr, "Expected a type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			Type *u = operand->type;
+
+			if (!is_type_union(u)) {
+				error(operand->expr, "Expected a union type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			u = base_type(u);
+			GB_ASSERT(u->kind == Type_Union);
+			
+			// NOTE(jakubtomsu): forces calculation of variant_block_size
+			type_size_of(u);
+			i64 tag_offset = u->Union.variant_block_size;
+			GB_ASSERT(tag_offset > 0);
+			
+			operand->mode = Addressing_Constant;
+			operand->type = t_untyped_integer;
+			operand->value = exact_value_i64(tag_offset);
+		}
+		break;
+
+	case BuiltinProc_type_union_base_tag_value:
+		{
+			if (operand->mode != Addressing_Type) {
+				error(operand->expr, "Expected a type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			Type *u = operand->type;
+
+			if (!is_type_union(u)) {
+				error(operand->expr, "Expected a union type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			u = base_type(u);
+			GB_ASSERT(u->kind == Type_Union);
+			
+			operand->mode = Addressing_Constant;
+			operand->type = t_untyped_integer;
+			operand->value = exact_value_i64(u->Union.kind == UnionType_no_nil ? 0 : 1);
+		} break;
+
+	case BuiltinProc_type_union_variant_count:
+		{
+			if (operand->mode != Addressing_Type) {
+				error(operand->expr, "Expected a type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			Type *u = operand->type;
+
+			if (!is_type_union(u)) {
+				error(operand->expr, "Expected a union type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			u = base_type(u);
+			GB_ASSERT(u->kind == Type_Union);
+			
+			operand->mode = Addressing_Constant;
+			operand->type = t_untyped_integer;
+			operand->value = exact_value_i64(u->Union.variants.count);
+		} break;
+
+	case BuiltinProc_type_variant_type_of:
+		{
+			if (operand->mode != Addressing_Type) {
+				error(operand->expr, "Expected a type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			Type *u = operand->type;
+
+			if (!is_type_union(u)) {
+				error(operand->expr, "Expected a union type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			u = base_type(u);
+			GB_ASSERT(u->kind == Type_Union);
+			Operand x = {};
+			check_expr_or_type(c, &x, ce->args[1]);
+			if (!is_type_integer(x.type) || x.mode != Addressing_Constant) {
+				error(call, "Expected a constant integer for '%.*s", LIT(builtin_name));
+				operand->mode = Addressing_Type;
+				operand->type = t_invalid;
+				return false;
+			}
+			
+			i64 index = big_int_to_i64(&x.value.value_integer);
+			if (index < 0 || index >= u->Union.variants.count) {
+				error(call, "Variant tag out of bounds index for '%.*s", LIT(builtin_name));
+				operand->mode = Addressing_Type;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			operand->mode = Addressing_Type;
+			operand->type = u->Union.variants[index];
+		}
+		break;
+	
+	case BuiltinProc_type_variant_index_of:
+		{
+			if (operand->mode != Addressing_Type) {
+				error(operand->expr, "Expected a type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			Type *u = operand->type;
+
+			if (!is_type_union(u)) {
+				error(operand->expr, "Expected a union type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+
+			Type *v = check_type(c, ce->args[1]);
+			u = base_type(u);
+			GB_ASSERT(u->kind == Type_Union);
+
+			i64 index = -1;			
+			for_array(i, u->Union.variants) {
+				Type *vt = u->Union.variants[i];
+				if (union_variant_index_types_equal(v, vt)) {
+					index = i64(i);
+					break;
+				}
+			}
+			
+			if (index < 0) {
+				error(operand->expr, "Expected a variant type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
+			
+			operand->mode = Addressing_Constant;
+			operand->type = t_untyped_integer;
+			operand->value = exact_value_i64(index);
+		}
+		break;
+
 	case BuiltinProc_type_struct_field_count:
 		operand->value = exact_value_i64(0);
 		if (operand->mode != Addressing_Type) {

+ 13 - 1
src/checker_builtin_procs.hpp

@@ -260,6 +260,12 @@ BuiltinProc__type_simple_boolean_end,
 	BuiltinProc_type_is_specialization_of,
 
 	BuiltinProc_type_is_variant_of,
+	BuiltinProc_type_union_tag_type,
+	BuiltinProc_type_union_tag_offset,
+	BuiltinProc_type_union_base_tag_value,
+	BuiltinProc_type_union_variant_count,
+	BuiltinProc_type_variant_type_of,
+	BuiltinProc_type_variant_index_of,
 
 	BuiltinProc_type_struct_field_count,
 
@@ -557,7 +563,13 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 
 	{STR_LIT("type_is_specialization_of"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 
-	{STR_LIT("type_is_variant_of"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("type_is_variant_of"),          2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("type_union_tag_type"),         1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("type_union_tag_offset"),       1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("type_union_base_tag_value"),   1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("type_union_variant_count"),    1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("type_variant_type_of"),        2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("type_variant_index_of"),       2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 
 	{STR_LIT("type_struct_field_count"),   1, false, Expr_Expr, BuiltinProcPkg_intrinsics},