Browse Source

Add `llvm_vector_reduce_add`

gingerBill 3 years ago
parent
commit
1bfbed0e02
2 changed files with 43 additions and 1 deletions
  1. 2 1
      src/llvm_backend_expr.cpp
  2. 41 0
      src/llvm_backend_utility.cpp

+ 2 - 1
src/llvm_backend_expr.cpp

@@ -619,9 +619,10 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 	Type *elem = mt->Matrix.elem;
 	Type *elem = mt->Matrix.elem;
 	LLVMTypeRef elem_type = lb_type(p->module, elem);
 	LLVMTypeRef elem_type = lb_type(p->module, elem);
 	
 	
-	unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 	
 	
 	if (lb_matrix_elem_simple(mt)) {
 	if (lb_matrix_elem_simple(mt)) {
+		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
+		
 		unsigned row_count = cast(unsigned)mt->Matrix.row_count; gb_unused(row_count);
 		unsigned row_count = cast(unsigned)mt->Matrix.row_count; gb_unused(row_count);
 		unsigned column_count = cast(unsigned)mt->Matrix.column_count;
 		unsigned column_count = cast(unsigned)mt->Matrix.column_count;
 		auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
 		auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);

+ 41 - 0
src/llvm_backend_utility.cpp

@@ -1544,4 +1544,45 @@ LLVMValueRef llvm_splat(lbProcedure *p, LLVMValueRef value, unsigned count) {
 	}
 	}
 	LLVMValueRef mask = llvm_mask_zero(p->module, count);
 	LLVMValueRef mask = llvm_mask_zero(p->module, count);
 	return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, "");
 	return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, "");
+}
+
+LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
+	LLVMTypeRef type = LLVMTypeOf(value);
+	GB_ASSERT(LLVMGetTypeKind(type) == LLVMVectorTypeKind);
+	LLVMTypeRef elem = LLVMGetElementType(type);
+	
+	char const *name = nullptr;
+	i32 value_offset = 0;
+	i32 value_count  = 0;
+	
+	switch (LLVMGetTypeKind(elem)) {
+	case LLVMHalfTypeKind:
+	case LLVMFloatTypeKind:
+	case LLVMDoubleTypeKind:
+		name = "llvm.vector.reduce.fadd";
+		value_offset = 0;
+		value_count = 2;
+		break;
+	case LLVMIntegerTypeKind:
+		name = "llvm.vector.reduce.add";
+		value_offset = 1;
+		value_count = 1;
+		break;
+	default:
+		GB_PANIC("invalid vector type %s", LLVMPrintTypeToString(type));
+		break;
+	}
+	
+	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+	
+	LLVMTypeRef types[1] = {};
+	types[0] = elem;
+	
+	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types));
+	LLVMValueRef values[2] = {};
+	values[0] = LLVMConstNull(elem);
+	values[1] = value;
+	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, "");
+	return call;
 }
 }