Browse Source

Procedure Literal

gingerBill 9 years ago
parent
commit
f8fd6fce0b
7 changed files with 124 additions and 47 deletions
  1. 0 5
      examples/other.odin
  2. 18 2
      examples/test.odin
  3. 13 9
      src/checker/checker.cpp
  4. 25 5
      src/checker/expression.cpp
  5. 22 19
      src/checker/statements.cpp
  6. 5 0
      src/checker/type.cpp
  7. 41 7
      src/parser.cpp

+ 0 - 5
examples/other.odin

@@ -1,5 +0,0 @@
-import "test"
-
-add :: proc(a, b: int) -> int {
-	return a + b;
-}

+ 18 - 2
examples/test.odin

@@ -1,11 +1,27 @@
-import "other"
+// import "other"
 
 TAU :: 6.28;
 PI :: PI/2;
 
+type AddProc: proc(a, b: int) -> int;
+
+
+do_thing :: proc(p: AddProc) {
+	p(1, 2);
+}
+
+add :: proc(a, b: int) -> int {
+	return a + b;
+}
+
+
 main :: proc() {
 	x : int = 2;
 	x = x * 3;
 
-	y := add(1, x);
+	// do_thing(add(1, x));
+	do_thing(proc(a, b: int) -> f32 {
+		return a*b - a%b;
+	});
+
 }

+ 13 - 9
src/checker/checker.cpp

@@ -151,6 +151,11 @@ gb_global BuiltinProcedure builtin_procedures[BuiltinProcedure_Count] = {
 	{STR_LIT("println"),          1, true,  Expression_Statement},
 };
 
