瀏覽代碼

Fix the intrinsics, add min and max

jakubtomsu 1 年之前
父節點
當前提交
f7e0516254
共有 1 個文件被更改,包括 37 次插入9 次删除
  1. 37 9
      src/check_builtin.cpp

+ 37 - 9
src/check_builtin.cpp

@@ -2597,7 +2597,7 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 
 		Type *original_type = operand->type;
 		Type *type = base_type(operand->type);
-		if (operand->mode == Addressing_Type && is_type_enumerated_array(type)) {
+		if (operand->mode == Addressing_Type && (is_type_enumerated_array(type) || is_type_union(type))) {
 			// Okay
 		} else if (!is_type_ordered(type) || !(is_type_numeric(type) || is_type_string(type))) {
 			gbString type_str = type_to_string(original_type);
@@ -2662,6 +2662,14 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 				operand->type  = bt->EnumeratedArray.index;
 				operand->value = *bt->EnumeratedArray.min_value;
 				return true;
+			} else if (is_type_union(type)) {
+				Type *bt = base_type(type);
+				GB_ASSERT(bt->kind == Type_Union);
+				operand->mode  = Addressing_Constant;
+				operand->type  = t_untyped_integer;
+				i64 min_tag = bt->Union.kind == UnionType_no_nil ? 0 : 1;
+				operand->value = exact_value_i64(min_tag);
+				return true;
 			}
 			gbString type_str = type_to_string(original_type);
 			error(call, "Invalid type for 'min', got %s", type_str);
@@ -2766,7 +2774,7 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 		Type *original_type = operand->type;
 		Type *type = base_type(operand->type);
 
-		if (operand->mode == Addressing_Type && is_type_enumerated_array(type)) {
+		if (operand->mode == Addressing_Type && (is_type_enumerated_array(type) || is_type_union(type))) {
 			// Okay
 		} else if (!is_type_ordered(type) || !(is_type_numeric(type) || is_type_string(type))) {
 			gbString type_str = type_to_string(original_type);
@@ -2836,6 +2844,14 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 				operand->type  = bt->EnumeratedArray.index;
 				operand->value = *bt->EnumeratedArray.max_value;
 				return true;
+			} else if (is_type_union(type)) {
+				Type *bt = base_type(type);
+				GB_ASSERT(bt->kind == Type_Union);
+				operand->mode  = Addressing_Constant;
+				operand->type  = t_untyped_integer;
+				i64 max_tag = (bt->Union.kind == UnionType_no_nil ? 0 : 1) + bt->Union.variants.count - 1;
+				operand->value = exact_value_i64(max_tag);
+				return true;
 			}
 			gbString type_str = type_to_string(original_type);
 			error(call, "Invalid type for 'max', got %s", type_str);
@@ -5180,7 +5196,7 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 		}
 		break;
 
-	case BuiltinProc_type_variant_type:
+	case BuiltinProc_type_variant_type_of:
 		{
 			if (operand->mode != Addressing_Type) {
 				error(operand->expr, "Expected a type for '%.*s'", LIT(builtin_name));
@@ -5210,10 +5226,6 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 			}
 			
 			i64 index = big_int_to_i64(&x.value.value_integer);
-			if (u->Union.kind != UnionType_no_nil) {
-				index -= 1;
-			}
-			
 			if (index < 0 || index >= u->Union.variants.count) {
 				error(call, "Variant tag out of bounds index for '%.*s", LIT(builtin_name));
 				operand->mode = Addressing_Type;
@@ -5226,7 +5238,7 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 		}
 		break;
 	
-	case BuiltinProc_type_variant_tag:
+	case BuiltinProc_type_variant_index_of:
 		{
 			if (operand->mode != Addressing_Type) {
 				error(operand->expr, "Expected a type for '%.*s'", LIT(builtin_name));
@@ -5247,10 +5259,26 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 			Type *v = check_type(c, ce->args[1]);
 			u = base_type(u);
 			GB_ASSERT(u->kind == Type_Union);
+
+			i64 index = -1;			
+			for_array(i, u->Union.variants) {
+				Type *vt = u->Union.variants[i];
+				if (union_variant_index_types_equal(v, vt)) {
+					index = i64(i);
+					break;
+				}
+			}
+			
+			if (index < 0) {
+				error(operand->expr, "Expected a variant type for '%.*s'", LIT(builtin_name));
+				operand->mode = Addressing_Invalid;
+				operand->type = t_invalid;
+				return false;
+			}
 			
 			operand->mode = Addressing_Constant;
 			operand->type = t_untyped_integer;
-			operand->value = exact_value_i64(union_variant_index(u, v));
+			operand->value = exact_value_i64(index);
 		}
 		break;