Browse Source

`union #shared_nil`

This adds a feature to `union` which requires all the variants to have a `nil` value and on assign to the union, checks whether that value is `nil` or not. If the value is `nil`, the union will be `nil` (thus sharing the `nil` value)
gingerBill 3 years ago
parent
commit
3f935bea25

+ 1 - 0
core/runtime/core.odin

@@ -136,6 +136,7 @@ Type_Info_Union :: struct {
 	custom_align: bool,
 	no_nil:       bool,
 	maybe:        bool,
+	shared_nil:   bool,
 }
 Type_Info_Enum :: struct {
 	base:      ^Type_Info,

+ 5 - 2
src/check_expr.cpp

@@ -10047,8 +10047,11 @@ gbString write_expr_to_string(gbString str, Ast *node, bool shorthand) {
 			str = write_expr_to_string(str, st->polymorphic_params, shorthand);
 			str = gb_string_appendc(str, ") ");
 		}
-		if (st->no_nil) str = gb_string_appendc(str, "#no_nil ");
-		if (st->maybe)  str = gb_string_appendc(str, "#maybe ");
+		switch (st->kind) {
+		case UnionType_maybe:      str = gb_string_appendc(str, "#maybe ");      break;
+		case UnionType_no_nil:     str = gb_string_appendc(str, "#no_nil ");     break;
+		case UnionType_shared_nil: str = gb_string_appendc(str, "#shared_nil "); break;
+		}
 		if (st->align) {
 			str = gb_string_appendc(str, "#align ");
 			str = write_expr_to_string(str, st->align, shorthand);

+ 14 - 5
src/check_type.cpp

@@ -675,22 +675,31 @@ void check_union_type(CheckerContext *ctx, Type *union_type, Ast *node, Array<Op
 			}
 			if (ok) {
 				array_add(&variants, t);
+
+				if (ut->kind == UnionType_shared_nil) {
+					if (!type_has_nil(t)) {
+						gbString s = type_to_string(t);
+						error(node, "Each variant of a union with #shared_nil must have a 'nil' value, got %s", s);
+						gb_string_free(s);
+					}
+				}
 			}
 		}
 	}
 
 	union_type->Union.variants = slice_from_array(variants);
