Browse Source

Add `where` clauses to `struct` and `union`

gingerBill 6 years ago
parent
commit
4afc78efc6
8 changed files with 125 additions and 61 deletions
  1. 11 7
      core/odin/ast/ast.odin
  2. 33 9
      core/odin/parser/parser.odin
  3. 1 1
      src/check_decl.cpp
  4. 9 13
      src/check_expr.cpp
  5. 4 1
      src/check_type.cpp
  6. 1 1
      src/checker.cpp
  7. 50 17
      src/parser.cpp
  8. 16 12
      src/parser.hpp

+ 11 - 7
core/odin/ast/ast.odin

@@ -568,13 +568,15 @@ Dynamic_Array_Type :: struct {
 
 Struct_Type :: struct {
 	using node: Expr,
-	tok_pos:   token.Pos,
-	poly_params:  ^Field_List,
-	align:        ^Expr,
-	is_packed:    bool,
-	is_raw_union: bool,
-	fields:       ^Field_List,
-	name_count:  int,
+	tok_pos:       token.Pos,
+	poly_params:   ^Field_List,
+	align:         ^Expr,
+	fields:        ^Field_List,
+	name_count:    int,
+	where_token:   token.Token,
+	where_clauses: []^Expr,
+	is_packed:     bool,
+	is_raw_union:  bool,
 }
 
 Union_Type :: struct {
@@ -583,6 +585,8 @@ Union_Type :: struct {
 	poly_params: ^Field_List,
 	align:       ^Expr,
 	variants:    []^Expr,
+	where_token: token.Token,
+	where_clauses: []^Expr,
 }
 
 Enum_Type :: struct {

+ 33 - 9
core/odin/parser/parser.odin

@@ -2174,17 +2174,29 @@ parse_operand :: proc(p: ^Parser, lhs: bool) -> ^ast.Expr {
 			error(p, tok.pos, "'#raw_union' cannot also be '#packed");
 		}
 
+		where_token: token.Token;
+		where_clauses: []^ast.Expr;
+		if (p.curr_tok.kind == token.Where) {
+			where_token = expect_token(p, token.Where);
+			prev_level := p.expr_level;
+			p.expr_level = -1;
+			where_clauses = parse_rhs_expr_list(p);
+			p.expr_level = prev_level;
+		}
+
 		expect_token(p, token.Open_Brace);
 		fields, name_count = parse_field_list(p, token.Close_Brace, ast.Field_Flags_Struct);
 		close := expect_token(p, token.Close_Brace);
 
 		st := ast.new(ast.Struct_Type, tok.pos, end_pos(close));
-		st.poly_params  = poly_params;
-		st.align        = align;
-		st.is_packed    = is_packed;
-		st.is_raw_union = is_raw_union;
-		st.fields       = fields;
-		st.name_count   = name_count;
+		st.poly_params   = poly_params;
+		st.align         = align;
+		st.is_packed     = is_packed;
+		st.is_raw_union  = is_raw_union;
+		st.fields        = fields;
+		st.name_count    = name_count;
+		st.where_token   = where_token;
+		st.where_clauses = where_clauses;
 		return st;
 
 	case token.Union:
@@ -2217,6 +2229,16 @@ parse_operand :: proc(p: ^Parser, lhs: bool) -> ^ast.Expr {
 		}
 		p.expr_level = prev_level;
 
+		where_token: token.Token;
+		where_clauses: []^ast.Expr;
+		if (p.curr_tok.kind == token.Where) {
+			where_token = expect_token(p, token.Where);
+			prev_level := p.expr_level;
+			p.expr_level = -1;
+			where_clauses = parse_rhs_expr_list(p);
+			p.expr_level = prev_level;
+		}
+
 		variants: [dynamic]^ast.Expr;
 
 		expect_token_after(p, token.Open_Brace, "union");
@@ -2234,9 +2256,11 @@ parse_operand :: proc(p: ^Parser, lhs: bool) -> ^ast.Expr {
 		close := expect_token(p, token.Close_Brace);
 
 		ut := ast.new(ast.Union_Type, tok.pos, end_pos(close));
-		ut.poly_params = poly_params;
-		ut.variants    = variants[:];
-		ut.align       = align;
+		ut.poly_params   = poly_params;
+		ut.variants      = variants[:];
+		ut.align         = align;
+		ut.where_token   = where_token;
+		ut.where_clauses = where_clauses;
 
 		return ut;
 

+ 1 - 1
src/check_decl.cpp

@@ -1176,7 +1176,7 @@ void check_proc_body(CheckerContext *ctx_, Token token, DeclInfo *decl, Type *ty
 	}
 
 
-	bool where_clause_ok = evaluate_where_clauses(ctx, decl, true);
+	bool where_clause_ok = evaluate_where_clauses(ctx, decl->scope, &decl->proc_lit->ProcLit.where_clauses, true);
 	if (!where_clause_ok) {
 		// NOTE(bill, 2019-08-31): Don't check the body as the where clauses failed
 		return;

+ 9 - 13
src/check_expr.cpp

@@ -5471,15 +5471,10 @@ Entity **populate_proc_parameter_list(CheckerContext *c, Type *proc_type, isize
 }
 
 
-bool evaluate_where_clauses(CheckerContext *ctx, DeclInfo *decl, bool print_err) {
-	Ast *proc_lit = decl->proc_lit;
-	GB_ASSERT(proc_lit != nullptr);
-	GB_ASSERT(proc_lit->kind == Ast_ProcLit);
-
-	if (proc_lit->ProcLit.where_token.kind != Token_Invalid) {
-		auto &clauses = proc_lit->ProcLit.where_clauses;
-		for_array(i, clauses) {
-			Ast *clause = clauses[i];
+bool evaluate_where_clauses(CheckerContext *ctx, Scope *scope, Array<Ast *> *clauses, bool print_err) {
+	if (clauses != nullptr) {
+		for_array(i, *clauses) {
+			Ast *clause = (*clauses)[i];
 			Operand o = {};
 			check_expr(ctx, &o, clause);
 			if (o.mode != Addressing_Constant) {
@@ -5494,10 +5489,10 @@ bool evaluate_where_clauses(CheckerContext *ctx, DeclInfo *decl, bool print_err)
 					error(clause, "'where' clause evaluated to false:\n\t%s", str);
 					gb_string_free(str);
 
-					if (decl->scope != nullptr) {
+					if (scope != nullptr) {
 						isize print_count = 0;
-						for_array(j, decl->scope->elements.entries) {
-							Entity *e = decl->scope->elements.entries[j].value;
+						for_array(j, scope->elements.entries) {
+							Entity *e = scope->elements.entries[j].value;
 							switch (e->kind) {
 							case Entity_TypeName: {
 								if (print_count == 0) error_line("\n\tWith the following definitions:\n");
@@ -5790,7 +5785,8 @@ CallArgumentData check_call_arguments(CheckerContext *c, Operand *operand, Type
 					ctx.curr_proc_decl = decl;
 					ctx.curr_proc_sig  = e->type;
 
-					if (!evaluate_where_clauses(&ctx, decl, false)) {
+					GB_ASSERT(decl->proc_lit->kind == Ast_ProcLit);
+					if (!evaluate_where_clauses(&ctx, decl->scope, &decl->proc_lit->ProcLit.where_clauses, false)) {
 						continue;
 					}
 				}

+ 4 - 1
src/check_type.cpp

@@ -504,8 +504,8 @@ void check_struct_type(CheckerContext *ctx, Type *struct_type, Ast *node, Array<
 	struct_type->Struct.polymorphic_params      = polymorphic_params;
 	struct_type->Struct.is_poly_specialized     = is_poly_specialized;
 
-
 	if (!is_polymorphic) {
+		bool where_clause_ok = evaluate_where_clauses(ctx, ctx->scope, &st->where_clauses, true);
 		check_struct_fields(ctx, node, &struct_type->Struct.fields, &struct_type->Struct.tags, st->fields, min_field_count, struct_type, context);
 	}
 
@@ -688,6 +688,9 @@ void check_union_type(CheckerContext *ctx, Type *union_type, Ast *node, Array<Op
 	union_type->Union.is_polymorphic          = is_polymorphic;
 	union_type->Union.is_poly_specialized     = is_poly_specialized;
 
+	bool where_clause_ok = evaluate_where_clauses(ctx, ctx->scope, &ut->where_clauses, true);
+
+
 	for_array(i, ut->variants) {
 		Ast *node = ut->variants[i];
 		Type *t = check_type_expr(ctx, node, nullptr);

+ 1 - 1
src/checker.cpp

@@ -3157,7 +3157,7 @@ void check_add_import_decl(CheckerContext *ctx, Ast *decl) {
 			Entity *e = scope->elements.entries[elem_index].value;
 			if (e->scope == parent_scope) continue;
 
-			if (is_entity_exported(e)) {
+			if (is_entity_exported(e, true)) {
 				Entity *found = scope_lookup_current(parent_scope, name);
 				if (found != nullptr) {
 					// NOTE(bill):

+ 50 - 17
src/parser.cpp

@@ -349,10 +349,12 @@ Ast *clone_ast(Ast *node) {
 		n->StructType.fields = clone_ast_array(n->StructType.fields);
 		n->StructType.polymorphic_params = clone_ast(n->StructType.polymorphic_params);
 		n->StructType.align  = clone_ast(n->StructType.align);
+		n->StructType.where_clauses  = clone_ast_array(n->StructType.where_clauses);
 		break;
 	case Ast_UnionType:
 		n->UnionType.variants = clone_ast_array(n->UnionType.variants);
 		n->UnionType.polymorphic_params = clone_ast(n->UnionType.polymorphic_params);
+		n->UnionType.where_clauses = clone_ast_array(n->UnionType.where_clauses);
 		break;
 	case Ast_EnumType:
 		n->EnumType.base_type = clone_ast(n->EnumType.base_type);
@@ -921,7 +923,8 @@ Ast *ast_dynamic_array_type(AstFile *f, Token token, Ast *elem) {
 
 Ast *ast_struct_type(AstFile *f, Token token, Array<Ast *> fields, isize field_count,
                      Ast *polymorphic_params, bool is_packed, bool is_raw_union,
-                     Ast *align) {
+                     Ast *align,
+                     Token where_token, Array<Ast *> const &where_clauses) {
 	Ast *result = alloc_ast_node(f, Ast_StructType);
 	result->StructType.token              = token;
 	result->StructType.fields             = fields;
@@ -930,17 +933,22 @@ Ast *ast_struct_type(AstFile *f, Token token, Array<Ast *> fields, isize field_c
 	result->StructType.is_packed          = is_packed;
 	result->StructType.is_raw_union       = is_raw_union;
 	result->StructType.align              = align;
+	result->StructType.where_token        = where_token;
+	result->StructType.where_clauses      = where_clauses;
 	return result;
 }
 
 
-Ast *ast_union_type(AstFile *f, Token token, Array<Ast *> variants, Ast *polymorphic_params, Ast *align, bool no_nil) {
+Ast *ast_union_type(AstFile *f, Token token, Array<Ast *> variants, Ast *polymorphic_params, Ast *align, bool no_nil,
+                    Token where_token, Array<Ast *> const &where_clauses) {
 	Ast *result = alloc_ast_node(f, Ast_UnionType);
 	result->UnionType.token              = token;
 	result->UnionType.variants           = variants;
 	result->UnionType.polymorphic_params = polymorphic_params;
 	result->UnionType.align              = align;
 	result->UnionType.no_nil             = no_nil;
+	result->UnionType.where_token        = where_token;
+	result->UnionType.where_clauses      = where_clauses;
 	return result;
 }
 
@@ -2020,6 +2028,18 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 			syntax_error(token, "'#raw_union' cannot also be '#packed'");
 		}
 
+		Token where_token = {};
+		Array<Ast *> where_clauses = {};
+
+		if (f->curr_token.kind == Token_where) {
+			where_token = expect_token(f, Token_where);
+			isize prev_level = f->expr_level;
+			f->expr_level = -1;
+			where_clauses = parse_rhs_expr_list(f);
+			f->expr_level = prev_level;
+		}
+
+
 		Token open = expect_token_after(f, Token_OpenBrace, "struct");
 
 		isize    name_count = 0;
@@ -2032,7 +2052,7 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 			decls = fields->FieldList.list;
 		}
 
-		return ast_struct_type(f, token, decls, name_count, polymorphic_params, is_packed, is_raw_union, align);
+		return ast_struct_type(f, token, decls, name_count, polymorphic_params, is_packed, is_raw_union, align, where_token, where_clauses);
 	} break;
 
 	case Token_union: {
@@ -2073,6 +2093,18 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 			}
 		}
 
+		Token where_token = {};
+		Array<Ast *> where_clauses = {};
+
+		if (f->curr_token.kind == Token_where) {
+			where_token = expect_token(f, Token_where);
+			isize prev_level = f->expr_level;
+			f->expr_level = -1;
+			where_clauses = parse_rhs_expr_list(f);
+			f->expr_level = prev_level;
+		}
+
+
 		Token open = expect_token_after(f, Token_OpenBrace, "union");
 
 		while (f->curr_token.kind != Token_CloseBrace &&
@@ -2088,7 +2120,7 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 
 		Token close = expect_token(f, Token_CloseBrace);
 
-		return ast_union_type(f, token, variants, polymorphic_params, align, no_nil);
+		return ast_union_type(f, token, variants, polymorphic_params, align, no_nil, where_token, where_clauses);
 	} break;
 
 	case Token_enum: {
@@ -4424,19 +4456,6 @@ bool determine_path_from_string(gbMutex *file_mutex, Ast *node, String base_dir,
 	}
 
 
-	if (is_package_name_reserved(file_str)) {
-		*path = file_str;
-		return true;
-	}
-
-	if (file_mutex) gb_mutex_lock(file_mutex);
-	defer (if (file_mutex) gb_mutex_unlock(file_mutex));
-
-
-	if (node->kind == Ast_ForeignImportDecl) {
-		node->ForeignImportDecl.collection_name = collection_name;
-	}
-
 	if (collection_name.len > 0) {
 		if (collection_name == "system") {
 			if (node->kind != Ast_ForeignImportDecl) {
@@ -4467,6 +4486,20 @@ bool determine_path_from_string(gbMutex *file_mutex, Ast *node, String base_dir,
 #endif
 	}
 
+
+	if (is_package_name_reserved(file_str)) {
+		*path = file_str;
+		return true;
+	}
+
+	if (file_mutex) gb_mutex_lock(file_mutex);
+	defer (if (file_mutex) gb_mutex_unlock(file_mutex));
+
+
+	if (node->kind == Ast_ForeignImportDecl) {
+		node->ForeignImportDecl.collection_name = collection_name;
+	}
+
 	if (has_windows_drive) {
 		*path = file_str;
 	} else {

+ 16 - 12
src/parser.hpp

@@ -491,20 +491,24 @@ AST_KIND(_TypeBegin, "", bool) \
 		Ast *elem; \
 	}) \
 	AST_KIND(StructType, "struct type", struct { \
-		Token token;              \
-		Array<Ast *> fields;      \
-		isize field_count;        \
-		Ast *polymorphic_params;  \
-		Ast *align;               \
-		bool is_packed;           \
-		bool is_raw_union;        \
+		Token token;                \
+		Array<Ast *> fields;        \
+		isize field_count;          \
+		Ast *polymorphic_params;    \
+		Ast *align;                 \
+		Token where_token;          \
+		Array<Ast *> where_clauses; \
+		bool is_packed;             \
+		bool is_raw_union;          \
 	}) \
 	AST_KIND(UnionType, "union type", struct { \
-		Token        token;      \
-		Array<Ast *> variants;   \
-		Ast *polymorphic_params; \
-		Ast *        align;      \
-		bool         no_nil;     \
+		Token        token;         \
+		Array<Ast *> variants;      \
+		Ast *polymorphic_params;    \
+		Ast *        align;         \
+		bool         no_nil;        \
+		Token where_token;          \
+		Array<Ast *> where_clauses; \
 	}) \
 	AST_KIND(EnumType, "enum type", struct { \
 		Token        token; \