Browse Source

Allow `case nil` within a type switch statement (experimental idea)

gingerBill 2 years ago
parent
commit
93f7d3bfb9
2 changed files with 40 additions and 7 deletions
  1. 26 2
      src/check_stmt.cpp
  2. 14 5
      src/llvm_backend_stmt.cpp

+ 26 - 2
src/check_stmt.cpp

@@ -1184,6 +1184,8 @@ gb_internal void check_type_switch_stmt(CheckerContext *ctx, Ast *node, u32 mod_
 		return;
 	}
 
+
+	Ast *nil_seen = nullptr;
 	PtrSet<Type *> seen = {};
 	defer (ptr_set_destroy(&seen));
 
@@ -1194,6 +1196,7 @@ gb_internal void check_type_switch_stmt(CheckerContext *ctx, Ast *node, u32 mod_
 		}
 		ast_node(cc, CaseClause, stmt);
 
+		bool saw_nil = false;
 		// TODO(bill): Make robust
 		Type *bt = base_type(type_deref(x.type));
 
@@ -1202,6 +1205,25 @@ gb_internal void check_type_switch_stmt(CheckerContext *ctx, Ast *node, u32 mod_
 			if (type_expr != nullptr) { // Otherwise it's a default expression
 				Operand y = {};
 				check_expr_or_type(ctx, &y, type_expr);
+
+				if (is_operand_nil(y)) {
+					if (!type_has_nil(type_deref(x.type))) {
+						error(type_expr, "'nil' case is not allowed for the type '%s'", type_to_string(type_deref(x.type)));
+						continue;
+					}
+					saw_nil = true;
+
+					if (nil_seen) {
+						ERROR_BLOCK();
+						error(type_expr, "'nil' case has already been handled previously");
+						error_line("\t 'nil' was already previously seen at %s", token_pos_to_string(ast_token(nil_seen).pos));
+					} else {
+						nil_seen = type_expr;
+					}
+					case_type = y.type;
+					continue;
+				}
+
 				if (y.mode != Addressing_Type) {
 					gbString str = expr_to_string(type_expr);
 					error(type_expr, "Expected a type as a case, got %s", str);
@@ -1255,14 +1277,16 @@ gb_internal void check_type_switch_stmt(CheckerContext *ctx, Ast *node, u32 mod_
 			is_reference = true;
 		}
 
-		if (cc->list.count > 1) {
+		if (cc->list.count > 1 || saw_nil) {
 			case_type = nullptr;
 		}
 		if (case_type == nullptr) {
 			case_type = x.type;
 		}
 		if (switch_kind == TypeSwitch_Any) {
-			add_type_info_type(ctx, case_type);
+			if (!is_type_untyped(case_type)) {
+				add_type_info_type(ctx, case_type);
+			}
 		}
 
 		check_open_scope(ctx, stmt);

+ 14 - 5
src/llvm_backend_stmt.cpp

@@ -1423,9 +1423,11 @@ gb_internal void lb_build_type_switch_stmt(lbProcedure *p, AstTypeSwitchStmt *ss
 				continue;
 			}
 			Entity *case_entity = implicit_entity_of_node(clause);
-			max_size = gb_max(max_size, type_size_of(case_entity->type));
-			max_align = gb_max(max_align, type_align_of(case_entity->type));
-			variants_found = true;
+			if (!is_type_untyped_nil(case_entity->type)) {
+				max_size = gb_max(max_size, type_size_of(case_entity->type));
+				max_align = gb_max(max_align, type_align_of(case_entity->type));
+				variants_found = true;
+			}
 		}
 		if (variants_found) {
 			Type *t = alloc_type_array(t_u8, max_size);
@@ -1449,6 +1451,8 @@ gb_internal void lb_build_type_switch_stmt(lbProcedure *p, AstTypeSwitchStmt *ss
 		if (p->debug_info != nullptr) {
 			LLVMSetCurrentDebugLocation2(p->builder, lb_debug_location_from_ast(p, clause));
 		}
+
+		bool saw_nil = false;
 		for (Ast *type_expr : cc->list) {
 			Type *case_type = type_of_expr(type_expr);
 			lbValue on_val = {};
@@ -1457,7 +1461,12 @@ gb_internal void lb_build_type_switch_stmt(lbProcedure *p, AstTypeSwitchStmt *ss
 				on_val = lb_const_union_tag(m, ut, case_type);
 
 			} else if (switch_kind == TypeSwitch_Any) {
-				on_val = lb_typeid(m, case_type);
+				if (is_type_untyped_nil(case_type)) {
+					saw_nil = true;
+					on_val = lb_const_nil(m, t_typeid);
+				} else {
+					on_val = lb_typeid(m, case_type);
+				}
 			}
 			GB_ASSERT(on_val.value != nullptr);
 			LLVMAddCase(switch_instr, on_val.value, body->block);
@@ -1469,7 +1478,7 @@ gb_internal void lb_build_type_switch_stmt(lbProcedure *p, AstTypeSwitchStmt *ss
 
 		bool by_reference = (case_entity->flags & EntityFlag_Value) == 0;
 
-		if (cc->list.count == 1) {
+		if (cc->list.count == 1 && !saw_nil) {
 			lbValue data = {};
 			if (switch_kind == TypeSwitch_Union) {
 				data = union_data;