Parcourir la source

Enforce an `icmp` when casting to `i1` to correct behaviour for booleans which are not 0 or 1

gingerBill il y a 2 ans
Parent
commit
46bb9bc5c7
1 fichiers modifiés avec 32 ajouts et 45 suppressions
  1. 32 45
      src/llvm_backend_expr.cpp

+ 32 - 45
src/llvm_backend_expr.cpp

@@ -1,20 +1,18 @@
 gb_internal lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise);
 
-gb_internal lbValue lb_emit_logical_binary_expr(lbProcedure *p, TokenKind op, Ast *left, Ast *right, Type *type) {
+gb_internal lbValue lb_emit_logical_binary_expr(lbProcedure *p, TokenKind op, Ast *left, Ast *right, Type *final_type) {
 	lbModule *m = p->module;
 
 	lbBlock *rhs  = lb_create_block(p, "logical.cmp.rhs");
 	lbBlock *done = lb_create_block(p, "logical.cmp.done");
 
-	type = default_type(type);
-
 	lbValue short_circuit = {};
 	if (op == Token_CmpAnd) {
 		lb_build_cond(p, left, rhs, done);
-		short_circuit = lb_const_bool(m, type, false);
+		short_circuit = lb_const_bool(m, t_llvm_bool, false);
 	} else if (op == Token_CmpOr) {
 		lb_build_cond(p, left, done, rhs);
-		short_circuit = lb_const_bool(m, type, true);
+		short_circuit = lb_const_bool(m, t_llvm_bool, true);
 	}
 
 	if (rhs->preds.count == 0) {
@@ -25,7 +23,7 @@ gb_internal lbValue lb_emit_logical_binary_expr(lbProcedure *p, TokenKind op, As
 	if (done->preds.count == 0) {
 		lb_start_block(p, rhs);
 		if (lb_is_expr_untyped_const(right)) {
-			return lb_expr_untyped_const_to_typed(m, right, type);
+			return lb_expr_untyped_const_to_typed(m, right, default_type(final_type));
 		}
 		return lb_build_expr(p, right);
 	}
@@ -43,10 +41,11 @@ gb_internal lbValue lb_emit_logical_binary_expr(lbProcedure *p, TokenKind op, As
 	lb_start_block(p, rhs);
 	lbValue edge = {};
 	if (lb_is_expr_untyped_const(right)) {
-		edge = lb_expr_untyped_const_to_typed(m, right, type);
+		edge = lb_expr_untyped_const_to_typed(m, right, t_llvm_bool);
 	} else {
-		edge = lb_build_expr(p, right);
+		edge = lb_emit_conv(p, lb_build_expr(p, right), t_llvm_bool);
 	}
+	GB_ASSERT(edge.type == t_llvm_bool);
 
 	incoming_values[done->preds.count] = edge.value;
 	incoming_blocks[done->preds.count] = p->curr_block->block;
@@ -54,7 +53,7 @@ gb_internal lbValue lb_emit_logical_binary_expr(lbProcedure *p, TokenKind op, As
 	lb_emit_jump(p, done);
 	lb_start_block(p, done);	
 	
-	LLVMTypeRef dst_type = lb_type(m, type);
+	LLVMTypeRef dst_type = lb_type(m, t_llvm_bool);
 	LLVMValueRef phi = nullptr;
 	
 	GB_ASSERT(incoming_values.count == incoming_blocks.count);
@@ -67,48 +66,36 @@ gb_internal lbValue lb_emit_logical_binary_expr(lbProcedure *p, TokenKind op, As
 			break;
 		}
 	}
+
+	lbValue res = {};
 	
 	if (phi_type == nullptr) {
 		phi = LLVMBuildPhi(p->builder, dst_type, "");
 		LLVMAddIncoming(phi, incoming_values.data, incoming_blocks.data, cast(unsigned)incoming_values.count);
-		lbValue res = {};
-		res.type = type;
 		res.value = phi;
-		return res;
-	}
-	
-	for_array(i, incoming_values) {
-		LLVMValueRef incoming_value = incoming_values[i];
-		LLVMTypeRef incoming_type = LLVMTypeOf(incoming_value);
-		
-		if (phi_type != incoming_type) {
-			GB_ASSERT_MSG(LLVMIsConstant(incoming_value), "%s vs %s", LLVMPrintTypeToString(phi_type), LLVMPrintTypeToString(incoming_type));
-			bool ok = !!LLVMConstIntGetZExtValue(incoming_value);
-			incoming_values[i] = LLVMConstInt(phi_type, ok, false);
+		res.type = t_llvm_bool;
+	} else {
+		for_array(i, incoming_values) {
+			LLVMValueRef incoming_value = incoming_values[i];
+			LLVMTypeRef incoming_type = LLVMTypeOf(incoming_value);
+
+			if (phi_type != incoming_type) {
+				GB_ASSERT_MSG(LLVMIsConstant(incoming_value), "%s vs %s", LLVMPrintTypeToString(phi_type), LLVMPrintTypeToString(incoming_type));
+				bool ok = !!LLVMConstIntGetZExtValue(incoming_value);
+				incoming_values[i] = LLVMConstInt(phi_type, ok, false);
+			}
+
 		}
-		
+
+		// NOTE(bill): this now only uses i1 for the logic to prevent issues with corrupted booleans which are not of value 0 or 1 (e.g. 2)
+		// Doing this may produce slightly worse code as a result but it will be correct behaviour
+
+		phi = LLVMBuildPhi(p->builder, phi_type, "");
+		LLVMAddIncoming(phi, incoming_values.data, incoming_blocks.data, cast(unsigned)incoming_values.count);
+		res.value = phi;
+		res.type = t_llvm_bool;
 	}
-	
-	phi = LLVMBuildPhi(p->builder, phi_type, "");
-	LLVMAddIncoming(phi, incoming_values.data, incoming_blocks.data, cast(unsigned)incoming_values.count);
-	
-	LLVMTypeRef i1 = LLVMInt1TypeInContext(m->ctx);
-	if ((phi_type == i1) ^ (dst_type == i1)) {
-		if (phi_type == i1) {
-			phi = LLVMBuildZExt(p->builder, phi, dst_type, "");
-		} else {
-			phi = LLVMBuildTruncOrBitCast(p->builder, phi, dst_type, "");
-		}
-	} else if (lb_sizeof(phi_type) < lb_sizeof(dst_type)) {
-		phi = LLVMBuildZExt(p->builder, phi, dst_type, "");
-	} else {
-		phi = LLVMBuildTruncOrBitCast(p->builder, phi, dst_type, "");	
-	}		
-	
-	lbValue res = {};
-	res.type = type;
-	res.value = phi;
-	return res;
+	return lb_emit_conv(p, res, default_type(final_type));
 }
 
 
@@ -1566,7 +1553,7 @@ gb_internal lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) {
 	// bool <-> llvm bool
 	if (is_type_boolean(src) && dst == t_llvm_bool) {
 		lbValue res = {};
-		res.value = LLVMBuildTrunc(p->builder, value.value, lb_type(m, dst), "");
+		res.value = LLVMBuildICmp(p->builder, LLVMIntNE, value.value, LLVMConstNull(lb_type(m, src)), "");
 		res.type = t;
 		return res;
 	}