Jelajahi Sumber

SSA - Basic block optimizations

Ginger Bill 9 tahun lalu
induk
melakukan
2e0b260d3a
5 mengubah file dengan 164 tambahan dan 72 penghapusan
  1. 2 6
      src/array.cpp
  2. 1 1
      src/checker/stmt.cpp
  3. 19 19
      src/codegen/print_llvm.cpp
  4. 141 45
      src/codegen/ssa.cpp
  5. 1 1
      src/main.cpp

+ 2 - 6
src/array.cpp

@@ -8,16 +8,12 @@ struct Array {
 	isize capacity;
 
 	T &operator[](isize index) {
-		if (count > 0) {
-			GB_ASSERT_MSG(0 <= index && index < count, "Index out of bounds");
-		}
+		GB_ASSERT_MSG(0 <= index && index < count, "Index out of bounds");
 		return data[index];
 	}
 
 	T const &operator[](isize index) const {
-		if (count > 0) {
-			GB_ASSERT_MSG(0 <= index && index < count, "Index out of bounds");
-		}
+		GB_ASSERT_MSG(0 <= index && index < count, "Index out of bounds");
 		return data[index];
 	}
 };

+ 1 - 1
src/checker/stmt.cpp

@@ -1279,7 +1279,7 @@ void check_stmt(Checker *c, AstNode *node, u32 flags) {
 			}
 			ast_node(cc, CaseClause, stmt);
 
-			AstNode *type_expr = cc->list[0];
+			AstNode *type_expr = cc->list.count > 0 ? cc->list[0] : NULL;
 			Type *tag_type = NULL;
 			if (type_expr != NULL) { // Otherwise it's a default expression
 				Operand y = {};

+ 19 - 19
src/codegen/print_llvm.cpp

@@ -567,7 +567,7 @@ void ssa_print_value(ssaFileBuffer *f, ssaModule *m, ssaValue *value, Type *type
 		ssa_print_encoded_global(f, value->Proc.name, (value->Proc.tags & (ProcTag_foreign|ProcTag_link_name)) != 0);
 		break;
 	case ssaValue_Instr:
-		ssa_fprintf(f, "%%%d", value->id);
+		ssa_fprintf(f, "%%%d", value->index);
 		break;
 	}
 }
@@ -591,7 +591,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 
 	case ssaInstr_Local: {
 		Type *type = instr->Local.entity->type;
-		ssa_fprintf(f, "%%%d = alloca ", value->id);
+		ssa_fprintf(f, "%%%d = alloca ", value->index);
 		ssa_print_type(f, m, type);
 		ssa_fprintf(f, ", align %lld\n", type_align_of(m->sizes, m->allocator, type));
 	} break;
@@ -602,11 +602,11 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 		ssa_print_type(f, m, type);
 		ssa_fprintf(f, " zeroinitializer, ");
 		ssa_print_type(f, m, type);
-		ssa_fprintf(f, "* %%%d\n", instr->ZeroInit.address->id);
+		ssa_fprintf(f, "* %%%d\n", instr->ZeroInit.address->index);
 	} break;
 
 	case ssaInstr_Store: {
-		Type *type = ssa_type(instr);
+		Type *type = ssa_type(instr->Store.value);
 		ssa_fprintf(f, "store ");
 		ssa_print_type(f, m, type);
 		ssa_fprintf(f, " ");
@@ -620,7 +620,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 
 	case ssaInstr_Load: {
 		Type *type = instr->Load.type;
-		ssa_fprintf(f, "%%%d = load ", value->id);
+		ssa_fprintf(f, "%%%d = load ", value->index);
 		ssa_print_type(f, m, type);
 		ssa_fprintf(f, ", ");
 		ssa_print_type(f, m, type);
@@ -631,7 +631,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 
 	case ssaInstr_GetElementPtr: {
 		Type *et = instr->GetElementPtr.elem_type;
-		ssa_fprintf(f, "%%%d = getelementptr ", value->id);
+		ssa_fprintf(f, "%%%d = getelementptr ", value->index);
 		if (instr->GetElementPtr.inbounds) {
 			ssa_fprintf(f, "inbounds ");
 		}
@@ -653,16 +653,16 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 	} break;
 
 	case ssaInstr_Phi: {
-		ssa_fprintf(f, "%%%d = phi ", value->id);
+		ssa_fprintf(f, "%%%d = phi ", value->index);
 		ssa_print_type(f, m, instr->Phi.type);
-		ssa_fprintf(f, " ", value->id);
+		ssa_fprintf(f, " ", value->index);
 
 		for (isize i = 0; i < instr->Phi.edges.count; i++) {
-			ssaValue *edge = instr->Phi.edges[i];
 			if (i > 0) {
 				ssa_fprintf(f, ", ");
 			}
 
+			ssaValue *edge = instr->Phi.edges[i];
 			ssaBlock *block = NULL;
 			if (instr->parent != NULL &&
 			    i < instr->parent->preds.count) {
@@ -680,7 +680,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 
 	case ssaInstr_ExtractValue: {
 		Type *et = instr->ExtractValue.elem_type;
-		ssa_fprintf(f, "%%%d = extractvalue ", value->id);
+		ssa_fprintf(f, "%%%d = extractvalue ", value->index);
 
 		ssa_print_type(f, m, et);
 		ssa_fprintf(f, " ");
@@ -689,7 +689,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 	} break;
 
 	case ssaInstr_NoOp: {;
-		ssa_fprintf(f, "%%%d = add i32 0, 0\n", value->id);
+		ssa_fprintf(f, "%%%d = add i32 0, 0\n", value->index);
 	} break;
 
 	case ssaInstr_Br: {;
@@ -698,7 +698,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 			ssa_print_type(f, m, t_bool);
 			ssa_fprintf(f, " ");
 			ssa_print_value(f, m, instr->Br.cond, t_bool);
-			ssa_fprintf(f, ", ", instr->Br.cond->id);
+			ssa_fprintf(f, ", ", instr->Br.cond->index);
 		}
 		ssa_fprintf(f, "label ");
 		ssa_fprintf(f, "%%"); ssa_print_block_name(f, instr->Br.true_block);
@@ -727,7 +727,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 
 	case ssaInstr_Conv: {
 		auto *c = &instr->Conv;
-		ssa_fprintf(f, "%%%d = %.*s ", value->id, LIT(ssa_conv_strings[c->kind]));
+		ssa_fprintf(f, "%%%d = %.*s ", value->index, LIT(ssa_conv_strings[c->kind]));
 		ssa_print_type(f, m, c->from);
 		ssa_fprintf(f, " ");
 		ssa_print_value(f, m, c->value, c->from);
@@ -749,7 +749,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 			elem_type = base_type(elem_type->Vector.elem);
 		}
 
-		ssa_fprintf(f, "%%%d = ", value->id);
+		ssa_fprintf(f, "%%%d = ", value->index);
 
 		if (gb_is_between(bo->op, Token__ComparisonBegin+1, Token__ComparisonEnd-1)) {
 			if (is_type_string(elem_type)) {
@@ -852,7 +852,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 		auto *call = &instr->Call;
 		Type *result_type = call->type;
 		if (result_type) {
-			ssa_fprintf(f, "%%%d = ", value->id);
+			ssa_fprintf(f, "%%%d = ", value->index);
 		}
 		ssa_fprintf(f, "call ");
 		if (result_type) {
@@ -887,7 +887,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 	} break;
 
 	case ssaInstr_Select: {
-		ssa_fprintf(f, "%%%d = select i1 ", value->id);
+		ssa_fprintf(f, "%%%d = select i1 ", value->index);
 		ssa_print_value(f, m, instr->Select.cond, t_bool);
 		ssa_fprintf(f, ", ");
 		ssa_print_type(f, m, ssa_type(instr->Select.true_value));
@@ -902,7 +902,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 
 	case ssaInstr_ExtractElement: {
 		Type *vt = ssa_type(instr->ExtractElement.vector);
-		ssa_fprintf(f, "%%%d = extractelement ", value->id);
+		ssa_fprintf(f, "%%%d = extractelement ", value->index);
 
 		ssa_print_type(f, m, vt);
 		ssa_fprintf(f, " ");
@@ -918,7 +918,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 	case ssaInstr_InsertElement: {
 		auto *ie = &instr->InsertElement;
 		Type *vt = ssa_type(ie->vector);
-		ssa_fprintf(f, "%%%d = insertelement ", value->id);
+		ssa_fprintf(f, "%%%d = insertelement ", value->index);
 
 		ssa_print_type(f, m, vt);
 		ssa_fprintf(f, " ");
@@ -940,7 +940,7 @@ void ssa_print_instr(ssaFileBuffer *f, ssaModule *m, ssaValue *value) {
 	case ssaInstr_ShuffleVector: {
 		auto *sv = &instr->ShuffleVector;
 		Type *vt = ssa_type(sv->vector);
-		ssa_fprintf(f, "%%%d = shufflevector ", value->id);
+		ssa_fprintf(f, "%%%d = shufflevector ", value->index);
 
 		ssa_print_type(f, m, vt);
 		ssa_fprintf(f, " ");

+ 141 - 45
src/codegen/ssa.cpp

@@ -76,6 +76,7 @@ struct ssaBlock {
 	String label;
 	ssaProcedure *parent;
 	b32 added;
+	b32 is_dead;
 
 	Array<ssaValue *> instrs;
 	Array<ssaValue *> locals;
@@ -320,8 +321,7 @@ enum ssaValueKind {
 
 struct ssaValue {
 	ssaValueKind kind;
-	i32 id;
-
+	i32 index;
 	union {
 		struct {
 			Type *     type;
@@ -503,12 +503,10 @@ void ssa_destroy_module(ssaModule *m) {
 
 
 Type *ssa_type(ssaValue *value);
-Type *ssa_type(ssaInstr *instr) {
+Type *ssa_instr_type(ssaInstr *instr) {
 	switch (instr->kind) {
 	case ssaInstr_Local:
 		return instr->Local.type;
-	case ssaInstr_Store:
-		return ssa_type(instr->Store.value);
 	case ssaInstr_Load:
 		return instr->Load.type;
 	case ssaInstr_GetElementPtr:
@@ -528,8 +526,9 @@ Type *ssa_type(ssaInstr *instr) {
 	case ssaInstr_Call: {
 		Type *pt = base_type(instr->Call.type);
 		if (pt != NULL) {
-			if (pt->kind == Type_Tuple && pt->Tuple.variable_count == 1)
+			if (pt->kind == Type_Tuple && pt->Tuple.variable_count == 1) {
 				return pt->Tuple.variables[0]->type;
+			}
 			return pt;
 		}
 		return NULL;
@@ -565,11 +564,17 @@ Type *ssa_type(ssaValue *value) {
 	case ssaValue_Proc:
 		return value->Proc.type;
 	case ssaValue_Instr:
-		return ssa_type(&value->Instr);
+		return ssa_instr_type(&value->Instr);
 	}
 	return NULL;
 }
 
+void ssa_set_instr_parent(ssaValue *instr, ssaBlock *parent) {
+	if (instr->kind == ssaValue_Instr) {
+		instr->Instr.parent = parent;
+	}
+}
+
 ssaDebugInfo *ssa_add_debug_info_file(ssaProcedure *proc, AstFile *file) {
 	if (!proc->module->generate_debug_info) {
 		return NULL;
@@ -1092,6 +1097,35 @@ ssaBlock *ssa_add_block(ssaProcedure *proc, AstNode *node, char *label) {
 	return block;
 }
 
+void ssa_block_replace_pred(ssaBlock *b, ssaBlock *from, ssaBlock *to) {
+	for_array(i, b->preds) {
+		ssaBlock *pred = b->preds[i];
+		if (pred == from) {
+			b->preds[i] = to;
+		}
+	}
+}
+
+void ssa_block_replace_succ(ssaBlock *b, ssaBlock *from, ssaBlock *to) {
+	for_array(i, b->succs) {
+		ssaBlock *succ = b->succs[i];
+		if (succ == from) {
+			b->succs[i] = to;
+		}
+	}
+}
+
+b32 ssa_block_has_phi(ssaBlock *b) {
+	return b->instrs[0]->Instr.kind == ssaInstr_Phi;
+}
+
+
+
+
+
+
+
+
 void ssa_build_stmt(ssaProcedure *proc, AstNode *s);
 void ssa_emit_no_op(ssaProcedure *proc);
 void ssa_emit_jump(ssaProcedure *proc, ssaBlock *block);
@@ -1246,7 +1280,53 @@ void ssa_begin_procedure_body(ssaProcedure *proc) {
 	}
 }
 
-void ssa_remove_pred(ssaBlock *a, ssaBlock *b) {
+
+b32 ssa_is_instr_jump(ssaValue *v) {
+	if (v->kind != ssaValue_Instr) {
+		return false;
+	}
+	ssaInstr *i = &v->Instr;
+	if (i->kind != ssaInstr_Br) {
+		return false;
+	}
+
+	return i->Br.false_block == NULL;
+}
+
+Array<ssaValue *> ssa_get_block_phi_nodes(ssaBlock *b) {
+	Array<ssaValue *> phis = {};
+	for_array(i, b->instrs) {
+		ssaInstr *instr = &b->instrs[i]->Instr;
+		if (instr->kind != ssaInstr_Phi) {
+			phis = b->instrs;
+			phis.count = i;
+			return phis;
+		}
+	}
+	return phis;
+}
+
+void ssa_remove_pred(ssaBlock *b, ssaBlock *p) {
+	auto phis = ssa_get_block_phi_nodes(b);
+
+	isize i = 0;
+	for_array(j, b->preds) {
+		ssaBlock *pred = b->preds[j];
+		if (pred != p) {
+			b->preds[i] = b->preds[j];
+			for_array(k, phis) {
+				auto *phi = &phis[k]->Instr.Phi;
+				phi->edges[i] = phi->edges[j];
+			}
+			i++;
+		}
+	}
+
+	b->preds.count = i;
+	for_array(k, phis) {
+		auto *phi = &phis[k]->Instr.Phi;
+		phi->edges.count = i;
+	}
 
 }
 
@@ -1254,12 +1334,12 @@ void ssa_remove_dead_blocks(ssaProcedure *proc) {
 	isize j = 0;
 	for_array(i, proc->blocks) {
 		ssaBlock *b = proc->blocks[i];
-		if (b != NULL) {
-			// NOTE(bill): Swap order
-			b->index = j;
-			proc->blocks[j] = b;
-			j++;
+		if (b == NULL || b->is_dead) {
+			continue;
 		}
+		// NOTE(bill): Swap order
+		b->index = j;
+		proc->blocks[j++] = b;
 	}
 
 	proc->blocks.count = j;
@@ -1292,41 +1372,70 @@ void ssa_remove_unreachable_blocks(ssaProcedure *proc) {
 			for_array(j, b->succs) {
 				ssaBlock *c = b->succs[j];
 				if (c->index == BLACK) {
-					// ssa_remove_pred(c, b);
+					ssa_remove_pred(c, b);
 				}
 			}
 			// NOTE(bill): Mark as empty but don't actually free it
 			// As it's been allocated with an arena
-			proc->blocks[i] = NULL;
+			b->is_dead = true;
 		}
 	}
 	ssa_remove_dead_blocks(proc);
 }
 
+b32 ssa_opt_block_fusion(ssaProcedure *proc, ssaBlock *a) {
+	if (a->succs.count != 1) {
+		return false;
+	}
+	ssaBlock *b = a->succs[0];
+	if (b->preds.count != 1) {
+		return false;
+	}
+
+	if (ssa_block_has_phi(b)) {
+		return false;
+	}
+
+	array_pop(&a->instrs); // Remove branch at end
+	for_array(i, b->instrs) {
+		array_add(&a->instrs, b->instrs[i]);
+		ssa_set_instr_parent(b->instrs[i], a);
+	}
+
+	array_clear(&a->succs);
+	for_array(i, b->succs) {
+		array_add(&a->succs, b->succs[i]);
+	}
+
+	// Fix preds links
+	for_array(i, b->succs) {
+		ssa_block_replace_pred(b->succs[i], b, a);
+	}
+
+	proc->blocks[b->index]->is_dead = true;
+	return true;
+}
 
 void ssa_optimize_blocks(ssaProcedure *proc) {
 	ssa_remove_unreachable_blocks(proc);
 
-#if 0
-	b32 changed = false;
-	do {
+#if 1
+	b32 changed = true;
+	while (changed) {
+		changed = false;
 		for_array(i, proc->blocks) {
 			ssaBlock *b = proc->blocks[i];
 			if (b == NULL) {
 				continue;
 			}
+			GB_ASSERT(b->index == i);
 
-			// if (ssa_fuse_blocks(proc, b)) {
-			// 	changed = true;
-			// }
-
-			if (ssa_jump_threading(proc, b)) {
-				// x -> y -> z becomes x -> z if y is just a jump
+			if (ssa_opt_block_fusion(proc, b)) {
 				changed = true;
-				continue;
 			}
+			// TODO(bill): other simple block optimizations
 		}
-	} while (changed);
+	}
 #endif
 
 	ssa_remove_dead_blocks(proc);
@@ -1346,33 +1455,20 @@ void ssa_end_procedure_body(ssaProcedure *proc) {
 
 	ssa_optimize_blocks(proc);
 
-
-// Number blocks and registers
-	i32 reg_id = 0;
+// Number registers
+	i32 reg_index = 0;
 	for_array(i, proc->blocks) {
 		ssaBlock *b = proc->blocks[i];
+		b->index = i;
 		for_array(j, b->instrs) {
 			ssaValue *value = b->instrs[j];
 			GB_ASSERT(value->kind == ssaValue_Instr);
 			ssaInstr *instr = &value->Instr;
-			// NOTE(bill): Ignore non-returning instructions
-			switch (instr->kind) {
-			case ssaInstr_Comment:
-			case ssaInstr_ZeroInit:
-			case ssaInstr_Store:
-			case ssaInstr_Br:
-			case ssaInstr_Ret:
-			case ssaInstr_Unreachable:
-			case ssaInstr_StartupRuntime:
+			if (ssa_instr_type(instr) == NULL) { // NOTE(bill): Ignore non-returning instructions
 				continue;
-			case ssaInstr_Call:
-				if (instr->Call.type == NULL) {
-					continue;
-				}
-				break;
 			}
-			value->id = reg_id;
-			reg_id++;
+			value->index = reg_index;
+			reg_index++;
 		}
 	}
 }

+ 1 - 1
src/main.cpp

@@ -1,4 +1,4 @@
-#define DISPLAY_TIMING
+// #define DISPLAY_TIMING
 
 #include "common.cpp"
 #include "unicode.cpp"