Browse Source

Basic procedure type parameter specialization

Ginger Bill 8 years ago
parent
commit
054948e701
4 changed files with 204 additions and 26 deletions
  1. 138 16
      src/check_expr.cpp
  2. 23 4
      src/ir.cpp
  3. 17 4
      src/parser.cpp
  4. 26 2
      src/types.cpp

+ 138 - 16
src/check_expr.cpp

@@ -990,6 +990,9 @@ void check_struct_type(Checker *c, Type *struct_type, AstNode *node, Array<Opera
 
 	Type *polymorphic_params = nullptr;
 	bool is_polymorphic = false;
+	bool can_check_fields = true;
+	bool is_poly_specialized = false;
+
 	if (st->polymorphic_params != nullptr) {
 		ast_node(field_list, FieldList, st->polymorphic_params);
 		Array<AstNode *> params = field_list->list;
@@ -1026,7 +1029,18 @@ void check_struct_type(Checker *c, Type *struct_type, AstNode *node, Array<Opera
 				}
 				if (type_expr->kind == AstNode_TypeType) {
 					is_type_param = true;
-					type = make_type_generic(c->allocator, 0, str_lit(""));
+					Type *specialization = nullptr;
+					if (type_expr->TypeType.specialization != nullptr) {
+						AstNode *s = type_expr->TypeType.specialization;
+						specialization = check_type(c, s);
+						if (!is_type_polymorphic_struct(specialization)) {
+							gbString str = type_to_string(specialization);
+							defer (gb_string_free(str));
+							error(s, "Expected a polymorphic record, got %s", str);
+							specialization = nullptr;
+						}
+					}
+					type = make_type_generic(c->allocator, 0, str_lit(""), specialization);
 				} else {
 					type = check_type(c, type_expr);
 					if (is_type_polymorphic(type)) {
@@ -1074,6 +1088,10 @@ void check_struct_type(Checker *c, Type *struct_type, AstNode *node, Array<Opera
 						Operand operand = (*poly_operands)[entities.count];
 						if (is_type_param) {
 							GB_ASSERT(operand.mode == Addressing_Type);
+							if (is_type_polymorphic(base_type(operand.type))) {
+								is_polymorphic = true;
+								can_check_fields = false;
+							}
 							e = make_entity_type_name(c->allocator, scope, token, operand.type);
 							e->TypeName.is_type_alias = true;
 						} else {
@@ -1104,7 +1122,19 @@ void check_struct_type(Checker *c, Type *struct_type, AstNode *node, Array<Opera
 		}
 	}
 
-	is_polymorphic = polymorphic_params != nullptr && poly_operands == nullptr;
+	if (!is_polymorphic) {
+		is_polymorphic = polymorphic_params != nullptr && poly_operands == nullptr;
+	}
+	if (poly_operands != nullptr) {
+		is_poly_specialized = true;
+		for (isize i = 0; i < poly_operands->count; i++) {
+			Operand o = (*poly_operands)[i];
+			if (is_type_polymorphic(o.type)) {
+				is_poly_specialized = false;
+				break;
+			}
+		}
+	}
 
 	Array<Entity *> fields = {};
 
@@ -1120,6 +1150,7 @@ void check_struct_type(Checker *c, Type *struct_type, AstNode *node, Array<Opera
 	struct_type->Record.field_count         = fields.count;
 	struct_type->Record.polymorphic_params  = polymorphic_params;
 	struct_type->Record.is_polymorphic      = is_polymorphic;
+	struct_type->Record.is_poly_specialized = is_poly_specialized;
 
 
 	type_set_offsets(c->allocator, struct_type);
@@ -1514,6 +1545,21 @@ void check_bit_field_type(Checker *c, Type *bit_field_type, Type *named_type, As
 	}
 }
 
+bool is_polymorphic_type_assignable_to_specific(Checker *c, Type *source, Type *specific) {
+	if (!is_type_struct(specific)) {
+		return false;
+	}
+
+	if (!is_type_struct(source)) {
+		return false;
+	}
+
+	source = base_type(source);
+	GB_ASSERT(source->kind == Type_Record && source->Record.kind == TypeRecord_Struct);
+
+	return are_types_identical(source->Record.polymorphic_parent, specific);
+}
+
 bool is_polymorphic_type_assignable(Checker *c, Type *poly, Type *source, bool compound, bool modify_type) {
 	Operand o = {Addressing_Value};
 	o.type = source;
@@ -1527,6 +1573,12 @@ bool is_polymorphic_type_assignable(Checker *c, Type *poly, Type *source, bool c
 		return check_is_assignable_to(c, &o, poly);
 
 	case Type_Generic: {
+		if (poly->Generic.specific != nullptr) {
+			Type *s = poly->Generic.specific;
+			if (!is_polymorphic_type_assignable_to_specific(c, source, s)) {
+				return false;
+			}
+		}
 		if (modify_type) {
 			Type *ds = default_type(source);
 			gb_memmove(poly, ds, gb_size_of(Type));
@@ -1661,14 +1713,39 @@ Type *determine_type_from_polymorphic(Checker *c, Type *poly_type, Operand opera
 		gbString ots = type_to_string(operand.type);
 		defer (gb_string_free(pts));
 		defer (gb_string_free(ots));
-		error(operand.expr,
-		      "Cannot determine polymorphic type from parameter: `%s` to `%s`\n"
-		      "\tNote: Record and procedure types are not yet supported",
-		      ots, pts);
+		error(operand.expr, "Cannot determine polymorphic type from parameter: `%s` to `%s`", ots, pts);
 	}
 	return t_invalid;
 }
 
+bool check_type_specialization_to(Checker *c, Type *type, Type *specialization) {
+	if (type == nullptr ||
+	    type == t_invalid) {
+		return true;
+	}
+
+	Type *t = base_type(type);
+	Type *s = base_type(specialization);
+	if (t->kind != s->kind) {
+		return false;
+	}
+	// gb_printf_err("#1 %s %s\n", type_to_string(type), type_to_string(specialization));
+	if (t->kind != Type_Record) {
+		return false;
+	}
+	// gb_printf_err("#2 %s %s\n", type_to_string(type), type_to_string(specialization));
+	if (t->Record.polymorphic_parent == specialization) {
+		return true;
+	}
+	// gb_printf_err("#3 %s %s\n", type_to_string(t->Record.polymorphic_parent), type_to_string(specialization));
+	if (t->Record.polymorphic_parent == s->Record.polymorphic_parent) {
+		return true;
+	}
+
+
+	return false;
+}
+
 Type *check_get_params(Checker *c, Scope *scope, AstNode *_params, bool *is_variadic_, bool *success_, Array<Operand> *operands) {
 	if (_params == nullptr) {
 		return nullptr;
@@ -1720,6 +1797,7 @@ Type *check_get_params(Checker *c, Scope *scope, AstNode *_params, bool *is_vari
 		bool is_type_param = false;
 		bool is_type_polymorphic_type = false;
 		bool detemine_type_from_operand = false;
+		Type *specialization = nullptr;
 
 
 		if (type_expr == nullptr) {
@@ -1752,12 +1830,25 @@ Type *check_get_params(Checker *c, Scope *scope, AstNode *_params, bool *is_vari
 				}
 			}
 			if (type_expr->kind == AstNode_TypeType) {
+				ast_node(tt, TypeType, type_expr);
 				is_type_param = true;
+				specialization = check_type(c, tt->specialization);
+				if (specialization == t_invalid){
+					specialization = nullptr;
+				}
+				if (specialization) {
+					if (!is_type_polymorphic(specialization)) {
+						gbString str = type_to_string(specialization);
+						error(tt->specialization, "Type specialization requires a polymorphic type, got %s", str);
+						gb_string_free(str);
+					}
+				}
+
 				if (operands != nullptr) {
 					detemine_type_from_operand = true;
 					type = t_invalid;
 				} else {
-					type = make_type_generic(c->allocator, 0, str_lit(""));
+					type = make_type_generic(c->allocator, 0, str_lit(""), specialization);
 				}
 			} else {
 				bool prev = c->context.allow_polymorphic_types;
@@ -1845,6 +1936,18 @@ Type *check_get_params(Checker *c, Scope *scope, AstNode *_params, bool *is_vari
 						success = false;
 						type = t_invalid;
 					}
+					if (is_type_polymorphic_struct(type)) {
+						error(o.expr, "Cannot pass polymorphic struct as a parameter");
+						type = t_invalid;
+					}
+					if (specialization != nullptr && !check_type_specialization_to(c, type, specialization)) {
+						gbString t = type_to_string(type);
+						gbString s = type_to_string(specialization);
+						error(o.expr, "Cannot convert type `%s` to the specialization `%s`", t, s);
+						gb_string_free(s);
+						gb_string_free(t);
+						type = t_invalid;
+					}
 				}
 				param = make_entity_type_name(c->allocator, scope, name->Ident.token, type);
 				param->TypeName.is_type_alias = true;
@@ -2654,7 +2757,18 @@ bool check_type_internal(Checker *c, AstNode *e, Type **type, Type *named_type)
 		}
 
 		Token token = ident->Ident.token;
-		Type *t = make_type_generic(c->allocator, 0, token.string);
+		Type *specific = nullptr;
+		if (pt->specialization != nullptr) {
+			AstNode *s = pt->specialization;
+			specific = check_type(c, s);
+			if (!is_type_polymorphic_struct(specific)) {
+				gbString str = type_to_string(specific);
+				error(s, "Expected a polymorphic record, got %s", str);
+				gb_string_free(str);
+				specific = nullptr;
+			}
+		}
+		Type *t = make_type_generic(c->allocator, 0, token.string, specific);
 		if (c->context.allow_polymorphic_types) {
 			Scope *ps = c->context.polymorphic_scope;
 			Scope *s = c->context.scope;
@@ -6259,6 +6373,7 @@ isize lookup_polymorphic_struct_parameter(TypeRecord *st, String parameter_name)
 	return -1;
 }
 
+
 CallArgumentError check_polymorphic_struct_type(Checker *c, Operand *operand, AstNode *call) {
 	ast_node(ce, CallExpr, call);
 
@@ -6413,7 +6528,11 @@ CallArgumentError check_polymorphic_struct_type(Checker *c, Operand *operand, As
 		err = CallArgumentError_TooFewArguments;
 	}
 
-	if (err == 0) {
+	if (err != 0) {
+		return err;
+	}
+
+	{
 		// TODO(bill): Check for previous types
 		gbAllocator a = c->allocator;
 
@@ -6462,6 +6581,7 @@ CallArgumentError check_polymorphic_struct_type(Checker *c, Operand *operand, As
 		check_struct_type(c, struct_type, node, &ordered_operands);
 		check_close_scope(c);
 		struct_type->Record.node = node;
+		struct_type->Record.polymorphic_parent = original_type;
 
 		Entity *e = nullptr;
 
@@ -6481,13 +6601,15 @@ CallArgumentError check_polymorphic_struct_type(Checker *c, Operand *operand, As
 
 		named_type->Named.type_name = e;
 
-		if (found_gen_types) {
-			array_add(found_gen_types, e);
-		} else {
-			Array<Entity *> array = {};
-			array_init(&array, heap_allocator());
-			array_add(&array, e);
-			map_set(&c->info.gen_types, hash_pointer(original_type), array);
+		if (!struct_type->Record.is_polymorphic) {
+			if (found_gen_types) {
+				array_add(found_gen_types, e);
+			} else {
+				Array<Entity *> array = {};
+				array_init(&array, heap_allocator());
+				array_add(&array, e);
+				map_set(&c->info.gen_types, hash_pointer(original_type), array);
+			}
 		}
 
 		operand->mode = Addressing_Type;

+ 23 - 4
src/ir.cpp

@@ -3670,7 +3670,18 @@ void ir_pop_target_list(irProcedure *proc) {
 void ir_gen_global_type_name(irModule *m, Entity *e, String name) {
 	if (e->type == nullptr) return;
 
-	if (is_type_polymorphic(base_type(e->type))) {
+
+	Type *bt = base_type(e->type);
+
+	bool is_poly = is_type_polymorphic(bt);
+	if (!is_poly) {
+		if (bt->kind == Type_Record &&
+		    bt->Record.is_polymorphic &&
+		    !bt->Record.is_poly_specialized) {
+			is_poly = true;
+		}
+	}
+	if (is_poly) {
 		auto found = map_get(&m->info->gen_types, hash_pointer(e->type));
 		if (found == nullptr) {
 			return;
@@ -3696,7 +3707,6 @@ void ir_gen_global_type_name(irModule *m, Entity *e, String name) {
 	}
 	#endif
 
-	Type *bt = base_type(e->type);
 	if (bt->kind == Type_Record) {
 		Scope *s = bt->Record.scope;
 		if (s != nullptr) {
@@ -7503,14 +7513,23 @@ void ir_gen_tree(irGen *s) {
 			continue;
 		}
 
-		if (map_get(&m->min_dep_map, hash_entity(e)) == nullptr) {
+
+		bool polymorphic_struct = false;
+		if (e->type != nullptr && e->kind == Entity_TypeName) {
+			Type *bt = base_type(e->type);
+			if (bt->kind == Type_Record) {
+				polymorphic_struct = bt->Record.is_polymorphic;
+			}
+		}
+
+		if (!polymorphic_struct && map_get(&m->min_dep_map, hash_entity(e)) == nullptr) {
 			// NOTE(bill): Nothing depends upon it so doesn't need to be built
 			continue;
 		}
 
 		String original_name = name;
 
-		if (!scope->is_global || is_type_polymorphic(e->type)) {
+		if (!scope->is_global || polymorphic_struct || is_type_polymorphic(e->type)) {
 			if (e->kind == Entity_Procedure && (e->Procedure.tags & ProcTag_export) != 0) {
 			} else if (e->kind == Entity_Procedure && e->Procedure.link_name.len > 0) {
 				// Handle later

+ 17 - 4
src/parser.cpp

@@ -379,6 +379,7 @@ AST_NODE_KIND(_DeclEnd,   "", i32) \
 AST_NODE_KIND(_TypeBegin, "", i32) \
 	AST_NODE_KIND(TypeType, "type", struct { \
 		Token token; \
+		AstNode *specialization; \
 	}) \
 	AST_NODE_KIND(HelperType, "helper type", struct { \
 		Token token; \
@@ -387,6 +388,7 @@ AST_NODE_KIND(_TypeBegin, "", i32) \
 	AST_NODE_KIND(PolyType, "polymorphic type", struct { \
 		Token    token; \
 		AstNode *type;  \
+		AstNode *specialization;  \
 	}) \
 	AST_NODE_KIND(ProcType, "procedure type", struct { \
 		Token    token;   \
@@ -1390,9 +1392,10 @@ AstNode *ast_union_field(AstFile *f, AstNode *name, AstNode *list) {
 }
 
 
-AstNode *ast_type_type(AstFile *f, Token token) {
+AstNode *ast_type_type(AstFile *f, Token token, AstNode *specialization) {
 	AstNode *result = make_ast_node(f, AstNode_TypeType);
 	result->TypeType.token = token;
+	result->TypeType.specialization = specialization;
 	return result;
 }
 
@@ -1404,10 +1407,11 @@ AstNode *ast_helper_type(AstFile *f, Token token, AstNode *type) {
 }
 
 
-AstNode *ast_poly_type(AstFile *f, Token token, AstNode *type) {
+AstNode *ast_poly_type(AstFile *f, Token token, AstNode *type, AstNode *specialization) {
 	AstNode *result = make_ast_node(f, AstNode_PolyType);
 	result->PolyType.token = token;
 	result->PolyType.type   = type;
+	result->PolyType.specialization = specialization;
 	return result;
 }
 
@@ -2359,7 +2363,11 @@ AstNode *parse_operand(AstFile *f, bool lhs) {
 	case Token_Dollar: {
 		Token token = expect_token(f, Token_Dollar);
 		AstNode *type = parse_ident(f);
-		return ast_poly_type(f, token, type);
+		AstNode *specialization = nullptr;
+		if (allow_token(f, Token_Quo)) {
+			specialization = parse_type(f);
+		}
+		return ast_poly_type(f, token, type, specialization);
 	} break;
 
 	case Token_type_of: {
@@ -3481,7 +3489,12 @@ AstNode *parse_var_type(AstFile *f, bool allow_ellipsis, bool allow_type_token)
 	AstNode *type = nullptr;
 	if (allow_type_token &&
 	    f->curr_token.kind == Token_type) {
-		type = ast_type_type(f, expect_token(f, Token_type));
+		Token token = expect_token(f, Token_type);
+		AstNode *specialization = nullptr;
+		if (allow_token(f, Token_Quo)) {
+			specialization = parse_type(f);
+		}
+		type = ast_type_type(f, token, specialization);
 	} else {
 		type = parse_type_attempt(f);
 	}

+ 26 - 2
src/types.cpp

@@ -95,7 +95,9 @@ struct TypeRecord {
 	bool     is_packed;
 	bool     is_ordered;
 	bool     is_polymorphic;
+	bool     is_poly_specialized;
 	Type *   polymorphic_params; // Type_Tuple
+	Type *   polymorphic_parent;
 
 	i64      custom_align; // NOTE(bill): Only used in structs at the moment
 	Entity * names;
@@ -103,7 +105,11 @@ struct TypeRecord {
 
 #define TYPE_KINDS                                        \
 	TYPE_KIND(Basic,   BasicType)                         \
-	TYPE_KIND(Generic, struct{ i64 id; String name; })    \
+	TYPE_KIND(Generic, struct {                           \
+		i64    id;                                        \
+		String name;                                      \
+		Type * specific;                                  \
+	})                                                    \
 	TYPE_KIND(Pointer, struct { Type *elem; })            \
 	TYPE_KIND(Atomic,  struct { Type *elem; })            \
 	TYPE_KIND(Array,   struct { Type *elem; i64 count; }) \
@@ -480,10 +486,11 @@ Type *make_type_basic(gbAllocator a, BasicType basic) {
 	return t;
 }
 
-Type *make_type_generic(gbAllocator a, i64 id, String name) {
+Type *make_type_generic(gbAllocator a, i64 id, String name, Type *specific) {
 	Type *t = alloc_type(a, Type_Generic);
 	t->Generic.id = id;
 	t->Generic.name = name;
+	t->Generic.specific = specific;
 	return t;
 }
 
@@ -948,11 +955,24 @@ bool is_type_polymorphic_struct(Type *t) {
 	return false;
 }
 
+bool is_type_polymorphic_struct_specialized(Type *t) {
+	t = base_type(t);
+	if (t->kind == Type_Record &&
+	    t->Record.kind == TypeRecord_Struct) {
+		return t->Record.is_polymorphic && t->Record.is_poly_specialized;
+	}
+	return false;
+}
+
+
 bool is_type_polymorphic(Type *t) {
 	switch (t->kind) {
 	case Type_Generic:
 		return true;
 
+	case Type_Named:
+		return is_type_polymorphic_struct(t->Named.base);
+
 	case Type_Pointer:
 		return is_type_polymorphic(t->Pointer.elem);
 	case Type_Atomic:
@@ -2271,6 +2291,10 @@ gbString write_type_to_string(gbString str, Type *type) {
 			String name = type->Generic.name;
 			str = gb_string_appendc(str, "$");
 			str = gb_string_append_length(str, name.text, name.len);
+			if (type->Generic.specific != nullptr) {
+				str = gb_string_appendc(str, "/");
+				str = write_type_to_string(str, type->Generic.specific);
+			}
 		}
 		break;