-	union_type->Union.no_nil = ut->no_nil;
-	union_type->Union.maybe = ut->maybe;
-	if (union_type->Union.no_nil) {
+	union_type->Union.kind = ut->kind;
+	switch (ut->kind) {
+	case UnionType_no_nil:
 		if (variants.count < 2) {
 			error(ut->align, "A union with #no_nil must have at least 2 variants");
 		}
-	}
-	if (union_type->Union.maybe) {
+		break;
+	case UnionType_maybe:
 		if (variants.count != 1) {
 			error(ut->align, "A union with #maybe must have at 1 variant, got %lld", cast(long long)variants.count);
 		}
+		break;
 	}
 
 	if (ut->align != nullptr) {

+ 1 - 0
src/docs_format.cpp

@@ -99,6 +99,7 @@ enum OdinDocTypeFlag_Union : u32 {
 	OdinDocTypeFlag_Union_polymorphic = 1<<0,
 	OdinDocTypeFlag_Union_no_nil      = 1<<1,
 	OdinDocTypeFlag_Union_maybe       = 1<<2,
+	OdinDocTypeFlag_Union_shared_nil  = 1<<3,
 };
 
 enum OdinDocTypeFlag_Proc : u32 {

+ 5 - 3
src/docs_writer.cpp

@@ -619,9 +619,11 @@ OdinDocTypeIndex odin_doc_type(OdinDocWriter *w, Type *type) {
 	case Type_Union:
 		doc_type.kind = OdinDocType_Union;
 		if (type->Union.is_polymorphic) { doc_type.flags |= OdinDocTypeFlag_Union_polymorphic; }
-		if (type->Union.no_nil)         { doc_type.flags |= OdinDocTypeFlag_Union_no_nil; }
-		if (type->Union.maybe)          { doc_type.flags |= OdinDocTypeFlag_Union_maybe; }
-
+		switch (type->Union.kind) {
+		case UnionType_maybe:      doc_type.flags |= OdinDocTypeFlag_Union_maybe;      break;
+		case UnionType_no_nil:     doc_type.flags |= OdinDocTypeFlag_Union_no_nil;     break;
+		case UnionType_shared_nil: doc_type.flags |= OdinDocTypeFlag_Union_shared_nil; break;
+		}
 		{
 			auto variants = array_make<OdinDocTypeIndex>(heap_allocator(), type->Union.variants.count);
 			defer (array_free(&variants));

+ 28 - 3
src/llvm_backend_general.cpp

@@ -1176,10 +1176,35 @@ void lb_emit_store_union_variant_tag(lbProcedure *p, lbValue parent, Type *varia
 }
 
 void lb_emit_store_union_variant(lbProcedure *p, lbValue parent, lbValue variant, Type *variant_type) {
-	lbValue underlying = lb_emit_conv(p, parent, alloc_type_pointer(variant_type));
+	Type *pt = base_type(type_deref(parent.type));
+	GB_ASSERT(pt->kind == Type_Union);
+	if (pt->Union.kind == UnionType_shared_nil) {
+		lbBlock *if_nil     = lb_create_block(p, "shared_nil.if_nil");
+		lbBlock *if_not_nil = lb_create_block(p, "shared_nil.if_not_nil");
+		lbBlock *done       = lb_create_block(p, "shared_nil.done");
+
+		lbValue cond_is_nil = lb_emit_comp_against_nil(p, Token_CmpEq, variant);
+		lb_emit_if(p, cond_is_nil, if_nil, if_not_nil);
+
+		lb_start_block(p, if_nil);
+		lb_emit_store(p, parent, lb_const_nil(p->module, type_deref(parent.type)));
+		lb_emit_jump(p, done);
+
+		lb_start_block(p, if_not_nil);
+		lbValue underlying = lb_emit_conv(p, parent, alloc_type_pointer(variant_type));
+		lb_emit_store(p, underlying, variant);
+		lb_emit_store_union_variant_tag(p, parent, variant_type);
+		lb_emit_jump(p, done);
+
+		lb_start_block(p, done);
 
-	lb_emit_store(p, underlying, variant);
-	lb_emit_store_union_variant_tag(p, parent, variant_type);
+
+	} else {
+		lbValue underlying = lb_emit_conv(p, parent, alloc_type_pointer(variant_type));
+
+		lb_emit_store(p, underlying, variant);
+		lb_emit_store_union_variant_tag(p, parent, variant_type);
+	}
 }
 
 

+ 4 - 3
src/llvm_backend_type.cpp

@@ -641,7 +641,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 			tag = lb_const_ptr_cast(m, variant_ptr, t_type_info_union_ptr);
 
 			{
-				LLVMValueRef vals[7] = {};
+				LLVMValueRef vals[8] = {};
 
 				isize variant_count = gb_max(0, t->Union.variants.count);
 				lbValue memory_types = lb_type_info_member_types_offset(p, variant_count);
@@ -675,8 +675,9 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da
 				}
 
 				vals[4] = lb_const_bool(m, t_bool, t->Union.custom_align != 0).value;
-				vals[5] = lb_const_bool(m, t_bool, t->Union.no_nil).value;
-				vals[6] = lb_const_bool(m, t_bool, t->Union.maybe).value;
+				vals[5] = lb_const_bool(m, t_bool, t->Union.kind == UnionType_no_nil).value;
+				vals[6] = lb_const_bool(m, t_bool, t->Union.kind == UnionType_maybe).value;
+				vals[7] = lb_const_bool(m, t_bool, t->Union.kind == UnionType_shared_nil).value;
 
 				for (isize i = 0; i < gb_count_of(vals); i++) {
 					if (vals[i] == nullptr) {

+ 26 - 4
src/parser.cpp

@@ -1071,15 +1071,14 @@ Ast *ast_struct_type(AstFile *f, Token token, Slice<Ast *> fields, isize field_c
 }
 
 
-Ast *ast_union_type(AstFile *f, Token token, Array<Ast *> const &variants, Ast *polymorphic_params, Ast *align, bool no_nil, bool maybe,
+Ast *ast_union_type(AstFile *f, Token token, Array<Ast *> const &variants, Ast *polymorphic_params, Ast *align, UnionTypeKind kind,
                     Token where_token, Array<Ast *> const &where_clauses) {
 	Ast *result = alloc_ast_node(f, Ast_UnionType);
 	result->UnionType.token              = token;
 	result->UnionType.variants           = slice_from_array(variants);
 	result->UnionType.polymorphic_params = polymorphic_params;
 	result->UnionType.align              = align;
-	result->UnionType.no_nil             = no_nil;
-	result->UnionType.maybe              = maybe;
+	result->UnionType.kind               = kind;
 	result->UnionType.where_token        = where_token;
 	result->UnionType.where_clauses      = slice_from_array(where_clauses);
 	return result;
@@ -2475,6 +2474,9 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 		Ast *align = nullptr;
 		bool no_nil = false;
 		bool maybe = false;
+		bool shared_nil = false;
+
+		UnionTypeKind union_kind = UnionType_Normal;
 
 		Token start_token = f->curr_token;
 
@@ -2501,6 +2503,11 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 					syntax_error(tag, "Duplicate union tag '#%.*s'", LIT(tag.string));
 				}
 				no_nil = true;
+			} else if (tag.string == "shared_nil") {
+				if (shared_nil) {
+					syntax_error(tag, "Duplicate union tag '#%.*s'", LIT(tag.string));
+				}
+				shared_nil = true;
 			} else if (tag.string == "maybe") {
 				if (maybe) {
 					syntax_error(tag, "Duplicate union tag '#%.*s'", LIT(tag.string));
@@ -2513,6 +2520,21 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 		if (no_nil && maybe) {
 			syntax_error(f->curr_token, "#maybe and #no_nil cannot be applied together");
 		}
+		if (no_nil && shared_nil) {
+			syntax_error(f->curr_token, "#shared_nil and #no_nil cannot be applied together");
+		}
+		if (shared_nil && maybe) {
+			syntax_error(f->curr_token, "#maybe and #shared_nil cannot be applied together");
+		}
+
+
+		if (maybe) {
+			union_kind = UnionType_maybe;
+		} else if (no_nil) {
+			union_kind = UnionType_no_nil;
+		} else if (shared_nil) {
+			union_kind = UnionType_shared_nil;
+		}
 
 		skip_possible_newline_for_literal(f);
 
@@ -2544,7 +2566,7 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 
 		Token close = expect_closing_brace_of_field_list(f);
 
-		return ast_union_type(f, token, variants, polymorphic_params, align, no_nil, maybe, where_token, where_clauses);
+		return ast_union_type(f, token, variants, polymorphic_params, align, union_kind, where_token, where_clauses);
 	} break;
 
 	case Token_enum: {

+ 8 - 2
src/parser.hpp

@@ -330,6 +330,13 @@ char const *inline_asm_dialect_strings[InlineAsmDialect_COUNT] = {
 	"intel",
 };
 
+enum UnionTypeKind : u8 {
+	UnionType_Normal     = 0,
+	UnionType_maybe      = 1,
+	UnionType_no_nil     = 2,
+	UnionType_shared_nil = 3,
+};
+
 #define AST_KINDS \
 	AST_KIND(Ident,          "identifier",      struct { \
 		Token   token;  \
@@ -678,8 +685,7 @@ AST_KIND(_TypeBegin, "", bool) \
 		Slice<Ast *> variants;      \
 		Ast *polymorphic_params;    \
 		Ast *        align;         \
-		bool         maybe;         \
-		bool         no_nil;        \
+		UnionTypeKind kind;       \
 		Token where_token;          \
 		Slice<Ast *> where_clauses; \
 	}) \

+ 12 - 10
src/types.cpp

@@ -165,9 +165,8 @@ struct TypeUnion {
 
 	i16           tag_size;
 	bool          is_polymorphic;
-	bool          is_poly_specialized : 1;
-	bool          no_nil              : 1;
-	bool          maybe               : 1;
+	bool          is_poly_specialized;
+	UnionTypeKind kind;
 };
 
 struct TypeProc {
@@ -1664,7 +1663,7 @@ bool is_type_map(Type *t) {
 
 bool is_type_union_maybe_pointer(Type *t) {
 	t = base_type(t);
-	if (t->kind == Type_Union && t->Union.maybe) {
+	if (t->kind == Type_Union && t->Union.kind == UnionType_maybe) {
 		if (t->Union.variants.count == 1) {
 			Type *v = t->Union.variants[0];
 			return is_type_pointer(v) || is_type_multi_pointer(v);
@@ -1676,7 +1675,7 @@ bool is_type_union_maybe_pointer(Type *t) {
 
 bool is_type_union_maybe_pointer_original_alignment(Type *t) {
 	t = base_type(t);
-	if (t->kind == Type_Union && t->Union.maybe) {
+	if (t->kind == Type_Union && t->Union.kind == UnionType_maybe) {
 		if (t->Union.variants.count == 1) {
 			Type *v = t->Union.variants[0];
 			if (is_type_pointer(v) || is_type_multi_pointer(v)) {
@@ -2168,7 +2167,7 @@ bool type_has_nil(Type *t) {
 	case Type_Map:
 		return true;
 	case Type_Union:
-		return !t->Union.no_nil;
+		return t->Union.kind != UnionType_no_nil;
 	case Type_Struct:
 		if (is_type_soa_struct(t)) {
 			switch (t->Struct.soa_kind) {
@@ -2454,7 +2453,7 @@ bool are_types_identical_internal(Type *x, Type *y, bool check_tuple_names) {
 		if (y->kind == Type_Union) {
 			if (x->Union.variants.count == y->Union.variants.count &&
 			    x->Union.custom_align == y->Union.custom_align &&
-			    x->Union.no_nil == y->Union.no_nil) {
+			    x->Union.kind == y->Union.kind) {
 				// NOTE(bill): zeroth variant is nullptr
 				for_array(i, x->Union.variants) {
 					if (!are_types_identical(x->Union.variants[i], y->Union.variants[i])) {
@@ -2598,7 +2597,7 @@ i64 union_variant_index(Type *u, Type *v) {
 	for_array(i, u->Union.variants) {
 		Type *vt = u->Union.variants[i];
 		if (are_types_identical(v, vt)) {
-			if (u->Union.no_nil) {
+			if (u->Union.kind == UnionType_no_nil) {
 				return cast(i64)(i+0);
 			} else {
 				return cast(i64)(i+1);
@@ -4021,8 +4020,11 @@ gbString write_type_to_string(gbString str, Type *type, bool shorthand=false) {
 
 	case Type_Union:
 		str = gb_string_appendc(str, "union");
-		if (type->Union.no_nil != 0) str = gb_string_appendc(str, " #no_nil");
-		if (type->Union.maybe != 0)  str = gb_string_appendc(str, " #maybe");
+		switch (type->Union.kind) {
+		case UnionType_maybe:      str = gb_string_appendc(str, " #maybe");      break;
+		case UnionType_no_nil:     str = gb_string_appendc(str, " #no_nil");     break;
+		case UnionType_shared_nil: str = gb_string_appendc(str, " #shared_nil"); break;
+		}
 		if (type->Union.custom_align != 0) str = gb_string_append_fmt(str, " #align %d", cast(int)type->Union.custom_align);
 		str = gb_string_appendc(str, " {");
 		for_array(i, type->Union.variants) {