Browse Source

Write a `log(n)` fallback for `llvm_vector_reduce_add`

This may be what LLVM does at any rate
gingerBill 3 years ago
parent
commit
3794d2417d
2 changed files with 72 additions and 6 deletions
  1. 11 1
      src/common.cpp
  2. 61 5
      src/llvm_backend_utility.cpp

+ 11 - 1
src/common.cpp

@@ -443,7 +443,17 @@ u64 ceil_log2(u64 x) {
 	return cast(u64)(bit_set_count(x) - 1 - y);
 	return cast(u64)(bit_set_count(x) - 1 - y);
 }
 }
 
 
-
+u32 prev_pow2(u32 n) {
+	if (n == 0) {
+		return 0;
+	}
+	n |= n >> 1;
+	n |= n >> 2;
+	n |= n >> 4;
+	n |= n >> 8;
+	n |= n >> 16;
+	return n - (n >> 1);
+}
 i32 prev_pow2(i32 n) {
 i32 prev_pow2(i32 n) {
 	if (n <= 0) {
 	if (n <= 0) {
 		return 0;
 		return 0;

+ 61 - 5
src/llvm_backend_utility.cpp

@@ -1563,6 +1563,40 @@ LLVMValueRef llvm_vector_broadcast(lbProcedure *p, LLVMValueRef value, unsigned
 	return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, "");
 	return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, "");
 }
 }
 
 
+LLVMValueRef llvm_vector_shuffle_reduction(lbProcedure *p, LLVMValueRef value, LLVMOpcode op_code) {
+	LLVMValueRef v_zero32 = lb_const_int(p->module, t_u32, 0).value;
+	unsigned len = LLVMGetVectorSize(LLVMTypeOf(value));
+	if (len == 1) {
+		return LLVMBuildExtractElement(p->builder, value, v_zero32, "");
+	}
+	GB_ASSERT((len & (len-1)) == 0);
+	
+	for (unsigned i = len; i != 1; i >>= 1) {
+		LLVMValueRef lhs_mask = llvm_mask_iota(p->module, 0, i/2);
+		LLVMValueRef rhs_mask = llvm_mask_iota(p->module, i/2, i);
+		LLVMValueRef lhs = LLVMBuildShuffleVector(p->builder, value, LLVMGetUndef(LLVMTypeOf(value)), lhs_mask, "");
+		LLVMValueRef rhs = LLVMBuildShuffleVector(p->builder, value, LLVMGetUndef(LLVMTypeOf(value)), rhs_mask, "");
+		
+		value = LLVMBuildBinOp(p->builder, op_code, lhs, rhs, "");
+	}
+	return LLVMBuildExtractElement(p->builder, value, v_zero32, "");
+}
+
+LLVMValueRef llvm_vector_expand_to_power_of_two(lbProcedure *p, LLVMValueRef value) {
+	LLVMTypeRef vector_type = LLVMTypeOf(value);
+	unsigned len = LLVMGetVectorSize(vector_type);
+	if (len == 1) {
+		return value;
+	}
+	if ((len & (len-1)) == 0) {
+		return value;
+	}
+	
+	unsigned expanded_len = cast(unsigned)next_pow2(cast(i64)len);
+	LLVMValueRef mask = llvm_mask_iota(p->module, 0, expanded_len);
+	return LLVMBuildShuffleVector(p->builder, value, LLVMConstNull(vector_type), mask, "");
+}
+
 LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
 LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
 	LLVMTypeRef type = LLVMTypeOf(value);
 	LLVMTypeRef type = LLVMTypeOf(value);
 	GB_ASSERT(LLVMGetTypeKind(type) == LLVMVectorTypeKind);
 	GB_ASSERT(LLVMGetTypeKind(type) == LLVMVectorTypeKind);
@@ -1571,11 +1605,11 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
 	if (len == 0) {
 	if (len == 0) {
 		return LLVMConstNull(type);
 		return LLVMConstNull(type);
 	}
 	}
-	
+
 	char const *name = nullptr;
 	char const *name = nullptr;
 	i32 value_offset = 0;
 	i32 value_offset = 0;
 	i32 value_count  = 0;
 	i32 value_count  = 0;
-	
+
 	switch (LLVMGetTypeKind(elem)) {
 	switch (LLVMGetTypeKind(elem)) {
 	case LLVMHalfTypeKind:
 	case LLVMHalfTypeKind:
 	case LLVMFloatTypeKind:
 	case LLVMFloatTypeKind:
@@ -1593,7 +1627,7 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
 		GB_PANIC("invalid vector type %s", LLVMPrintTypeToString(type));
 		GB_PANIC("invalid vector type %s", LLVMPrintTypeToString(type));
 		break;
 		break;
 	}
 	}
-	
+
 	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
 	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
 	if (id != 0) {
 	if (id != 0) {
 		LLVMTypeRef types[1] = {};
 		LLVMTypeRef types[1] = {};
@@ -1606,9 +1640,9 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
 		LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, "");
 		LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, "");
 		return call;
 		return call;
 	}
 	}
-	
+
 	// Manual reduce
 	// Manual reduce
-	
+#if 0
 	LLVMValueRef sum = LLVMBuildExtractElement(p->builder, value, lb_const_int(p->module, t_u32, 0).value, "");
 	LLVMValueRef sum = LLVMBuildExtractElement(p->builder, value, lb_const_int(p->module, t_u32, 0).value, "");
 	for (unsigned i = 0; i < len; i++) {
 	for (unsigned i = 0; i < len; i++) {
 		LLVMValueRef val = LLVMBuildExtractElement(p->builder, value, lb_const_int(p->module, t_u32, i).value, "");
 		LLVMValueRef val = LLVMBuildExtractElement(p->builder, value, lb_const_int(p->module, t_u32, i).value, "");
@@ -1619,6 +1653,28 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
 		}
 		}
 	}
 	}
 	return sum;
 	return sum;
+#else
+	LLVMOpcode op_code = LLVMFAdd;
+	if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) {
+		op_code = LLVMAdd;
+	}
+
+	unsigned len_pow_2 = prev_pow2(len);
+	if (len_pow_2 == len) {
+		return llvm_vector_shuffle_reduction(p, value, op_code);
+	} else {
+		LLVMValueRef lower_mask = llvm_mask_iota(p->module, 0, len_pow_2);
+		LLVMValueRef upper_mask = llvm_mask_iota(p->module, len_pow_2, len-len_pow_2);
+		LLVMValueRef lower = LLVMBuildShuffleVector(p->builder, value, LLVMGetUndef(LLVMTypeOf(value)), lower_mask, "");
+		LLVMValueRef upper = LLVMBuildShuffleVector(p->builder, value, LLVMGetUndef(LLVMTypeOf(value)), upper_mask, "");
+		upper = llvm_vector_expand_to_power_of_two(p, upper);
+
+		LLVMValueRef lower_reduced = llvm_vector_shuffle_reduction(p, lower, op_code);
+		LLVMValueRef upper_reduced = llvm_vector_shuffle_reduction(p, upper, op_code); 
+
+		return LLVMBuildBinOp(p->builder, op_code, lower_reduced, upper_reduced, "");
+	}
+#endif
 }
 }
 
 
 LLVMValueRef llvm_vector_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
 LLVMValueRef llvm_vector_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {