Browse Source

Add trivial `switch` statement check to use a jump table

gingerBill 2 years ago
parent
commit
184563bbe1
4 changed files with 115 additions and 20 deletions
  1. 1 1
      src/tilde_backend.hpp
  2. 1 1
      src/tilde_const.cpp
  3. 4 6
      src/tilde_expr.cpp
  4. 109 12
      src/tilde_stmt.cpp

+ 1 - 1
src/tilde_backend.hpp

@@ -260,7 +260,7 @@ gb_internal TB_DebugType *cg_debug_type(cgModule *m, Type *type);
 
 gb_internal String cg_get_entity_name(cgModule *m, Entity *e);
 
-gb_internal cgValue cg_typeid(cgModule *m, Type *t);
+gb_internal cgValue cg_typeid(cgProcedure *m, Type *t);
 
 gb_internal cgValue cg_emit_ptr_offset(cgProcedure *p, cgValue ptr, cgValue index);
 gb_internal cgValue cg_emit_array_ep(cgProcedure *p, cgValue s, cgValue index);

+ 1 - 1
src/tilde_const.cpp

@@ -63,7 +63,7 @@ gb_internal cgValue cg_const_value(cgModule *m, cgProcedure *p, Type *type, Exac
 		return cg_const_nil(p, type);
 
 	case ExactValue_Typeid:
-		return cg_typeid(m, value.value_typeid);
+		return cg_typeid(p, value.value_typeid);
 
 	case ExactValue_Procedure:
 		{

+ 4 - 6
src/tilde_expr.cpp

@@ -120,7 +120,7 @@ gb_internal cgAddr cg_build_addr_from_entity(cgProcedure *p, Entity *e, Ast *exp
 	return cg_addr(v);
 }
 
-gb_internal cgValue cg_typeid(cgModule *m, Type *t) {
+gb_internal cgValue cg_typeid(cgProcedure *p, Type *t) {
 	GB_ASSERT("TODO(bill): cg_typeid");
 	return {};
 }
@@ -1747,10 +1747,10 @@ gb_internal cgValue cg_build_binary_expr(cgProcedure *p, Ast *expr) {
 			cgValue right = {};
 
 			if (be->left->tav.mode == Addressing_Type) {
-				left = cg_typeid(p->module, be->left->tav.type);
+				left = cg_typeid(p, be->left->tav.type);
 			}
 			if (be->right->tav.mode == Addressing_Type) {
-				right = cg_typeid(p->module, be->right->tav.type);
+				right = cg_typeid(p, be->right->tav.type);
 			}
 			if (left.node == nullptr)  left  = cg_build_expr(p, be->left);
 			if (right.node == nullptr) right = cg_build_expr(p, be->right);
@@ -1944,8 +1944,6 @@ gb_internal cgValue cg_build_expr(cgProcedure *p, Ast *expr) {
 
 
 gb_internal cgValue cg_build_expr_internal(cgProcedure *p, Ast *expr) {
-	cgModule *m = p->module;
-
 	expr = unparen_expr(expr);
 
 	TokenPos expr_pos = ast_token(expr).pos;
@@ -1964,7 +1962,7 @@ gb_internal cgValue cg_build_expr_internal(cgProcedure *p, Ast *expr) {
 		return cg_const_value(p, type, tv.value);
 	} else if (tv.mode == Addressing_Type) {
 		// NOTE(bill, 2023-01-16): is this correct? I hope so at least
-		return cg_typeid(m, tv.type);
+		return cg_typeid(p, tv.type);
 	}
 
 	switch (expr->kind) {

+ 109 - 12
src/tilde_stmt.cpp

@@ -945,6 +945,57 @@ gb_internal void cg_build_for_stmt(cgProcedure *p, Ast *node) {
 	}
 	tb_inst_set_control(p->func, done);
 }
+
+gb_internal bool cg_switch_stmt_can_be_trivial_jump_table(AstSwitchStmt *ss) {
+	if (ss->tag == nullptr) {
+		return false;
+	}
+	bool is_typeid = false;
+	TypeAndValue tv = type_and_value_of_expr(ss->tag);
+	if (is_type_integer(core_type(tv.type))) {
+		if (type_size_of(tv.type) > 8) {
+			return false;
+		}
+		// okay
+	} else if (is_type_typeid(tv.type)) {
+		// okay
+		is_typeid = true;
+	} else {
+		return false;
+	}
+
+	ast_node(body, BlockStmt, ss->body);
+	for (Ast *clause : body->stmts) {
+		ast_node(cc, CaseClause, clause);
+
+		if (cc->list.count == 0) {
+			continue;
+		}
+
+		for (Ast *expr : cc->list) {
+			expr = unparen_expr(expr);
+			if (is_ast_range(expr)) {
+				return false;
+			}
+			if (expr->tav.mode == Addressing_Type) {
+				GB_ASSERT(is_typeid);
+				continue;
+			}
+			tv = type_and_value_of_expr(expr);
+			if (tv.mode != Addressing_Constant) {
+				return false;
+			}
+			if (!is_type_integer(core_type(tv.type))) {
+				return false;
+			}
+		}
+
+	}
+
+	return true;
+}
+
+
 gb_internal void cg_build_switch_stmt(cgProcedure *p, Ast *node) {
 	ast_node(ss, SwitchStmt, node);
 	cg_scope_open(p, ss->scope);
@@ -970,41 +1021,85 @@ gb_internal void cg_build_switch_stmt(cgProcedure *p, Ast *node) {
 	Scope *  default_scope = nullptr;
 	TB_Node *fall = nullptr;
 
-	auto body_blocks = slice_make<TB_Node *>(permanent_allocator(), body->stmts.count);
+
+	auto body_regions = slice_make<TB_Node *>(permanent_allocator(), body->stmts.count);
 	auto body_scopes = slice_make<Scope *>(permanent_allocator(), body->stmts.count);
 	for_array(i, body->stmts) {
 		Ast *clause = body->stmts[i];
 		ast_node(cc, CaseClause, clause);
 
-		body_blocks[i] = tb_inst_region_with_name(p->func, -1, cc->list.count == 0 ? "switch_default_body" : "switch_case_body");
+		body_regions[i] = tb_inst_region_with_name(p->func, -1, cc->list.count == 0 ? "switch_default_body" : "switch_case_body");
 		body_scopes[i] = cc->scope;
 		if (cc->list.count == 0) {
-			default_block = body_blocks[i];
+			default_block = body_regions[i];
 			default_scope = cc->scope;
 		}
 	}
 
+	bool is_trivial = cg_switch_stmt_can_be_trivial_jump_table(ss);
+	if (is_trivial) {
+		isize key_count = 0;
+		for (Ast *clause : body->stmts) {
+			ast_node(cc, CaseClause, clause);
+			key_count += cc->list.count;
+		}
+		TB_SwitchEntry *keys = gb_alloc_array(temporary_allocator(), TB_SwitchEntry, key_count);
+		isize key_index = 0;
+		for_array(i, body->stmts) {
+			Ast *clause = body->stmts[i];
+			ast_node(cc, CaseClause, clause);
+
+			TB_Node *region = body_regions[i];
+			for (Ast *expr : cc->list) {
+				i64 key = 0;
+				expr = unparen_expr(expr);
+				GB_ASSERT(!is_ast_range(expr));
+				if (expr->tav.mode == Addressing_Type) {
+					GB_PANIC("TODO(bill): cg_typeid as i64");
+					// key = cg_typeid(p, expr->tav.value.value_typeid);
+				} else {
+					auto tv = type_and_value_of_expr(expr);
+					GB_ASSERT(tv.mode == Addressing_Constant);
+					key = exact_value_to_i64(tv.value);
+				}
+				keys[key_index++] = {key, region};
+			}
+		}
+		GB_ASSERT(key_index == key_count);
+
+		TB_Node *end_block = done;
+		if (default_block) {
+			end_block = default_block;
+		}
+
+		TB_DataType dt = cg_data_type(tag.type);
+		GB_ASSERT(tag.kind == cgValue_Value);
+		GB_ASSERT(!TB_IS_VOID_TYPE(dt));
+
+		tb_inst_branch(p->func, dt, tag.node, end_block, key_count, keys);
+	}
+
 	for_array(i, body->stmts) {
 		Ast *clause = body->stmts[i];
 		ast_node(cc, CaseClause, clause);
 
-		TB_Node *body = body_blocks[i];
+		TB_Node *body_region = body_regions[i];
 		Scope *body_scope = body_scopes[i];
 		fall = done;
 		if (i+1 < case_count) {
-			fall = body_blocks[i+1];
+			fall = body_regions[i+1];
 		}
 
 		if (cc->list.count == 0) {
 			// default case
 			default_stmts = cc->stmts;
 			default_fall  = fall;
-			default_block = body;
+			GB_ASSERT(default_block == body_region);
 			continue;
 		}
 
 		TB_Node *next_cond = nullptr;
-		for (Ast *expr : cc->list) {
+		if (!is_trivial) for (Ast *expr : cc->list) {
 			expr = unparen_expr(expr);
 
 			next_cond = tb_inst_region_with_name(p->func, -1, "switch_case_next");
@@ -1028,7 +1123,7 @@ gb_internal void cg_build_switch_stmt(cgProcedure *p, Ast *node) {
 			} else {
 				if (expr->tav.mode == Addressing_Type) {
 					GB_ASSERT(is_type_typeid(tag.type));
-					cgValue e = cg_typeid(p->module, expr->tav.type);
+					cgValue e = cg_typeid(p, expr->tav.type);
 					e = cg_emit_conv(p, e, tag.type);
 					cond = cg_emit_comp(p, Token_CmpEq, tag, e);
 				} else {
@@ -1037,16 +1132,16 @@ gb_internal void cg_build_switch_stmt(cgProcedure *p, Ast *node) {
 			}
 
 			GB_ASSERT(cond.kind == cgValue_Value);
-			tb_inst_if(p->func, cond.node, body, next_cond);
+			tb_inst_if(p->func, cond.node, body_region, next_cond);
 			tb_inst_set_control(p->func, next_cond);
 		}
 
-		tb_inst_set_control(p->func, body);
+		tb_inst_set_control(p->func, body_region);
 
 		cg_push_target_list(p, ss->label, done, nullptr, fall);
 		cg_scope_open(p, body_scope);
 		cg_build_stmt_list(p, cc->stmts);
-		cg_scope_close(p, cgDeferExit_Default, body);
+		cg_scope_close(p, cgDeferExit_Default, body_region);
 		cg_pop_target_list(p);
 
 		tb_inst_goto(p->func, done);
@@ -1054,7 +1149,9 @@ gb_internal void cg_build_switch_stmt(cgProcedure *p, Ast *node) {
 	}
 
 	if (default_block != nullptr) {
-		tb_inst_goto(p->func, default_block);
+		if (!is_trivial) {
+			tb_inst_goto(p->func, default_block);
+		}
 		tb_inst_set_control(p->func, default_block);
 
 		cg_push_target_list(p, ss->label, done, nullptr, default_fall);