Browse Source

Parametric polymorphic union type

gingerBill 7 years ago
parent
commit
3cd6ae311d
9 changed files with 318 additions and 59 deletions
  1. 7 1
      core/fmt/fmt.odin
  2. 1 0
      core/runtime/core.odin
  3. 18 0
      examples/demo/demo.odin
  4. 45 23
      src/check_expr.cpp
  5. 171 7
      src/check_type.cpp
  6. 7 3
      src/ir.cpp
  7. 23 10
      src/parser.cpp
  8. 4 3
      src/parser.hpp
  9. 42 12
      src/types.cpp

+ 7 - 1
core/fmt/fmt.odin

@@ -353,7 +353,13 @@ write_type :: proc(buf: ^String_Buffer, ti: ^runtime.Type_Info) {
 		write_byte(buf, '}');
 
 	case runtime.Type_Info_Union:
-		write_string(buf, "union {");
+		write_string(buf, "union ");
+		if info.custom_align {
+			write_string(buf, "#align ");
+			write_i64(buf, i64(ti.align), 10);
+			write_byte(buf, ' ');
+		}
+		write_byte(buf, '{');
 		for variant, i in info.variants {
 			if i > 0 do write_string(buf, ", ");
 			write_type(buf, variant);

+ 1 - 0
core/runtime/core.odin

@@ -82,6 +82,7 @@ Type_Info_Union :: struct {
 	variants:   []^Type_Info,
 	tag_offset: uintptr,
 	tag_type:   ^Type_Info,
+	custom_align: bool,
 };
 Type_Info_Enum :: struct {
 	base:      ^Type_Info,

+ 18 - 0
examples/demo/demo.odin

@@ -471,6 +471,24 @@ parametric_polymorphism :: proc() {
 		// and let the user specify the hashing function or make the user store
 		// the hashing procedure with the table
 	}
+
+	{ // Parametric polymorphic union
+		Error :: enum {
+			Foo0,
+			Foo1,
+			Foo2,
+			Foo3,
+		}
+		Para_Union :: union(T: typeid) {T, Error};
+		r: Para_Union(int);
+		fmt.println(typeid_of(type_of(r)));
+
+		fmt.println(r);
+		r = 123;
+		fmt.println(r);
+		r = Error.Foo0;
+		fmt.println(r);
+	}
 }
 
 

+ 45 - 23
src/check_expr.cpp

@@ -60,7 +60,7 @@ Type *   make_optional_ok_type          (Type *value);
 void     check_type_decl                (CheckerContext *c, Entity *e, Ast *type_expr, Type *def);
 Entity * check_selector                 (CheckerContext *c, Operand *operand, Ast *node, Type *type_hint);
 Entity * check_ident                    (CheckerContext *c, Operand *o, Ast *n, Type *named_type, Type *type_hint, bool allow_import_name);
-Entity * find_polymorphic_struct_entity (CheckerContext *c, Type *original_type, isize param_count, Array<Operand> ordered_operands);
+Entity * find_polymorphic_record_entity (CheckerContext *c, Type *original_type, isize param_count, Array<Operand> ordered_operands);
 void     check_not_tuple                (CheckerContext *c, Operand *operand);
 void     convert_to_typed               (CheckerContext *c, Operand *operand, Type *target_type);
 gbString expr_to_string                 (Ast *expression);
@@ -77,6 +77,9 @@ bool     check_representable_as_constant(CheckerContext *c, ExactValue in_value,
 bool     check_procedure_type           (CheckerContext *c, Type *type, Ast *proc_type_node, Array<Operand> *operands = nullptr);
 void     check_struct_type              (CheckerContext *c, Type *struct_type, Ast *node, Array<Operand> *poly_operands,
                                          Type *named_type = nullptr, Type *original_type_for_poly = nullptr);
+void     check_union_type               (CheckerContext *c, Type *union_type, Ast *node, Array<Operand> *poly_operands,
+                                         Type *named_type = nullptr, Type *original_type_for_poly = nullptr);
+
 CallArgumentData check_call_arguments   (CheckerContext *c, Operand *operand, Type *proc_type, Ast *call);
 Type *           check_init_variable    (CheckerContext *c, Entity *e, Operand *operand, String context_name);
 
@@ -4562,10 +4565,15 @@ CallArgumentData check_call_arguments(CheckerContext *c, Operand *operand, Type
 }
 
 
-isize lookup_polymorphic_struct_parameter(TypeStruct *st, String parameter_name) {
-	if (!st->is_polymorphic) return -1;
+isize lookup_polymorphic_record_parameter(Type *t, String parameter_name) {
+	if (!is_type_polymorphic_record(t)) {
+		return -1;
+	}
 
-	TypeTuple *params = &st->polymorphic_params->Tuple;
+	TypeTuple *params = get_record_polymorphic_params(t);
+	if (params == nullptr) {
+		return -1;
+	}
 	for_array(i, params->variables) {
 		Entity *e = params->variables[i];
 		String name = e->token.string;
@@ -4580,14 +4588,12 @@ isize lookup_polymorphic_struct_parameter(TypeStruct *st, String parameter_name)
 }
 
 
-CallArgumentError check_polymorphic_struct_type(CheckerContext *c, Operand *operand, Ast *call) {
+CallArgumentError check_polymorphic_record_type(CheckerContext *c, Operand *operand, Ast *call) {
 	ast_node(ce, CallExpr, call);
 
 	Type *original_type = operand->type;
 	Type *struct_type = base_type(operand->type);
-	GB_ASSERT(struct_type->kind == Type_Struct);
-	TypeStruct *st = &struct_type->Struct;
-	GB_ASSERT(st->is_polymorphic);
+	GB_ASSERT(is_type_polymorphic_record(original_type));
 
 	bool show_error = true;
 
@@ -4617,7 +4623,7 @@ CallArgumentError check_polymorphic_struct_type(CheckerContext *c, Operand *oper
 
 	CallArgumentError err = CallArgumentError_None;
 
-	TypeTuple *tuple = &st->polymorphic_params->Tuple;
+	TypeTuple *tuple = get_record_polymorphic_params(original_type);
 	isize param_count = tuple->variables.count;
 
 	Array<Operand> ordered_operands = operands;
@@ -4640,7 +4646,7 @@ CallArgumentError check_polymorphic_struct_type(CheckerContext *c, Operand *oper
 				continue;
 			}
 			String name = fv->field->Ident.token.string;
-			isize index = lookup_polymorphic_struct_parameter(st, name);
+			isize index = lookup_polymorphic_record_parameter(original_type, name);
 			if (index < 0) {
 				if (show_error) {
 					error(arg, "No parameter named '%.*s' for this polymorphic type", LIT(name));
@@ -4740,10 +4746,9 @@ CallArgumentError check_polymorphic_struct_type(CheckerContext *c, Operand *oper
 	}
 
 	{
-		// TODO(bill): Check for previous types
 		gbAllocator a = c->allocator;
 
-		Entity *found_entity = find_polymorphic_struct_entity(c, original_type, param_count, ordered_operands);
+		Entity *found_entity = find_polymorphic_record_entity(c, original_type, param_count, ordered_operands);
 		if (found_entity) {
 			operand->mode = Addressing_Type;
 			operand->type = found_entity->type;
@@ -4753,15 +4758,30 @@ CallArgumentError check_polymorphic_struct_type(CheckerContext *c, Operand *oper
 		String generated_name = make_string_c(expr_to_string(call));
 
 		Type *named_type = alloc_type_named(generated_name, nullptr, nullptr);
-		Ast *node = clone_ast(st->node);
-		Type *struct_type = alloc_type_struct();
-		struct_type->Struct.node = node;
-		struct_type->Struct.polymorphic_parent = original_type;
-		set_base_type(named_type, struct_type);
-
-		check_open_scope(c, node);
-		check_struct_type(c, struct_type, node, &ordered_operands, named_type, original_type);
-		check_close_scope(c);
+		Type *bt = base_type(original_type);
+		if (bt->kind == Type_Struct) {
+			Ast *node = clone_ast(bt->Struct.node);
+			Type *struct_type = alloc_type_struct();
+			struct_type->Struct.node = node;
+			struct_type->Struct.polymorphic_parent = original_type;
+			set_base_type(named_type, struct_type);
+
+			check_open_scope(c, node);
+			check_struct_type(c, struct_type, node, &ordered_operands, named_type, original_type);
+			check_close_scope(c);
+		} else if (bt->kind == Type_Union) {
+			Ast *node = clone_ast(bt->Union.node);
+			Type *union_type = alloc_type_union();
+			union_type->Union.node = node;
+			union_type->Union.polymorphic_parent = original_type;
+			set_base_type(named_type, union_type);
+
+			check_open_scope(c, node);
+			check_union_type(c, union_type, node, &ordered_operands, named_type, original_type);
+			check_close_scope(c);
+		} else {
+			GB_PANIC("Unsupported parametric polymorphic record type");
+		}
 
 		operand->mode = Addressing_Type;
 		operand->type = named_type;
@@ -4770,6 +4790,8 @@ CallArgumentError check_polymorphic_struct_type(CheckerContext *c, Operand *oper
 }
 
 
+
+
 ExprKind check_call_expr(CheckerContext *c, Operand *operand, Ast *call) {
 	ast_node(ce, CallExpr, call);
 	if (ce->proc != nullptr &&
@@ -4828,8 +4850,8 @@ ExprKind check_call_expr(CheckerContext *c, Operand *operand, Ast *call) {
 
 	if (operand->mode == Addressing_Type) {
 		Type *t = operand->type;
-		if (is_type_polymorphic_struct(t)) {
-			auto err = check_polymorphic_struct_type(c, operand, call);
+		if (is_type_polymorphic_record(t)) {
+			auto err = check_polymorphic_record_type(c, operand, call);
 			if (err == 0) {
 				Ast *ident = operand->expr;
 				while (ident->kind == Ast_SelectorExpr) {

+ 171 - 7
src/check_type.cpp

@@ -241,13 +241,13 @@ bool check_custom_align(CheckerContext *ctx, Ast *node, i64 *align_) {
 }
 
 
-Entity *find_polymorphic_struct_entity(CheckerContext *ctx, Type *original_type, isize param_count, Array<Operand> ordered_operands) {
+Entity *find_polymorphic_record_entity(CheckerContext *ctx, Type *original_type, isize param_count, Array<Operand> ordered_operands) {
 	auto *found_gen_types = map_get(&ctx->checker->info.gen_types, hash_pointer(original_type));
 	if (found_gen_types != nullptr) {
 		for_array(i, *found_gen_types) {
 			Entity *e = (*found_gen_types)[i];
 			Type *t = base_type(e->type);
-			TypeTuple *tuple = &t->Struct.polymorphic_params->Tuple;
+			TypeTuple *tuple = get_record_polymorphic_params(t);
 			bool ok = true;
 			GB_ASSERT(param_count == tuple->variables.count);
 			for (isize j = 0; j < param_count; j++) {
@@ -281,7 +281,7 @@ Entity *find_polymorphic_struct_entity(CheckerContext *ctx, Type *original_type,
 }
 
 
-void add_polymorphic_struct_entity(CheckerContext *ctx, Ast *node, Type *named_type, Type *original_type) {
+void add_polymorphic_record_entity(CheckerContext *ctx, Ast *node, Type *named_type, Type *original_type) {
 	GB_ASSERT(is_type_named(named_type));
 	gbAllocator a = heap_allocator();
 	Scope *s = ctx->scope->parent;
@@ -472,7 +472,7 @@ void check_struct_type(CheckerContext *ctx, Type *struct_type, Ast *node, Array<
 
 		if (original_type_for_poly != nullptr) {
 			GB_ASSERT(named_type != nullptr);
-			add_polymorphic_struct_entity(ctx, node, named_type, original_type_for_poly);
+			add_polymorphic_record_entity(ctx, node, named_type, original_type_for_poly);
 		}
 	}
 
@@ -517,7 +517,7 @@ void check_struct_type(CheckerContext *ctx, Type *struct_type, Ast *node, Array<
 		}
 	}
 }
-void check_union_type(CheckerContext *ctx, Type *union_type, Ast *node) {
+void check_union_type(CheckerContext *ctx, Type *union_type, Ast *node, Array<Operand> *poly_operands, Type *named_type, Type *original_type_for_poly) {
 	GB_ASSERT(is_type_union(union_type));
 	ast_node(ut, UnionType, node);
 
@@ -529,6 +529,170 @@ void check_union_type(CheckerContext *ctx, Type *union_type, Ast *node) {
 
 	union_type->Union.scope = ctx->scope;
 
+	Type *polymorphic_params     = nullptr;
+	bool is_polymorphic          = false;
+	bool can_check_fields        = true;
+	bool is_poly_specialized     = false;
+
+	if (ut->polymorphic_params != nullptr) {
+		ast_node(field_list, FieldList, ut->polymorphic_params);
+		Array<Ast *> params = field_list->list;
+		if (params.count != 0) {
+			isize variable_count = 0;
+			for_array(i, params) {
+				Ast *field = params[i];
+				if (ast_node_expect(field, Ast_Field)) {
+					ast_node(f, Field, field);
+					variable_count += gb_max(f->names.count, 1);
+				}
+			}
+
+			auto entities = array_make<Entity *>(ctx->allocator, 0, variable_count);
+
+			for_array(i, params) {
+				Ast *param = params[i];
+				if (param->kind != Ast_Field) {
+					continue;
+				}
+				ast_node(p, Field, param);
+				Ast *type_expr = p->type;
+				Type *type = nullptr;
+				bool is_type_param = false;
+				bool is_type_polymorphic_type = false;
+				if (type_expr == nullptr) {
+					error(param, "Expected a type for this parameter");
+					continue;
+				}
+				if (type_expr->kind == Ast_Ellipsis) {
+					type_expr = type_expr->Ellipsis.expr;
+					error(param, "A polymorphic parameter cannot be variadic");
+				}
+				if (type_expr->kind == Ast_TypeidType) {
+					is_type_param = true;
+					Type *specialization = nullptr;
+					if (type_expr->TypeidType.specialization != nullptr) {
+						Ast *s = type_expr->TypeidType.specialization;
+						specialization = check_type(ctx, s);
+					}
+					type = alloc_type_generic(ctx->scope, 0, str_lit(""), specialization);
+				} else if (type_expr->kind == Ast_TypeType) {
+					is_type_param = true;
+					Type *specialization = nullptr;
+					if (type_expr->TypeType.specialization != nullptr) {
+						Ast *s = type_expr->TypeType.specialization;
+						specialization = check_type(ctx, s);
+					}
+					type = alloc_type_generic(ctx->scope, 0, str_lit(""), specialization);
+				} else {
+					type = check_type(ctx, type_expr);
+					if (is_type_polymorphic(type)) {
+						is_type_polymorphic_type = true;
+					}
+				}
+
+				if (type == nullptr) {
+					error(params[i], "Invalid parameter type");
+					type = t_invalid;
+				}
+				if (is_type_untyped(type)) {
+					if (is_type_untyped_undef(type)) {
+						error(params[i], "Cannot determine parameter type from ---");
+					} else {
+						error(params[i], "Cannot determine parameter type from a nil");
+					}
+					type = t_invalid;
+				}
+
+				if (is_type_polymorphic_type) {
+					gbString str = type_to_string(type);
+					error(params[i], "Parameter types cannot be polymorphic, got %s", str);
+					gb_string_free(str);
+					type = t_invalid;
+				}
+
+				if (!is_type_param && !is_type_constant_type(type)) {
+					gbString str = type_to_string(type);
+					error(params[i], "A parameter must be a valid constant type, got %s", str);
+					gb_string_free(str);
+				}
+
+				Scope *scope = ctx->scope;
+				for_array(j, p->names) {
+					Ast *name = p->names[j];
+					if (!ast_node_expect(name, Ast_Ident)) {
+						continue;
+					}
+					Entity *e = nullptr;
+
+					Token token = name->Ident.token;
+
+					if (poly_operands != nullptr) {
+						Operand operand = (*poly_operands)[entities.count];
+						if (is_type_param) {
+							GB_ASSERT(operand.mode == Addressing_Type ||
+							          operand.mode == Addressing_Invalid);
+							if (is_type_polymorphic(base_type(operand.type))) {
+								is_polymorphic = true;
+								can_check_fields = false;
+							}
+							e = alloc_entity_type_name(scope, token, operand.type);
+							e->TypeName.is_type_alias = true;
+						} else {
+							GB_ASSERT(operand.mode == Addressing_Constant);
+							e = alloc_entity_constant(scope, token, operand.type, operand.value);
+						}
+					} else {
+						if (is_type_param) {
+							e = alloc_entity_type_name(scope, token, type);
+							e->TypeName.is_type_alias = true;
+						} else {
+							e = alloc_entity_constant(scope, token, type, empty_exact_value);
+						}
+					}
+
+					e->state = EntityState_Resolved;
+					add_entity(ctx->checker, scope, name, e);
+					array_add(&entities, e);
+				}
+			}
+
+			if (entities.count > 0) {
+				Type *tuple = alloc_type_tuple();
+				tuple->Tuple.variables = entities;
+				polymorphic_params = tuple;
+			}
+		}
+
+		if (original_type_for_poly != nullptr) {
+			GB_ASSERT(named_type != nullptr);
+			add_polymorphic_record_entity(ctx, node, named_type, original_type_for_poly);
+		}
+	}
+
+	if (!is_polymorphic) {
+		is_polymorphic = polymorphic_params != nullptr && poly_operands == nullptr;
+	}
+	if (poly_operands != nullptr) {
+		is_poly_specialized = true;
+		for (isize i = 0; i < poly_operands->count; i++) {
+			Operand o = (*poly_operands)[i];
+			if (is_type_polymorphic(o.type)) {
+				is_poly_specialized = false;
+				break;
+			}
+			if (union_type == o.type) {
+				// NOTE(bill): Cycle
+				is_poly_specialized = false;
+				break;
+			}
+		}
+	}
+
+	union_type->Union.scope                   = ctx->scope;
+	union_type->Union.polymorphic_params      = polymorphic_params;
+	union_type->Union.is_polymorphic          = is_polymorphic;
+	union_type->Union.is_poly_specialized     = is_poly_specialized;
+
 	for_array(i, ut->variants) {
 		Ast *node = ut->variants[i];
 		Type *t = check_type_expr(ctx, node, nullptr);
@@ -2021,7 +2185,7 @@ bool check_type_internal(CheckerContext *ctx, Ast *e, Type **type, Type *named_t
 			*type = o.type;
 			if (!ctx->in_polymorphic_specialization) {
 				Type *t = base_type(o.type);
-				if (t != nullptr && is_type_polymorphic_struct_unspecialized(t)) {
+				if (t != nullptr && is_type_polymorphic_record_unspecialized(t)) {
 					err_str = expr_to_string(e);
 					error(e, "Invalid use of a non-specialized polymorphic type '%s'", err_str);
 					return true;
@@ -2209,7 +2373,7 @@ bool check_type_internal(CheckerContext *ctx, Ast *e, Type **type, Type *named_t
 		*type = alloc_type_union();
 		set_base_type(named_type, *type);
 		check_open_scope(&c, e);
-		check_union_type(&c, *type, e);
+		check_union_type(&c, *type, e, nullptr, named_type);
 		check_close_scope(&c);
 		(*type)->Union.node = e;
 		return true;

+ 7 - 3
src/ir.cpp

@@ -8104,9 +8104,10 @@ void ir_setup_type_info_data(irProcedure *proc) { // NOTE(bill): Setup type_info
 			tag = ir_emit_conv(proc, variant_ptr, t_type_info_union_ptr);
 
 			{
-				irValue *variant_types  = ir_emit_struct_ep(proc, tag, 0);
-				irValue *tag_offset_ptr = ir_emit_struct_ep(proc, tag, 1);
-				irValue *tag_type_ptr   = ir_emit_struct_ep(proc, tag, 2);
+				irValue *variant_types    = ir_emit_struct_ep(proc, tag, 0);
+				irValue *tag_offset_ptr   = ir_emit_struct_ep(proc, tag, 1);
+				irValue *tag_type_ptr     = ir_emit_struct_ep(proc, tag, 2);
+				irValue *custom_align_ptr = ir_emit_struct_ep(proc, tag, 3);
 
 				isize variant_count = gb_max(0, t->Union.variants.count);
 				irValue *memory_types = ir_type_info_member_types_offset(proc, variant_count);
@@ -8131,6 +8132,9 @@ void ir_setup_type_info_data(irProcedure *proc) { // NOTE(bill): Setup type_info
 					ir_emit_store(proc, tag_offset_ptr, ir_const_uintptr(tag_offset));
 					ir_emit_store(proc, tag_type_ptr,   ir_type_info(proc, union_tag_type(t)));
 				}
+
+				irValue *is_custom_align = ir_const_bool(t->Union.custom_align != 0);
+				ir_emit_store(proc, custom_align_ptr, is_custom_align);
 			}
 
 			break;

+ 23 - 10
src/parser.cpp

@@ -341,6 +341,7 @@ Ast *clone_ast(Ast *node) {
 		break;
 	case Ast_UnionType:
 		n->UnionType.variants = clone_ast_array(n->UnionType.variants);
+		n->UnionType.polymorphic_params = clone_ast(n->UnionType.polymorphic_params);
 		break;
 	case Ast_EnumType:
 		n->EnumType.base_type = clone_ast(n->EnumType.base_type);
@@ -900,8 +901,8 @@ Ast *ast_dynamic_array_type(AstFile *f, Token token, Ast *elem) {
 }
 
 Ast *ast_struct_type(AstFile *f, Token token, Array<Ast *> fields, isize field_count,
-                         Ast *polymorphic_params, bool is_packed, bool is_raw_union,
-                         Ast *align) {
+                     Ast *polymorphic_params, bool is_packed, bool is_raw_union,
+                     Ast *align) {
 	Ast *result = alloc_ast_node(f, Ast_StructType);
 	result->StructType.token              = token;
 	result->StructType.fields             = fields;
@@ -914,11 +915,12 @@ Ast *ast_struct_type(AstFile *f, Token token, Array<Ast *> fields, isize field_c
 }
 
 
-Ast *ast_union_type(AstFile *f, Token token, Array<Ast *> variants, Ast *align) {
+Ast *ast_union_type(AstFile *f, Token token, Array<Ast *> variants, Ast *polymorphic_params, Ast *align) {
 	Ast *result = alloc_ast_node(f, Ast_UnionType);
-	result->UnionType.token        = token;
-	result->UnionType.variants     = variants;
-	result->UnionType.align = align;
+	result->UnionType.token              = token;
+	result->UnionType.variants           = variants;
+	result->UnionType.polymorphic_params = polymorphic_params;
+	result->UnionType.align              = align;
 	return result;
 }
 
@@ -1827,8 +1829,8 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 	case Token_struct: {
 		Token    token = expect_token(f, Token_struct);
 		Ast *polymorphic_params = nullptr;
-		bool     is_packed          = false;
-		bool     is_raw_union       = false;
+		bool is_packed          = false;
+		bool is_raw_union       = false;
 		Ast *align              = nullptr;
 
 		if (allow_token(f, Token_OpenParen)) {
@@ -1890,13 +1892,23 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 
 	case Token_union: {
 		Token token = expect_token(f, Token_union);
-		Token open = expect_token_after(f, Token_OpenBrace, "union");
 		auto variants = array_make<Ast *>(heap_allocator());
+		Ast *polymorphic_params = nullptr;
 		Ast *align = nullptr;
 
 		CommentGroup *docs = f->lead_comment;
 		Token start_token = f->curr_token;
 
+		if (allow_token(f, Token_OpenParen)) {
+			isize param_count = 0;
+			polymorphic_params = parse_field_list(f, &param_count, 0, Token_CloseParen, false, true);
+			if (param_count == 0) {
+				syntax_error(polymorphic_params, "Expected at least 1 polymorphic parametric");
+				polymorphic_params = nullptr;
+			}
+			expect_token_after(f, Token_CloseParen, "parameter list");
+		}
+
 		while (allow_token(f, Token_Hash)) {
 			Token tag = expect_token_after(f, Token_Ident, "#");
 			 if (tag.string == "align") {
@@ -1909,6 +1921,7 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 			}
 		}
 
+		Token open = expect_token_after(f, Token_OpenBrace, "union");
 
 		while (f->curr_token.kind != Token_CloseBrace &&
 		       f->curr_token.kind != Token_EOF) {
@@ -1923,7 +1936,7 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 
 		Token close = expect_token(f, Token_CloseBrace);
 
-		return ast_union_type(f, token, variants, align);
+		return ast_union_type(f, token, variants, polymorphic_params, align);
 	} break;
 
 	case Token_enum: {

+ 4 - 3
src/parser.hpp

@@ -470,9 +470,10 @@ AST_KIND(_TypeBegin, "", bool) \
 		bool is_raw_union;        \
 	}) \
 	AST_KIND(UnionType, "union type", struct { \
-		Token        token;    \
-		Array<Ast *> variants; \
-		Ast *        align;    \
+		Token        token;      \
+		Array<Ast *> variants;   \
+		Ast *polymorphic_params; \
+		Ast *        align;      \
 	}) \
 	AST_KIND(EnumType, "enum type", struct { \
 		Token        token; \

+ 42 - 12
src/types.cpp

@@ -99,6 +99,20 @@ struct TypeStruct {
 	Entity * names;
 };
 
+struct TypeUnion {
+	Array<Type *> variants;
+	Ast *         node;
+	Scope *       scope;
+	i64           variant_block_size;
+	i64           custom_align;
+	i64           tag_size;
+
+	bool       is_polymorphic;
+	bool       is_poly_specialized;
+	Type *     polymorphic_params; // Type_Tuple
+	Type *     polymorphic_parent;
+};
+
 #define TYPE_KINDS                                        \
 	TYPE_KIND(Basic, BasicType)                           \
 	TYPE_KIND(Named, struct {                             \
@@ -129,6 +143,7 @@ struct TypeStruct {
 		Type *lookup_result_type;                         \
 	})                                                    \
 	TYPE_KIND(Struct,  TypeStruct)                        \
+	TYPE_KIND(Union,   TypeUnion)                         \
 	TYPE_KIND(Enum, struct {                              \
 		Array<Entity *> fields;                           \
 		Ast *node;                                        \
@@ -136,14 +151,6 @@ struct TypeStruct {
 		Entity * names;                                   \
 		Type *   base_type;                               \
 	})                                                    \
-	TYPE_KIND(Union, struct {                             \
-		Array<Type *> variants;                           \
-		Ast *node;                                    \
-		Scope *  scope;                                   \
-		i64      variant_block_size;                      \
-		i64      custom_align;                            \
-		i64      tag_size;                                \
-	})                                                    \
 	TYPE_KIND(Tuple, struct {                             \
 		Array<Entity *> variables; /* Entity_Variable */  \
 		Array<i64>      offsets;                          \
@@ -1026,30 +1033,53 @@ bool is_type_indexable(Type *t) {
 	return false;
 }
 
-bool is_type_polymorphic_struct(Type *t) {
+bool is_type_polymorphic_record(Type *t) {
 	t = base_type(t);
 	if (t->kind == Type_Struct) {
 		return t->Struct.is_polymorphic;
+	} else if (t->kind == Type_Union) {
+		return t->Union.is_polymorphic;
 	}
 	return false;
 }
 
-bool is_type_polymorphic_struct_specialized(Type *t) {
+bool is_type_polymorphic_record_specialized(Type *t) {
 	t = base_type(t);
 	if (t->kind == Type_Struct) {
 		return t->Struct.is_polymorphic && t->Struct.is_poly_specialized;
+	} else if (t->kind == Type_Union) {
+		return t->Union.is_polymorphic && t->Union.is_poly_specialized;
 	}
 	return false;
 }
 
-bool is_type_polymorphic_struct_unspecialized(Type *t) {
+bool is_type_polymorphic_record_unspecialized(Type *t) {
 	t = base_type(t);
 	if (t->kind == Type_Struct) {
 		return t->Struct.is_polymorphic && !t->Struct.is_poly_specialized;
+	} else if (t->kind == Type_Struct) {
+		return t->Struct.is_polymorphic && !t->Struct.is_poly_specialized;
 	}
 	return false;
 }
 
+TypeTuple *get_record_polymorphic_params(Type *t) {
+	t = base_type(t);
+	switch (t->kind) {
+	case Type_Struct:
+		if (t->Struct.polymorphic_params) {
+			return &t->Struct.polymorphic_params->Tuple;
+		}
+		break;
+	case Type_Union:
+		if (t->Union.polymorphic_params) {
+			return &t->Union.polymorphic_params->Tuple;
+		}
+		break;
+	}
+	return nullptr;
+}
+
 
 bool is_type_polymorphic(Type *t) {
 	switch (t->kind) {
@@ -1057,7 +1087,7 @@ bool is_type_polymorphic(Type *t) {
 		return true;
 
 	case Type_Named:
-		return is_type_polymorphic_struct(t->Named.base);
+		return is_type_polymorphic_record(t->Named.base);
 
 	case Type_Pointer:
 		return is_type_polymorphic(t->Pointer.elem);