Browse Source

Support `any` in `match type`

Ginger Bill 8 years ago
parent
commit
24347ced45
8 changed files with 192 additions and 135 deletions
  1. 1 1
      build.bat
  2. 13 0
      code/demo.odin
  3. 45 58
      core/fmt.odin
  4. 13 3
      src/checker/expr.cpp
  5. 52 22
      src/checker/stmt.cpp
  6. 10 11
      src/checker/types.cpp
  7. 1 1
      src/parser.cpp
  8. 57 39
      src/ssa.cpp

+ 1 - 1
build.bat

@@ -4,7 +4,7 @@
 set exe_name=odin.exe
 
 :: Debug = 0, Release = 1
-set release_mode=1
+set release_mode=0
 
 set compiler_flags= -nologo -Oi -TP -W4 -fp:fast -fp:except- -Gm- -MP -FC -GS- -EHsc- -GR-
 

+ 13 - 0
code/demo.odin

@@ -1,6 +1,19 @@
 #import "fmt.odin"
 #import "game.odin"
 
+variadic :: proc(args: ..any) {
+	for i := 0; i < args.count; i++ {
+		match type a : args[i] {
+		case int:    fmt.println("int", a)
+		case f32:    fmt.println("f32", a)
+		case f64:    fmt.println("f64", a)
+		case string: fmt.println("string", a)
+		}
+	}
+}
+
 main :: proc() {
 	fmt.println("Hellope, everybody!")
+
+	variadic(1, 1.0 as f32, 1.0 as f64, "Hellope")
 }

+ 45 - 58
core/fmt.odin

@@ -288,6 +288,7 @@ print_any_to_buffer :: proc(buf: ^[]byte, arg: any) {
 		print_string_to_buffer(buf, "<nil>")
 		return
 	}
