瀏覽代碼

Reduce stack usage of some type `switch` `case`s

gingerBill 2 年之前
父節點
當前提交
7d4e9497eb
共有 1 個文件被更改,包括 66 次插入6 次删除
  1. 66 6
      src/llvm_backend_stmt.cpp

+ 66 - 6
src/llvm_backend_stmt.cpp

@@ -1397,6 +1397,52 @@ gb_internal void lb_build_type_switch_stmt(lbProcedure *p, AstTypeSwitchStmt *ss
 		switch_instr = LLVMBuildSwitch(p->builder, tag.value, else_block->block, cast(unsigned)num_cases);
 	}
 
+	bool all_by_reference = false;
+	for (Ast *clause : body->stmts) {
+		ast_node(cc, CaseClause, clause);
+		if (cc->list.count != 1) {
+			continue;
+		}
+		Entity *case_entity = implicit_entity_of_node(clause);
+		all_by_reference |= (case_entity->flags & EntityFlag_Value) == 0;
+		break;
+	}
+
+	// NOTE(bill, 2023-02-17): In the case of a pass by value, the value does need to be copied
+	// to prevent errors such as these:
+	//
+	//	switch v in some_union {
+	//	case i32:
+	//		fmt.println(v) // 'i32'
+	//		some_union = f32(123)
+	//		fmt.println(v) // if `v` is an implicit reference, then the data is now completely corrupted
+	//	case f32:
+	//		fmt.println(v)
+	//	}
+	//
+	lbAddr backing_data = {};
+	if (!all_by_reference) {
+		bool variants_found = false;
+		i64 max_size = 0;
+		i64 max_align = 1;
+		for (Ast *clause : body->stmts) {
+			ast_node(cc, CaseClause, clause);
+			if (cc->list.count != 1) {
+				continue;
+			}
+			Entity *case_entity = implicit_entity_of_node(clause);
+			max_size = gb_max(max_size, type_size_of(case_entity->type));
+			max_align = gb_max(max_align, type_align_of(case_entity->type));
+			variants_found = true;
+		}
+		if (variants_found) {
+			Type *t = alloc_type_array(t_u8, max_size);
+			backing_data = lb_add_local(p, t, nullptr, false, true);
+			GB_ASSERT(lb_try_update_alignment(backing_data.addr, cast(unsigned)max_align));
+		}
+	}
+	lbValue backing_ptr = backing_data.addr;
+
 	for (Ast *clause : body->stmts) {
 		ast_node(cc, CaseClause, clause);
 		lb_open_scope(p, cc->scope);
@@ -1427,8 +1473,6 @@ gb_internal void lb_build_type_switch_stmt(lbProcedure *p, AstTypeSwitchStmt *ss
 
 		Entity *case_entity = implicit_entity_of_node(clause);
 
-		lbValue value = parent_value;
-
 		lb_start_block(p, body);
 
 		bool by_reference = (case_entity->flags & EntityFlag_Value) == 0;
@@ -1444,13 +1488,29 @@ gb_internal void lb_build_type_switch_stmt(lbProcedure *p, AstTypeSwitchStmt *ss
 			Type *ct = case_entity->type;
 			Type *ct_ptr = alloc_type_pointer(ct);
 
-			value = lb_emit_conv(p, data, ct_ptr);
-			if (!by_reference) {
-				value = lb_emit_load(p, value);
+			lbValue ptr = {};
+
+			if (backing_data.addr.value) { // by value
+				GB_ASSERT(!by_reference);
+				// make a copy of the case value
+				lb_mem_copy_non_overlapping(p,
+				                            backing_ptr, // dst
+				                            data,        // src
+				                            lb_const_int(p->module, t_int, type_size_of(case_entity->type)));
+				ptr = lb_emit_conv(p, backing_ptr, ct_ptr);
+
+			} else { // by reference
+				GB_ASSERT(by_reference);
+				ptr = lb_emit_conv(p, data, ct_ptr);
 			}
+			GB_ASSERT(are_types_identical(case_entity->type, type_deref(ptr.type)));
+			lb_add_entity(p->module, case_entity, ptr);
+			lb_add_debug_local_variable(p, ptr.value, case_entity->type, case_entity->token);
+		} else {
+			// TODO(bill): is the correct expected behaviour?
+			lb_store_type_case_implicit(p, clause, parent_value);
 		}
 
-		lb_store_type_case_implicit(p, clause, value);
 		lb_type_case_body(p, ss->label, clause, body, done);
 	}