Browse Source

`where` clauses for procedure literals

gingerBill 6 years ago
parent
commit
b9d3129fb3

+ 2 - 0
core/odin/ast/ast.odin

@@ -99,6 +99,8 @@ Proc_Lit :: struct {
 	body: ^Stmt,
 	tags: Proc_Tags,
 	inlining: Proc_Inlining,
+	where_token: token.Token,
+	where_clauses: []^Expr,
 }
 
 Comp_Lit :: struct {

+ 18 - 0
core/odin/parser/parser.odin

@@ -1985,13 +1985,29 @@ parse_operand :: proc(p: ^Parser, lhs: bool) -> ^ast.Expr {
 
 		type := parse_proc_type(p, tok);
 
+		where_token: token.Token;
+		where_clauses: []^ast.Expr;
+		if (p.curr_tok.kind == token.Where) {
+			where_token = expect_token(p, token.Where);
+			prev_level := p.expr_level;
+			p.expr_level = -1;
+			where_clauses = parse_rhs_expr_list(p);
+			p.expr_level = prev_level;
+		}
+
 		if p.allow_type && p.expr_level < 0 {
+			if where_token.kind != token.Invalid {
+				error(p, where_token.pos, "'where' clauses are not allowed on procedure types");
+			}
 			return type;
 		}
 		body: ^ast.Stmt;
 
 		if allow_token(p, token.Undef) {
 			// Okay
+			if where_token.kind != token.Invalid {
+				error(p, where_token.pos, "'where' clauses are not allowed on procedure literals without a defined body (replaced with ---");
+			}
 		} else if p.curr_tok.kind == token.Open_Brace {
 			prev_proc := p.curr_proc;
 			p.curr_proc = type;
@@ -2009,6 +2025,8 @@ parse_operand :: proc(p: ^Parser, lhs: bool) -> ^ast.Expr {
 		pl := ast.new(ast.Proc_Lit, tok.pos, end_pos(p.prev_tok));
 		pl.type = type;
 		pl.body = body;
+		pl.where_token = where_token;
+		pl.where_clauses = where_clauses;
 		return pl;
 
 	case token.Dollar:

+ 2 - 0
core/odin/token/token.odin

@@ -118,6 +118,7 @@ using Kind :: enum u32 {
 		Package,
 		Typeid,
 		When,
+		Where,
 		If,
 		Else,
 		For,
@@ -252,6 +253,7 @@ tokens := [Kind.COUNT]string {
 	"package",
 	"typeid",
 	"when",
+	"where",
 	"if",
 	"else",
 	"for",

+ 65 - 0
examples/demo/demo.odin

@@ -4,6 +4,7 @@ import "core:fmt"
 import "core:mem"
 import "core:os"
 import "core:reflect"
+import "intrinsics"
 
 when os.OS == "windows" {
 	import "core:thread"
@@ -1094,6 +1095,69 @@ inline_for_statement :: proc() {
 	}
 }
 
+procedure_where_clauses :: proc() {
+	fmt.println("\n#procedure 'where' clauses");
+
+	{ // Sanity checks
+		simple_sanity_check :: proc(x: [2]int)
+			where len(x) > 1,
+			      type_of(x) == [2]int {
+			fmt.println(x);
+		}
+	}
+	{ // Parametric polymorphism checks
+		cross_2d :: proc(a, b: $T/[2]$E) -> E
+			where intrinsics.type_is_numeric(E) {
+			return a.x*b.y - a.y*b.x;
+		}
+		cross_3d :: proc(a, b: $T/[3]$E) -> T
+			where intrinsics.type_is_numeric(E) {
+			x := a.y*b.z - a.z*b.y;
+			y := a.z*b.x - a.x*b.z;
+			z := a.x*b.y - a.y*b.z;
+			return T{x, y, z};
+		}
+
+		a := [2]int{1, 2};
+		b := [2]int{5, -3};
+		fmt.println(cross_2d(a, b));
+
+		x := [3]f32{1, 4, 9};
+		y := [3]f32{-5, 0, 3};
+		fmt.println(cross_3d(x, y));
+
+		// Failure case
+		// i := [2]bool{true, false};
+		// j := [2]bool{false, true};
+		// fmt.println(cross_2d(i, j));
+
+	}
+
+	{ // Procedure groups usage
+		foo :: proc(x: [$N]int) -> bool
+			where N > 2 {
+			fmt.println(#procedure, "was called with the parameter", x);
+			return true;
+		}
+
+		bar :: proc(x: [$N]int) -> bool
+			where 0 < N,
+			      N <= 2 {
+			fmt.println(#procedure, "was called with the parameter", x);
+			return false;
+		}
+
+		baz :: proc{foo, bar};
+
+		x := [3]int{1, 2, 3};
+		y := [2]int{4, 9};
+		ok_x := baz(x);
+		ok_y := baz(y);
+		assert(ok_x == true);
+		assert(ok_y == false);
+	}
+}
+
 main :: proc() {
 	when true {
 		general_stuff();
@@ -1115,5 +1179,6 @@ main :: proc() {
 		reflection();
 		quaternions();
 		inline_for_statement();
+		procedure_where_clauses();
 	}
 }

+ 26 - 7
src/check_decl.cpp

@@ -936,7 +936,6 @@ void check_proc_group_decl(CheckerContext *ctx, Entity *pg_entity, DeclInfo *d)
 
 	ptr_set_destroy(&entity_set);
 
-
 	for_array(j, pge->entities) {
 		Entity *p = pge->entities[j];
 		if (p->type == t_invalid) {
@@ -962,27 +961,40 @@ void check_proc_group_decl(CheckerContext *ctx, Entity *pg_entity, DeclInfo *d)
 			defer (end_error_block());
 
 			ProcTypeOverloadKind kind = are_proc_types_overload_safe(p->type, q->type);
-			switch (kind) {
+			bool both_have_where_clauses = false;
+			if (p->decl_info->proc_lit != nullptr && q->decl_info->proc_lit != nullptr) {
+				GB_ASSERT(p->decl_info->proc_lit->kind == Ast_ProcLit);
+				GB_ASSERT(q->decl_info->proc_lit->kind == Ast_ProcLit);
+				auto pl = &p->decl_info->proc_lit->ProcLit;
+				auto ql = &q->decl_info->proc_lit->ProcLit;
+
+				// Allow collisions if the procedures both have 'where' clauses and are both polymorphic
+				bool pw = pl->where_token.kind != Token_Invalid && is_type_polymorphic(p->type, true);
+				bool qw = ql->where_token.kind != Token_Invalid && is_type_polymorphic(q->type, true);
+				both_have_where_clauses = pw && qw;
+			}
+
+			if (!both_have_where_clauses) switch (kind) {
 			case ProcOverload_Identical:
-				error(p->token, "Overloaded procedure '%.*s' as the same type as another procedure in this scope", LIT(name));
+				error(p->token, "Overloaded procedure '%.*s' as the same type as another procedure in the procedure group '%.*s'", LIT(name), LIT(proc_group_name));
 				is_invalid = true;
 				break;
 			// case ProcOverload_CallingConvention:
-				// error(p->token, "Overloaded procedure '%.*s' as the same type as another procedure in this scope", LIT(name));
+				// error(p->token, "Overloaded procedure '%.*s' as the same type as another procedure in the procedure group '%.*s'", LIT(name), LIT(proc_group_name));
 				// is_invalid = true;
 				// break;
 			case ProcOverload_ParamVariadic:
-				error(p->token, "Overloaded procedure '%.*s' as the same type as another procedure in this scope", LIT(name));
+				error(p->token, "Overloaded procedure '%.*s' as the same type as another procedure in the procedure group '%.*s'", LIT(name), LIT(proc_group_name));
 				is_invalid = true;
 				break;
 			case ProcOverload_ResultCount:
 			case ProcOverload_ResultTypes:
-				error(p->token, "Overloaded procedure '%.*s' as the same parameters but different results in this scope", LIT(name));
+				error(p->token, "Overloaded procedure '%.*s' as the same parameters but different results in the procedure group '%.*s'", LIT(name), LIT(proc_group_name));
 				is_invalid = true;
 				break;
 			case ProcOverload_Polymorphic:
 				#if 0
-				error(p->token, "Overloaded procedure '%.*s' has a polymorphic counterpart in this scope which is not allowed", LIT(name));
+				error(p->token, "Overloaded procedure '%.*s' has a polymorphic counterpart in the procedure group '%.*s' which is not allowed", LIT(name), LIT(proc_group_name));
 				is_invalid = true;
 				#endif
 				break;
@@ -1163,6 +1175,13 @@ void check_proc_body(CheckerContext *ctx_, Token token, DeclInfo *decl, Type *ty
 		}
 	}
 
+
+	bool where_clause_ok = evaluate_where_clauses(ctx, decl, true);
+	if (!where_clause_ok) {
+		// NOTE(bill, 2019-08-31): Don't check the body as the where clauses failed
+		return;
+	}
+
 	check_open_scope(ctx, body);
 	{
 		for_array(i, using_entities) {

+ 110 - 5
src/check_expr.cpp

@@ -5470,6 +5470,74 @@ Entity **populate_proc_parameter_list(CheckerContext *c, Type *proc_type, isize
 	return lhs;
 }
 
+
+bool evaluate_where_clauses(CheckerContext *ctx, DeclInfo *decl, bool print_err) {
+	Ast *proc_lit = decl->proc_lit;
+	GB_ASSERT(proc_lit != nullptr);
+	GB_ASSERT(proc_lit->kind == Ast_ProcLit);
+
+	if (proc_lit->ProcLit.where_token.kind != Token_Invalid) {
+		auto &clauses = proc_lit->ProcLit.where_clauses;
+		for_array(i, clauses) {
+			Ast *clause = clauses[i];
+			Operand o = {};
+			check_expr(ctx, &o, clause);
+			if (o.mode != Addressing_Constant) {
+				if (print_err) error(clause, "'where' clauses expect a constant boolean evaluation");
+				return false;
+			} else if (o.value.kind != ExactValue_Bool) {
+				if (print_err) error(clause, "'where' clauses expect a constant boolean evaluation");
+				return false;
+			} else if (!o.value.value_bool) {
+				if (print_err) {
+					gbString str = expr_to_string(clause);
+					error(clause, "'where' clause evaluated to false:\n\t%s", str);
+					gb_string_free(str);
+
+					if (decl->scope != nullptr) {
+						isize print_count = 0;
+						for_array(j, decl->scope->elements.entries) {
+							Entity *e = decl->scope->elements.entries[j].value;
+							switch (e->kind) {
+							case Entity_TypeName: {
+								if (print_count == 0) error_line("\n\tWith the following definitions:\n");
+
+								gbString str = type_to_string(e->type);
+								error_line("\t\t%.*s :: %s;\n", LIT(e->token.string), str);
+								gb_string_free(str);
+								print_count += 1;
+								break;
+							}
+							case Entity_Constant: {
+								if (print_count == 0) error_line("\n\tWith the following definitions:\n");
+
+								gbString str = exact_value_to_string(e->Constant.value);
+								if (is_type_untyped(e->type)) {
+									error_line("\t\t%.*s :: %s;\n", LIT(e->token.string), str);
+								} else {
+									gbString t = type_to_string(e->type);
+									error_line("\t\t%.*s : %s : %s;\n", LIT(e->token.string), t, str);
+									gb_string_free(t);
+								}
+								gb_string_free(str);
+
+								print_count += 1;
+								break;
+							}
+							}
+						}
+					}
+
+				}
+				return false;
+			}
+		}
+	}
+
+	return true;
+}
+
+
 CallArgumentData check_call_arguments(CheckerContext *c, Operand *operand, Type *proc_type, Ast *call) {
 	ast_node(ce, CallExpr, call);
 
@@ -5710,11 +5778,26 @@ CallArgumentData check_call_arguments(CheckerContext *c, Operand *operand, Type
 
 				err = call_checker(&ctx, call, pt, p, operands, CallArgumentMode_NoErrors, &data);
 
-				if (err == CallArgumentError_None) {
-					valids[valid_count].index = i;
-					valids[valid_count].score = data.score;
-					valid_count++;
+				if (err != CallArgumentError_None) {
+					continue;
 				}
+				if (data.gen_entity != nullptr) {
+					Entity *e = data.gen_entity;
+					DeclInfo *decl = data.gen_entity->decl_info;
+					ctx.scope = decl->scope;
+					ctx.decl = decl;
+					ctx.proc_name = e->token.string;
+					ctx.curr_proc_decl = decl;
+					ctx.curr_proc_sig  = e->type;
+
+					if (!evaluate_where_clauses(&ctx, decl, false)) {
+						continue;
+					}
+				}
+
+				valids[valid_count].index = i;
+				valids[valid_count].score = data.score;
+				valid_count++;
 			}
 		}
 
@@ -5822,7 +5905,29 @@ CallArgumentData check_call_arguments(CheckerContext *c, Operand *operand, Type
 				if (proc->kind == Entity_Variable) {
 					sep = ":=";
 				}
-				error_line("\t%.*s %s %s at %.*s(%td:%td)\n", LIT(name), sep, pt, LIT(pos.file), pos.line, pos.column);
+				error_line("\t%.*s %s %s ", LIT(name), sep, pt);
+				if (proc->decl_info->proc_lit != nullptr) {
+					GB_ASSERT(proc->decl_info->proc_lit->kind == Ast_ProcLit);
+					auto *pl = &proc->decl_info->proc_lit->ProcLit;
+					if (pl->where_token.kind != Token_Invalid) {
+						error_line("\n\t\twhere ");
+						for_array(j, pl->where_clauses) {
+							Ast *clause = pl->where_clauses[j];
+							if (j != 0) {
+								error_line("\t\t      ");
+							}
+							gbString str = expr_to_string(clause);
+							error_line("%s", str);
+							gb_string_free(str);
+
+							if (j != pl->where_clauses.count-1) {
+								error_line(",");
+							}
+						}
+						error_line("\n\t");
+					}
+				}
+				error_line("at %.*s(%td:%td)\n", LIT(pos.file), pos.line, pos.column);
 				// error_line("\t%.*s %s %s at %.*s(%td:%td) %lld\n", LIT(name), sep, pt, LIT(pos.file), pos.line, pos.column, valids[i].score);
 			}
 			result_type = t_invalid;

+ 8 - 0
src/checker.cpp

@@ -3697,6 +3697,14 @@ void check_proc_info(Checker *c, ProcInfo pi) {
 		return;
 	}
 
+	if (pt->is_polymorphic && pt->is_poly_specialized) {
+		Entity *e = pi.decl->entity;
+		if ((e->flags & EntityFlag_Used) == 0) {
+			// NOTE(bill, 2019-08-31): It was never used, don't check
+			return;
+		}
+	}
+
 	bool bounds_check    = (pi.tags & ProcTag_bounds_check)    != 0;
 	bool no_bounds_check = (pi.tags & ProcTag_no_bounds_check) != 0;
 

+ 38 - 12
src/parser.cpp

@@ -144,6 +144,7 @@ Ast *clone_ast(Ast *node) {
 	case Ast_ProcLit:
 		n->ProcLit.type = clone_ast(n->ProcLit.type);
 		n->ProcLit.body = clone_ast(n->ProcLit.body);
+		n->ProcLit.where_clauses = clone_ast_array(n->ProcLit.where_clauses);
 		break;
 	case Ast_CompoundLit:
 		n->CompoundLit.type  = clone_ast(n->CompoundLit.type);
@@ -612,11 +613,13 @@ Ast *ast_proc_group(AstFile *f, Token token, Token open, Token close, Array<Ast
 	return result;
 }
 
-Ast *ast_proc_lit(AstFile *f, Ast *type, Ast *body, u64 tags) {
+Ast *ast_proc_lit(AstFile *f, Ast *type, Ast *body, u64 tags, Token where_token, Array<Ast *> const &where_clauses) {
 	Ast *result = alloc_ast_node(f, Ast_ProcLit);
 	result->ProcLit.type = type;
 	result->ProcLit.body = body;
 	result->ProcLit.tags = tags;
+	result->ProcLit.where_token = where_token;
+	result->ProcLit.where_clauses = where_clauses;
 	return result;
 }
 
@@ -1827,15 +1830,41 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 		}
 
 		Ast *type = parse_proc_type(f, token);
+		Token where_token = {};
+		Array<Ast *> where_clauses = {};
+		u64 tags = 0;
+
+		if (f->curr_token.kind == Token_where) {
+			where_token = expect_token(f, Token_where);
+			isize prev_level = f->expr_level;
+			f->expr_level = -1;
+			where_clauses = parse_rhs_expr_list(f);
+			f->expr_level = prev_level;
+		}
+
+		parse_proc_tags(f, &tags);
+		if ((tags & ProcTag_require_results) != 0) {
+			syntax_error(f->curr_token, "#require_results has now been replaced as an attribute @(require_results) on the declaration");
+			tags &= ~ProcTag_require_results;
+		}
+		GB_ASSERT(type->kind == Ast_ProcType);
+		type->ProcType.tags = tags;
 
 		if (f->allow_type && f->expr_level < 0) {
+			if (tags != 0) {
+				syntax_error(token, "A procedure type cannot have suffix tags");
+			}
+			if (where_token.kind != Token_Invalid) {
+				syntax_error(where_token, "'where' clauses are not allowed on procedure types");
+			}
 			return type;
 		}
 
-		u64 tags = type->ProcType.tags;
-
 		if (allow_token(f, Token_Undef)) {
-			return ast_proc_lit(f, type, nullptr, tags);
+			if (where_token.kind != Token_Invalid) {
+				syntax_error(where_token, "'where' clauses are not allowed on procedure literals without a defined body (replaced with ---)");
+			}
+			return ast_proc_lit(f, type, nullptr, tags, where_token, where_clauses);
 		} else if (f->curr_token.kind == Token_OpenBrace) {
 			Ast *curr_proc = f->curr_proc;
 			Ast *body = nullptr;
@@ -1843,7 +1872,7 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 			body = parse_body(f);
 			f->curr_proc = curr_proc;
 
-			return ast_proc_lit(f, type, body, tags);
+			return ast_proc_lit(f, type, body, tags, where_token, where_clauses);
 		} else if (allow_token(f, Token_do)) {
 			Ast *curr_proc = f->curr_proc;
 			Ast *body = nullptr;
@@ -1851,12 +1880,15 @@ Ast *parse_operand(AstFile *f, bool lhs) {
 			body = convert_stmt_to_body(f, parse_stmt(f));
 			f->curr_proc = curr_proc;
 
-			return ast_proc_lit(f, type, body, tags);
+			return ast_proc_lit(f, type, body, tags, where_token, where_clauses);
 		}
 
 		if (tags != 0) {
 			syntax_error(token, "A procedure type cannot have suffix tags");
 		}
+		if (where_token.kind != Token_Invalid) {
+			syntax_error(where_token, "'where' clauses are not allowed on procedure types");
+		}
 
 		return type;
 	}
@@ -2827,12 +2859,6 @@ Ast *parse_proc_type(AstFile *f, Token proc_token) {
 	results = parse_results(f, &diverging);
 
 	u64 tags = 0;
-	parse_proc_tags(f, &tags);
-	if ((tags & ProcTag_require_results) != 0) {
-		syntax_error(f->curr_token, "#require_results has now been replaced as an attribute @(require_results) on the declaration");
-		tags &= ~ProcTag_require_results;
-	}
-
 	bool is_generic = false;
 
 	for_array(i, params->FieldList.list) {

+ 2 - 0
src/parser.hpp

@@ -229,6 +229,8 @@ enum StmtAllowFlag {
 		Ast *body; \
 		u64  tags; \
 		ProcInlining inlining; \
+		Token where_token; \
+		Array<Ast *> where_clauses; \
 	}) \
 	AST_KIND(CompoundLit, "compound literal", struct { \
 		Ast *type; \

+ 1 - 0
src/tokenizer.cpp

@@ -86,6 +86,7 @@ TOKEN_KIND(Token__KeywordBegin, ""), \
 	TOKEN_KIND(Token_package,     "package"),     \
 	TOKEN_KIND(Token_typeid,      "typeid"),      \
 	TOKEN_KIND(Token_when,        "when"),        \
+	TOKEN_KIND(Token_where,       "where"),       \
 	TOKEN_KIND(Token_if,          "if"),          \
 	TOKEN_KIND(Token_else,        "else"),        \
 	TOKEN_KIND(Token_for,         "for"),         \