+struct ProcedureContext {
+	Scope *scope;
+	DeclarationInfo *decl;
+};
+
 
 struct Checker {
 	Parser *               parser;
@@ -169,8 +174,7 @@ struct Checker {
 	gbArena     arena;
 	gbAllocator allocator;
 
-	Scope *            curr_scope;
-	DeclarationInfo *  decl;
+	ProcedureContext   proc_context;
 
 	gbArray(Type *) procedure_stack;
 	b32 in_defer; // TODO(bill): Actually handle correctly
@@ -250,10 +254,10 @@ void add_dependency(DeclarationInfo *d, Entity *e) {
 }
 
 void add_declaration_dependency(Checker *c, Entity *e) {
-	if (c->decl) {
+	if (c->proc_context.decl) {
 		auto found = map_get(&c->entities, hash_pointer(e));
 		if (found) {
-			add_dependency(c->decl, e);
+			add_dependency(c->proc_context.decl, e);
 		}
 	}
 }
@@ -345,7 +349,7 @@ void init_checker(Checker *c, Parser *parser) {
 	c->allocator = gb_arena_allocator(&c->arena);
 
 	c->global_scope = make_scope(universal_scope, c->allocator);
-	c->curr_scope = c->global_scope;
+	c->proc_context.scope = c->global_scope;
 }
 
 void destroy_checker(Checker *c) {
@@ -500,13 +504,13 @@ void add_scope(Checker *c, AstNode *node, Scope *scope) {
 
 
 void check_open_scope(Checker *c, AstNode *statement) {
-	Scope *scope = make_scope(c->curr_scope, c->allocator);
+	Scope *scope = make_scope(c->proc_context.scope, c->allocator);
 	add_scope(c, statement, scope);
-	c->curr_scope = scope;
+	c->proc_context.scope = scope;
 }
 
 void check_close_scope(Checker *c) {
-	c->curr_scope = c->curr_scope->parent;
+	c->proc_context.scope = c->proc_context.scope->parent;
 }
 
 void push_procedure(Checker *c, Type *procedure_type) {
@@ -564,7 +568,7 @@ void check_parsed_files(Checker *c) {
 					     name = name->next, value = value->next) {
 						GB_ASSERT(name->kind == AstNode_Identifier);
 						ExactValue v = {ExactValue_Invalid};
-						Entity *e = make_entity_constant(c->allocator, c->curr_scope, name->identifier.token, NULL, v);
+						Entity *e = make_entity_constant(c->allocator, c->proc_context.scope, name->identifier.token, NULL, v);
 						DeclarationInfo *di = make_declaration_info(c->allocator, c->global_scope);
 						di->type_expr = vd->type_expression;
 						di->init_expr = value;

+ 25 - 5
src/checker/expression.cpp

@@ -10,6 +10,7 @@ void           check_not_tuple         (Checker *c, Operand *operand);
 void           convert_to_typed        (Checker *c, Operand *operand, Type *target_type);
 gbString       expression_to_string    (AstNode *expression);
 void           check_entity_declaration(Checker *c, Entity *e, Type *named_type);
+void           check_procedure_body(Checker *c, Token token, DeclarationInfo *decl, Type *type, AstNode *body);
 
 
 void check_struct_type(Checker *c, Type *struct_type, AstNode *node) {
@@ -41,7 +42,7 @@ void check_struct_type(Checker *c, Type *struct_type, AstNode *node) {
 			GB_ASSERT(name->kind == AstNode_Identifier);
 			Token name_token = name->identifier.token;
 			// TODO(bill): is the curr_scope correct?
-			Entity *e = make_entity_field(c->allocator, c->curr_scope, name_token, type);
+			Entity *e = make_entity_field(c->allocator, c->proc_context.scope, name_token, type);
 			u64 key = hash_string(name_token.string);
 			if (map_get(&entity_map, key)) {
 				// TODO(bill): Scope checking already checks the declaration
@@ -122,10 +123,10 @@ void check_procedure_type(Checker *c, Type *type, AstNode *proc_type_node) {
 
 	// gb_printf("%td -> %td\n", param_count, result_count);
 
-	Type *params  = check_get_params(c, c->curr_scope, proc_type_node->procedure_type.param_list,   param_count);
-	Type *results = check_get_results(c, c->curr_scope, proc_type_node->procedure_type.results_list, result_count);
+	Type *params  = check_get_params(c, c->proc_context.scope, proc_type_node->procedure_type.param_list,   param_count);
+	Type *results = check_get_results(c, c->proc_context.scope, proc_type_node->procedure_type.results_list, result_count);
 
-	type->procedure.scope         = c->curr_scope;
+	type->procedure.scope         = c->proc_context.scope;
 	type->procedure.params        = params;
 	type->procedure.params_count  = proc_type_node->procedure_type.param_count;
 	type->procedure.results       = results;
@@ -138,7 +139,7 @@ void check_identifier(Checker *c, Operand *o, AstNode *n, Type *named_type) {
 	o->mode = Addressing_Invalid;
 	o->expression = n;
 	Entity *e = NULL;
-	scope_lookup_parent_entity(c->curr_scope, n->identifier.token.string, NULL, &e);
+	scope_lookup_parent_entity(c->proc_context.scope, n->identifier.token.string, NULL, &e);
 	if (e == NULL) {
 		checker_err(c, n->identifier.token,
 		            "Undeclared type or identifier `%.*s`", LIT(n->identifier.token.string));
@@ -1532,6 +1533,21 @@ ExpressionKind check__expression_base(Checker *c, Operand *o, AstNode *node, Typ
 		o->value = make_exact_value_from_basic_literal(lit);
 	} break;
 
+	case AstNode_ProcedureLiteral: {
+		Scope *origin_curr_scope = c->proc_context.scope;
+		Type *proc_type = check_type(c, node->procedure_literal.type);
+		if (proc_type != NULL) {
+			check_procedure_body(c, empty_token, c->proc_context.decl, proc_type, node->procedure_literal.body);
+			o->mode = Addressing_Value;
+			o->type = proc_type;
+		} else {
+			gbString str = expression_to_string(node);
+			checker_err(c, ast_node_token(node), "Invalid procedure literal `%s`", str);
+			gb_string_free(str);
+			goto error;
+		}
+	} break;
+
 	case AstNode_ParenExpression:
 		kind = check_expression_base(c, o, node->paren_expression.expression, type_hint);
 		o->expression = node;
@@ -1887,6 +1903,10 @@ gbString write_expression_to_string(gbString str, AstNode *node) {
 		str = string_append_token(str, node->basic_literal);
 		break;
 
+	case AstNode_ProcedureLiteral:
+		str = write_expression_to_string(str, node->procedure_literal.type);
+		break;
+
 	case AstNode_TagExpression:
 		str = gb_string_appendc(str, "#");
 		str = string_append_token(str, node->tag_expression.name);

+ 22 - 19
src/checker/statements.cpp

@@ -172,7 +172,7 @@ Type *check_assign_variable(Checker *c, Operand *op_a, AstNode *lhs) {
 	Entity *e = NULL;
 	b32 used = false;
 	if (node->kind == AstNode_Identifier) {
-		scope_lookup_parent_entity(c->curr_scope, node->identifier.token.string,
+		scope_lookup_parent_entity(c->proc_context.scope, node->identifier.token.string,
 		                           NULL, &e);
 		if (e != NULL && e->kind == Entity_Variable) {
 			used = e->variable.used; // TODO(bill): Make backup just in case
@@ -381,8 +381,11 @@ void check_type_declaration(Checker *c, Entity *e, AstNode *type_expr, Type *nam
 
 void check_procedure_body(Checker *c, Token token, DeclarationInfo *decl, Type *type, AstNode *body) {
 	GB_ASSERT(body->kind == AstNode_BlockStatement);
-	Scope *origin_curr_scope = c->curr_scope;
-	c->curr_scope = decl->scope;
+
+	ProcedureContext old_proc_context = c->proc_context;
+	c->proc_context.scope = decl->scope;
+	c->proc_context.decl = decl;
+
 	push_procedure(c, type);
 	check_statement_list(c, body->block_statement.list);
 	if (type->procedure.results_count > 0) {
@@ -392,7 +395,7 @@ void check_procedure_body(Checker *c, Token token, DeclarationInfo *decl, Type *
 	}
 	pop_procedure(c);
 
-	c->curr_scope = origin_curr_scope;
+	c->proc_context = old_proc_context;
 }
 
 void check_procedure_declaration(Checker *c, Entity *e, DeclarationInfo *d, b32 check_body_later) {
@@ -403,8 +406,8 @@ void check_procedure_declaration(Checker *c, Entity *e, DeclarationInfo *d, b32
 	auto *pd = &d->proc_decl->procedure_declaration;
 
 #if 1
-	Scope *origin_curr_scope = c->curr_scope;
-	c->curr_scope = c->global_scope;
+	Scope *original_curr_scope = c->proc_context.scope;
+	c->proc_context.scope = c->global_scope;
 	check_open_scope(c, pd->procedure_type);
 #endif
 	check_procedure_type(c, proc_type, pd->procedure_type);
@@ -438,7 +441,7 @@ void check_procedure_declaration(Checker *c, Entity *e, DeclarationInfo *d, b32
 			            "A procedure tagged as `#foreign` cannot have a body");
 		}
 
-		d->scope = c->curr_scope;
+		d->scope = c->proc_context.scope;
 
 		GB_ASSERT(pd->body->kind == AstNode_BlockStatement);
 		if (check_body_later) {
@@ -450,7 +453,7 @@ void check_procedure_declaration(Checker *c, Entity *e, DeclarationInfo *d, b32
 
 #if 1
 	check_close_scope(c);
-	c->curr_scope = origin_curr_scope;
+	c->proc_context.scope = original_curr_scope;
 #endif
 
 }
@@ -503,11 +506,11 @@ void check_entity_declaration(Checker *c, Entity *e, Type *named_type) {
 
 	switch (e->kind) {
 	case Entity_Constant:
-		c->decl = d;
+		c->proc_context.decl = d;
 		check_constant_declaration(c, e, d->type_expr, d->init_expr);
 		break;
 	case Entity_Variable:
-		c->decl = d;
+		c->proc_context.decl = d;
 		check_variable_declaration(c, e, d->entities, d->entity_count, d->type_expr, d->init_expr);
 		break;
 	case Entity_TypeName:
@@ -736,10 +739,10 @@ void check_statement(Checker *c, AstNode *node) {
 					// NOTE(bill): Ignore assignments to `_`
 					b32 can_be_ignored = are_strings_equal(str, make_string("_"));
 					if (!can_be_ignored) {
-						found = current_scope_lookup_entity(c->curr_scope, str);
+						found = current_scope_lookup_entity(c->proc_context.scope, str);
 					}
 					if (found == NULL) {
-						entity = make_entity_variable(c->allocator, c->curr_scope, token, NULL);
+						entity = make_entity_variable(c->allocator, c->proc_context.scope, token, NULL);
 						if (!can_be_ignored) {
 							new_entities[new_entity_count++] = entity;
 						}
@@ -780,7 +783,7 @@ void check_statement(Checker *c, AstNode *node) {
 
 			AstNode *name = vd->name_list;
 			for (isize i = 0; i < new_entity_count; i++, name = name->next) {
-				add_entity(c, c->curr_scope, name, new_entities[i]);
+				add_entity(c, c->proc_context.scope, name, new_entities[i]);
 			}
 
 		} break;
@@ -791,7 +794,7 @@ void check_statement(Checker *c, AstNode *node) {
 			     name = name->next, value = value->next) {
 				GB_ASSERT(name->kind == AstNode_Identifier);
 				ExactValue v = {ExactValue_Invalid};
-				Entity *e = make_entity_constant(c->allocator, c->curr_scope, name->identifier.token, NULL, v);
+				Entity *e = make_entity_constant(c->allocator, c->proc_context.scope, name->identifier.token, NULL, v);
 				entities[entity_index++] = e;
 				check_constant_declaration(c, e, vd->type_expression, value);
 			}
@@ -808,7 +811,7 @@ void check_statement(Checker *c, AstNode *node) {
 
 			AstNode *name = vd->name_list;
 			for (isize i = 0; i < entity_count; i++, name = name->next) {
-				add_entity(c, c->curr_scope, name, entities[i]);
+				add_entity(c, c->proc_context.scope, name, entities[i]);
 			}
 		} break;
 
@@ -820,8 +823,8 @@ void check_statement(Checker *c, AstNode *node) {
 
 	case AstNode_ProcedureDeclaration: {
 		auto *pd = &node->procedure_declaration;
-		Entity *e = make_entity_procedure(c->allocator, c->curr_scope, pd->name->identifier.token, NULL);
-		add_entity(c, c->curr_scope, pd->name, e);
+		Entity *e = make_entity_procedure(c->allocator, c->proc_context.scope, pd->name->identifier.token, NULL);
+		add_entity(c, c->proc_context.scope, pd->name, e);
 
 		DeclarationInfo decl = {};
 		init_declaration_info(&decl, e->parent);
@@ -833,8 +836,8 @@ void check_statement(Checker *c, AstNode *node) {
 	case AstNode_TypeDeclaration: {
 		auto *td = &node->type_declaration;
 		AstNode *name = td->name;
-		Entity *e = make_entity_type_name(c->allocator, c->curr_scope, name->identifier.token, NULL);
-		add_entity(c, c->curr_scope, name, e);
+		Entity *e = make_entity_type_name(c->allocator, c->proc_context.scope, name->identifier.token, NULL);
+		add_entity(c, c->proc_context.scope, name, e);
 		check_type_declaration(c, e, td->type_expression, NULL);
 	} break;
 	}

+ 5 - 0
src/checker/type.cpp

@@ -307,6 +307,11 @@ b32 are_types_identical(Type *x, Type *y) {
 	if (x == y)
 		return true;
 
+	if ((x == NULL && y != NULL) ||
+	    (x != NULL && y == NULL)) {
+		return false;
+	}
+
 	switch (x->kind) {
 	case Type_Basic:
 		if (y->kind == Type_Basic)

+ 41 - 7
src/parser.cpp

@@ -48,6 +48,7 @@ enum AstNodeKind {
 
 	AstNode_BasicLiteral,
 	AstNode_Identifier,
+	AstNode_ProcedureLiteral,
 
 AstNode__ExpressionBegin,
 	AstNode_BadExpression, // NOTE(bill): Naughty expression
@@ -121,6 +122,11 @@ struct AstNode {
 			Token token;
 			AstEntity *entity;
 		} identifier;
+		struct {
+			AstNode *type; // AstNode_ProcedureType
+			AstNode *body; // AstNode_BlockStatement
+		} procedure_literal;
+
 		struct {
 			Token token;
 			Token name;
@@ -273,6 +279,8 @@ Token ast_node_token(AstNode *node) {
 		return node->basic_literal;
 	case AstNode_Identifier:
 		return node->identifier.token;
+	case AstNode_ProcedureLiteral:
+		return ast_node_token(node->procedure_literal.type);
 	case AstNode_TagExpression:
 		return node->tag_expression.token;
 	case AstNode_BadExpression:
@@ -558,6 +566,13 @@ gb_inline AstNode *make_identifier(AstFile *f, Token token, AstEntity *entity =
 	return result;
 }
 
+gb_inline AstNode *make_procedure_literal(AstFile *f, AstNode *type, AstNode *body) {
+	AstNode *result = make_node(f, AstNode_ProcedureLiteral);
+	result->procedure_literal.type = type;
+	result->procedure_literal.body = body;
+	return result;
+}
+
 gb_inline AstNode *make_bad_statement(AstFile *f, Token begin, Token end) {
 	AstNode *result = make_node(f, AstNode_BadStatement);
 	result->bad_statement.begin = begin;
@@ -773,9 +788,6 @@ gb_inline b32 allow_token(AstFile *f, TokenKind kind) {
 }
 
 
-
-
-
 gb_internal void add_ast_entity(AstFile *f, AstScope *scope, AstNode *declaration, AstNode *name_list) {
 	for (AstNode *n = name_list; n != NULL; n = n->next) {
 		if (n->kind != AstNode_Identifier) {
@@ -797,7 +809,15 @@ gb_internal void add_ast_entity(AstFile *f, AstScope *scope, AstNode *declaratio
 	}
 }
 
+
+
+
+
 AstNode *parse_expression(AstFile *f, b32 lhs);
+AstNode *parse_procedure_type(AstFile *f, AstScope **scope_);
+AstNode *parse_statement_list(AstFile *f, isize *list_count_);
+AstNode *parse_statement(AstFile *f);
+AstNode *parse_body(AstFile *f, AstScope *scope);
 
 AstNode *parse_identifier(AstFile *f) {
 	Token token = f->cursor[0];
@@ -830,6 +850,8 @@ AstNode *unparen_expression(AstNode *node) {
 	}
 }
 
+
+
 AstNode *parse_atom_expression(AstFile *f, b32 lhs) {
 	AstNode *operand = NULL; // Operand
 	switch (f->cursor[0].kind) {
@@ -861,6 +883,18 @@ AstNode *parse_atom_expression(AstFile *f, b32 lhs) {
 		operand = parse_tag_expression(f, NULL);
 		operand->tag_expression.expression = parse_expression(f, false);
 	} break;
+
+	// Parse Procedure Type or Literal
+	case Token_proc: {
+		AstScope *scope = NULL;
+		AstNode *type = parse_procedure_type(f, &scope);
+		if (f->cursor[0].kind != Token_OpenBrace) {
+			return type;
+		}
+
+		AstNode *body = parse_body(f, scope);
+		return make_procedure_literal(f, type, body);
+	} break;
 	}
 
 	b32 loop = true;
@@ -1123,9 +1157,7 @@ AstNode *parse_simple_statement(AstFile *f) {
 	return make_expression_statement(f, lhs_expression_list);
 }
 
-AstNode *parse_statement_list(AstFile *f, isize *list_count_);
-AstNode *parse_statement(AstFile *f);
-AstNode *parse_body(AstFile *f, AstScope *scope);
+
 
 AstNode *parse_block_statement(AstFile *f) {
 	if (f->curr_scope == f->file_scope) {
@@ -1796,7 +1828,9 @@ void parse_file(Parser *p, AstFile *f) {
 	}
 
 	for (AstNode *node = f->declarations; node != NULL; node = node->next) {
-		if (!is_ast_node_declaration(node)) {
+		if (!is_ast_node_declaration(node) &&
+		    node->kind != AstNode_BadStatement &&
+		    node->kind != AstNode_EmptyStatement) {
 			// NOTE(bill): Sanity check
 			ast_file_err(f, ast_node_token(node), "Only declarations are allowed at file scope");
 		} else {