+
 	using Type_Info
 	match type info : arg.type_info {
 	case Named:
@@ -314,60 +315,54 @@ print_any_to_buffer :: proc(buf: ^[]byte, arg: any) {
 		}
 
 	case Integer:
-		if info.signed {
-			i: i64 = 0;
-			if arg.data != nil {
-				match info.size {
-				case 1:  i = (arg.data as ^i8)^   as i64
-				case 2:  i = (arg.data as ^i16)^  as i64
-				case 4:  i = (arg.data as ^i32)^  as i64
-				case 8:  i = (arg.data as ^i64)^  as i64
-				}
+		if arg.data != nil {
+			match type i : arg {
+			case i8:  print_i64_to_buffer(buf, i as i64)
+			case i16: print_i64_to_buffer(buf, i as i64)
+			case i32: print_i64_to_buffer(buf, i as i64)
+			case i64: print_i64_to_buffer(buf, i as i64)
+			case u8:  print_u64_to_buffer(buf, i as u64)
+			case u16: print_u64_to_buffer(buf, i as u64)
+			case u32: print_u64_to_buffer(buf, i as u64)
+			case u64: print_u64_to_buffer(buf, i as u64)
 			}
-			print_i64_to_buffer(buf, i)
 		} else {
-			i: u64 = 0;
-			if arg.data != nil {
-				match info.size {
-				case 1:  i = (arg.data as ^u8)^   as u64
-				case 2:  i = (arg.data as ^u16)^  as u64
-				case 4:  i = (arg.data as ^u32)^  as u64
-				case 8:  i = (arg.data as ^u64)^  as u64
-				}
-			}
-			print_u64_to_buffer(buf, i)
+			print_u64_to_buffer(buf, 0)
 		}
 
 	case Float:
-		f: f64 = 0
 		if arg.data != nil {
-			match info.size {
-			case 4: f = (arg.data as ^f32)^ as f64
-			case 8: f = (arg.data as ^f64)^ as f64
+			match type f : arg {
+			case f32: print_f64_to_buffer(buf, f as f64)
+			case f64: print_f64_to_buffer(buf, f as f64)
 			}
+		} else {
+			print_f64_to_buffer(buf, 0)
 		}
-		print_f64_to_buffer(buf, f)
 
 	case String:
-		s := ""
 		if arg.data != nil {
-			s = (arg.data as ^string)^
+			match type s : arg {
+			case string: print_string_to_buffer(buf, s)
+			}
+		} else {
+			print_string_to_buffer(buf, "")
 		}
-		print_string_to_buffer(buf, s)
 
 	case Boolean:
-		v := false;
 		if arg.data != nil {
-			v = (arg.data as ^bool)^
+			match type b : arg {
+			case bool: print_bool_to_buffer(buf, b)
+			}
+		} else {
+			print_bool_to_buffer(buf, false)
 		}
-		print_bool_to_buffer(buf, v)
 
 	case Pointer:
 		if arg.data != nil {
-			if arg.type_info == type_info(^Type_Info) {
-				print_type_to_buffer(buf, (arg.data as ^^Type_Info)^)
-			} else {
-				print_pointer_to_buffer(buf, (arg.data as ^rawptr)^)
+			match type p : arg {
+			case ^Type_Info: print_type_to_buffer(buf, p)
+			default:         print_pointer_to_buffer(buf, (arg.data as ^rawptr)^)
 			}
 		} else {
 			print_pointer_to_buffer(buf, nil)
@@ -384,36 +379,23 @@ print_any_to_buffer :: proc(buf: ^[]byte, arg: any) {
 
 	case Enum:
 		value: i64 = 0
-		match type i : info.base {
-		case Integer:
-			if i.signed {
-				if arg.data != nil {
-					match i.size {
-					case 1:  value = (arg.data as ^i8)^   as i64
-					case 2:  value = (arg.data as ^i16)^  as i64
-					case 4:  value = (arg.data as ^i32)^  as i64
-					case 8:  value = (arg.data as ^i64)^  as i64
-					}
-				}
-			} else {
-				if arg.data != nil {
-					match i.size {
-					case 1:  value = (arg.data as ^u8)^   as i64
-					case 2:  value = (arg.data as ^u16)^  as i64
-					case 4:  value = (arg.data as ^u32)^  as i64
-					case 8:  value = (arg.data as ^u64)^  as i64
-					}
-				}
-			}
+
+		match type i : make_any(info.base, arg.data) {
+		case i8:  value = i as i64
+		case i16: value = i as i64
+		case i32: value = i as i64
+		case i64: value = i as i64
+		case u8:  value = i as i64
+		case u16: value = i as i64
+		case u32: value = i as i64
+		case u64: value = i as i64
 		}
 		print_string_to_buffer(buf, __enum_to_string(arg.type_info, value))
 
-
 	case Array:
 		bprintf(buf, "[%]%{", info.count, info.elem)
 		defer print_string_to_buffer(buf, "}")
 
-
 		for i := 0; i < info.count; i++ {
 			if i > 0 {
 				print_string_to_buffer(buf, ", ")
@@ -466,6 +448,11 @@ print_any_to_buffer :: proc(buf: ^[]byte, arg: any) {
 
 
 	case Struct:
+		if arg.data == nil {
+			print_string_to_buffer(buf, "nil")
+			return
+		}
+
 		bprintf(buf, "%{", arg.type_info)
 		defer print_string_to_buffer(buf, "}")
 

+ 13 - 3
src/checker/expr.cpp

@@ -1543,22 +1543,32 @@ void check_comparison(Checker *c, Operand *x, Operand *y, Token op) {
 
 	if (check_is_assignable_to(c, x, y->type) ||
 	    check_is_assignable_to(c, y, x->type)) {
+		Type *err_type = x->type;
 		b32 defined = false;
 		switch (op.kind) {
 		case Token_CmpEq:
 		case Token_NotEq:
-			defined = is_type_comparable(get_enum_base_type(base_type(x->type)));
+			defined = is_type_comparable(x->type);
 			break;
 		case Token_Lt:
 		case Token_Gt:
 		case Token_LtEq:
 		case Token_GtEq: {
-			defined = is_type_ordered(get_enum_base_type(base_type(x->type)));
+			defined = is_type_ordered(x->type);
 		} break;
 		}
 
+		// CLEANUP(bill) NOTE(bill): there is an auto assignment to `any` which needs to be checked
+		if (is_type_any(x->type) && !is_type_any(y->type)) {
+			err_type = x->type;
+			defined = false;
+		} else if (is_type_any(y->type) && !is_type_any(x->type)) {
+			err_type = y->type;
+			defined = false;
+		}
+
 		if (!defined) {
-			gbString type_string = type_to_string(x->type);
+			gbString type_string = type_to_string(err_type);
 			defer (gb_string_free(type_string));
 			err_str = gb_string_make(c->tmp_allocator,
 			                         gb_bprintf("operator `%.*s` not defined for type `%s`", LIT(op.string), type_string));

+ 52 - 22
src/checker/stmt.cpp

@@ -302,6 +302,18 @@ Type *check_assignment_variable(Checker *c, Operand *op_a, AstNode *lhs) {
 	return op_a->type;
 }
 
+b32 check_valid_type_match_type(Type *type, b32 *is_union_ptr, b32 *is_any) {
+	if (is_type_pointer(type)) {
+		*is_union_ptr = is_type_union(type_deref(type));
+		return *is_union_ptr;
+	}
+	if (is_type_any(type)) {
+		*is_any = true;
+		return *is_any;
+	}
+	return false;
+}
+
 
 void check_stmt(Checker *c, AstNode *node, u32 flags) {
 	u32 prev_stmt_state_flags = c->context.stmt_state_flags;
@@ -729,17 +741,18 @@ void check_stmt(Checker *c, AstNode *node, u32 flags) {
 		check_open_scope(c, node);
 		defer (check_close_scope(c));
 
+		b32 is_union_ptr = false;
+		b32 is_any = false;
 
 		check_expr(c, &x, ms->tag);
 		check_assignment(c, &x, NULL, make_string("type match expression"));
-		if (!is_type_pointer(x.type) || !is_type_union(type_deref(x.type))) {
+		if (!check_valid_type_match_type(x.type, &is_union_ptr, &is_any)) {
 			gbString str = type_to_string(x.type);
 			defer (gb_string_free(str));
 			error(ast_node_token(x.expr),
-			      "Expected a pointer to a union for this type match expression, got `%s`", str);
+			      "Invalid type for this type match expression, got `%s`", str);
 			break;
 		}
-		Type *base_union = base_type(type_deref(x.type));
 
 
 		// NOTE(bill): Check for multiple defaults
@@ -787,27 +800,39 @@ void check_stmt(Checker *c, AstNode *node, u32 flags) {
 			}
 			ast_node(cc, CaseClause, stmt);
 
+			// TODO(bill): Make robust
+			Type *bt = base_type(type_deref(x.type));
+
+
 			AstNode *type_expr = cc->list.count > 0 ? cc->list[0] : NULL;
-			Type *tag_type = NULL;
+			Type *case_type = NULL;
 			if (type_expr != NULL) { // Otherwise it's a default expression
 				Operand y = {};
 				check_expr_or_type(c, &y, type_expr);
-				b32 tag_type_found = false;
-				for (isize i = 0; i < base_union->Record.field_count; i++) {
-					Entity *f = base_union->Record.fields[i];
-					if (are_types_identical(f->type, y.type)) {
-						tag_type_found = true;
-						break;
+
+				if (is_union_ptr) {
+					GB_ASSERT(is_type_union(bt));
+					b32 tag_type_found = false;
+					for (isize i = 0; i < bt->Record.field_count; i++) {
+						Entity *f = bt->Record.fields[i];
+						if (are_types_identical(f->type, y.type)) {
+							tag_type_found = true;
+							break;
+						}
 					}
+					if (!tag_type_found) {
+						gbString type_str = type_to_string(y.type);
+						defer (gb_string_free(type_str));
+						error(ast_node_token(y.expr),
+						      "Unknown tag type, got `%s`", type_str);
+						continue;
+					}
+					case_type = y.type;
+				} else if (is_any) {
+					case_type = y.type;
+				} else {
+					GB_PANIC("Unknown type to type match statement");
 				}
-				if (!tag_type_found) {
-					gbString type_str = type_to_string(y.type);
-					defer (gb_string_free(type_str));
-					error(ast_node_token(y.expr),
-					      "Unknown tag type, got `%s`", type_str);
-					continue;
-				}
-				tag_type = y.type;
 
 				HashKey key = hash_pointer(y.type);
 				auto *found = map_get(&seen, key);
@@ -823,14 +848,19 @@ void check_stmt(Checker *c, AstNode *node, u32 flags) {
 					break;
 				}
 				map_set(&seen, key, cast(b32)true);
-
 			}
 
 			check_open_scope(c, stmt);
-			if (tag_type != NULL) {
+			if (case_type != NULL) {
+				add_type_info_type(c, case_type);
+
 				// NOTE(bill): Dummy type
-				Type *tag_ptr_type = make_type_pointer(c->allocator, tag_type);
-				Entity *tag_var = make_entity_variable(c->allocator, c->context.scope, ms->var->Ident, tag_ptr_type);
+				Type *tt = case_type;
+				if (is_union_ptr) {
+					tt = make_type_pointer(c->allocator, case_type);
+					add_type_info_type(c, tt);
+				}
+				Entity *tag_var = make_entity_variable(c->allocator, c->context.scope, ms->var->Ident, tt);
 				tag_var->flags |= EntityFlag_Used;
 				add_entity(c, c->context.scope, ms->var, tag_var);
 				add_entity_use(c, ms->var, tag_var);

+ 10 - 11
src/checker/types.cpp

@@ -396,7 +396,14 @@ gb_global Type *t_context              = NULL;
 gb_global Type *t_context_ptr          = NULL;
 
 
-
+Type *get_enum_base_type(Type *t) {
+	Type *bt = base_type(t);
+	if (bt->kind == Type_Record && bt->Record.kind == TypeRecord_Enum) {
+		GB_ASSERT(bt->Record.enum_base != NULL);
+		return bt->Record.enum_base;
+	}
+	return t;
+}
 
 b32 is_type_named(Type *t) {
 	if (t->kind == Type_Basic) {
@@ -457,7 +464,7 @@ b32 is_type_untyped(Type *t) {
 	return false;
 }
 b32 is_type_ordered(Type *t) {
-	t = base_type(t);
+	t = base_type(get_enum_base_type(t));
 	if (t->kind == Type_Basic) {
 		return (t->Basic.flags & BasicFlag_Ordered) != 0;
 	}
@@ -581,14 +588,6 @@ b32 is_type_raw_union(Type *t) {
 	return (t->kind == Type_Record && t->Record.kind == TypeRecord_RawUnion);
 }
 
-Type *get_enum_base_type(Type *t) {
-	Type *bt = base_type(t);
-	if (is_type_enum(bt)) {
-		return bt->Record.enum_base;
-	}
-	return t;
-}
-
 b32 is_type_any(Type *t) {
 	t = base_type(t);
 	return (t->kind == Type_Basic && t->Basic.kind == Basic_any);
@@ -626,7 +625,7 @@ b32 type_has_nil(Type *t) {
 
 
 b32 is_type_comparable(Type *t) {
-	t = base_type(t);
+	t = base_type(get_enum_base_type(t));
 	switch (t->kind) {
 	case Type_Basic:
 		return t->kind != Basic_UntypedNil;

+ 1 - 1
src/parser.cpp

@@ -2514,7 +2514,7 @@ AstNode *parse_type_case_clause(AstFile *f) {
 	Token token = f->curr_token;
 	AstNodeArray clause = make_ast_node_array(f);
 	if (allow_token(f, Token_case)) {
-		array_add(&clause, parse_expr(f, false));
+		array_add(&clause, parse_type(f));
 	} else {
 		expect_token(f, Token_default);
 	}

+ 57 - 39
src/ssa.cpp

@@ -1908,7 +1908,6 @@ ssaValue *ssa_emit_conv(ssaProcedure *proc, ssaValue *value, Type *t) {
 		return value;
 	}
 
-
 	Type *src = base_type(get_enum_base_type(src_type));
 	Type *dst = base_type(get_enum_base_type(t));
 
@@ -2628,6 +2627,9 @@ ssaValue *ssa_build_single_expr(ssaProcedure *proc, AstNode *expr, TypeAndValue
 	case_end;
 
 	case_ast_node(be, BinaryExpr, expr);
+		ssaValue *left = ssa_build_expr(proc, be->left);
+		Type *type = default_type(tv->type);
+
 		switch (be->op.kind) {
 		case Token_Add:
 		case Token_Sub:
@@ -2639,11 +2641,10 @@ ssaValue *ssa_build_single_expr(ssaProcedure *proc, AstNode *expr, TypeAndValue
 		case Token_Xor:
 		case Token_AndNot:
 		case Token_Shl:
-		case Token_Shr:
-			return ssa_emit_arith(proc, be->op.kind,
-			                      ssa_build_expr(proc, be->left),
-			                      ssa_build_expr(proc, be->right),
-			                      tv->type);
+		case Token_Shr: {
+			ssaValue *right = ssa_build_expr(proc, be->right);
+			return ssa_emit_arith(proc, be->op.kind, left, right, type);
+		}
 
 
 		case Token_CmpEq:
@@ -2652,11 +2653,9 @@ ssaValue *ssa_build_single_expr(ssaProcedure *proc, AstNode *expr, TypeAndValue
 		case Token_LtEq:
 		case Token_Gt:
 		case Token_GtEq: {
-			ssaValue *left  = ssa_build_expr(proc, be->left);
 			ssaValue *right = ssa_build_expr(proc, be->right);
-
 			ssaValue *cmp = ssa_emit_comp(proc, be->op.kind, left, right);
-			return ssa_emit_conv(proc, cmp, default_type(tv->type));
+			return ssa_emit_conv(proc, cmp, type);
 		} break;
 
 		case Token_CmpAnd:
@@ -2665,19 +2664,19 @@ ssaValue *ssa_build_single_expr(ssaProcedure *proc, AstNode *expr, TypeAndValue
 
 		case Token_as:
 			ssa_emit_comment(proc, make_string("cast - as"));
-			return ssa_emit_conv(proc, ssa_build_expr(proc, be->left), tv->type);
+			return ssa_emit_conv(proc, left, type);
 
 		case Token_transmute:
 			ssa_emit_comment(proc, make_string("cast - transmute"));
-			return ssa_emit_transmute(proc, ssa_build_expr(proc, be->left), tv->type);
+			return ssa_emit_transmute(proc, left, type);
 
 		case Token_down_cast:
 			ssa_emit_comment(proc, make_string("cast - down_cast"));
-			return ssa_emit_down_cast(proc, ssa_build_expr(proc, be->left), tv->type);
+			return ssa_emit_down_cast(proc, left, type);
 
 		case Token_union_cast:
 			ssa_emit_comment(proc, make_string("cast - union_cast"));
-			return ssa_emit_union_cast(proc, ssa_build_expr(proc, be->left), tv->type);
+			return ssa_emit_union_cast(proc, left, type);
 
 		default:
 			GB_PANIC("Invalid binary expression");
@@ -4284,14 +4283,17 @@ void ssa_build_stmt(ssaProcedure *proc, AstNode *node) {
 		gbAllocator allocator = proc->module->allocator;
 
 		ssaValue *parent = ssa_build_expr(proc, ms->tag);
-		Type *union_type = type_deref(ssa_type(parent));
-		GB_ASSERT(is_type_union(union_type));
-
-		ssa_emit_comment(proc, make_string("get union's tag"));
-		ssaValue *tag_index = ssa_emit_union_tag_ptr(proc, parent);
-		tag_index = ssa_emit_load(proc, tag_index);
-
-		ssaValue *data = ssa_emit_conv(proc, parent, t_rawptr);
+		b32 is_union_ptr = false;
+		b32 is_any = false;
+		GB_ASSERT(check_valid_type_match_type(ssa_type(parent), &is_union_ptr, &is_any));
+
+		ssaValue *tag_index = NULL;
+		ssaValue *union_data = NULL;
+		if (is_union_ptr) {
+			ssa_emit_comment(proc, make_string("get union's tag"));
+			tag_index = ssa_emit_load(proc, ssa_emit_union_tag_ptr(proc, parent));
+			union_data = ssa_emit_conv(proc, parent, t_rawptr);
+		}
 
 		ssaBlock *start_block = ssa_add_block(proc, node, "type-match.case.first");
 		ssa_emit_jump(proc, start_block);
@@ -4301,12 +4303,12 @@ void ssa_build_stmt(ssaProcedure *proc, AstNode *node) {
 
 		ast_node(body, BlockStmt, ms->body);
 
-
 		String tag_var_name = ms->var->Ident.string;
 
 		AstNodeArray default_stmts = {};
 		ssaBlock *default_block = NULL;
 
+
 		isize case_count = body->stmts.count;
 		for_array(i, body->stmts) {
 			AstNode *clause = body->stmts[i];
@@ -4325,27 +4327,43 @@ void ssa_build_stmt(ssaProcedure *proc, AstNode *node) {
 			Scope *scope = *map_get(&proc->module->info->scopes, hash_pointer(clause));
 			Entity *tag_var_entity = current_scope_lookup_entity(scope, tag_var_name);
 			GB_ASSERT_MSG(tag_var_entity != NULL, "%.*s", LIT(tag_var_name));
-			ssaValue *tag_var = ssa_add_local(proc, tag_var_entity);
-			ssaValue *data_ptr = ssa_emit_conv(proc, data, tag_var_entity->type);
-			ssa_emit_store(proc, tag_var, data_ptr);
-
 
-
-			Type *bt = type_deref(tag_var_entity->type);
-			ssaValue *index = NULL;
-			Type *ut = base_type(union_type);
-			GB_ASSERT(ut->Record.kind == TypeRecord_Union);
-			for (isize field_index = 1; field_index < ut->Record.field_count; field_index++) {
-				Entity *f = base_type(union_type)->Record.fields[field_index];
-				if (are_types_identical(f->type, bt)) {
-					index = ssa_make_const_int(allocator, field_index);
-					break;
+			ssaBlock *next_cond = NULL;
+			ssaValue *cond = NULL;
+
+			if (is_union_ptr) {
+				Type *bt = type_deref(tag_var_entity->type);
+				ssaValue *index = NULL;
+				Type *ut = base_type(type_deref(ssa_type(parent)));
+				GB_ASSERT(ut->Record.kind == TypeRecord_Union);
+				for (isize field_index = 1; field_index < ut->Record.field_count; field_index++) {
+					Entity *f = ut->Record.fields[field_index];
+					if (are_types_identical(f->type, bt)) {
+						index = ssa_make_const_int(allocator, field_index);
+						break;
+					}
 				}
+				GB_ASSERT(index != NULL);
+
+				ssaValue *tag_var = ssa_add_local(proc, tag_var_entity);
+				ssaValue *data_ptr = ssa_emit_conv(proc, union_data, tag_var_entity->type);
+				ssa_emit_store(proc, tag_var, data_ptr);
+
+				cond = ssa_emit_comp(proc, Token_CmpEq, tag_index, index);
+			} else if (is_any) {
+				Type *type = tag_var_entity->type;
+				ssaValue *any_data = ssa_emit_struct_ev(proc, parent, 1);
+				ssaValue *data = ssa_emit_conv(proc, any_data, make_type_pointer(proc->module->allocator, type));
+				ssa_module_add_value(proc->module, tag_var_entity, data);
+
+				ssaValue *any_ti  = ssa_emit_struct_ev(proc, parent, 0);
+				ssaValue *case_ti = ssa_type_info(proc, type);
+				cond = ssa_emit_comp(proc, Token_CmpEq, any_ti, case_ti);
+			} else {
+				GB_PANIC("Invalid type for type match statement");
 			}
-			GB_ASSERT(index != NULL);
 
-			ssaBlock *next_cond = ssa_add_block(proc, clause, "type-match.case.next");
-			ssaValue *cond = ssa_emit_comp(proc, Token_CmpEq, tag_index, index);
+			next_cond = ssa_add_block(proc, clause, "type-match.case.next");
 			ssa_emit_if(proc, cond, body, next_cond);
 			proc->curr_block = next_cond;