Browse Source

Make `lb_emit_matrix_mul_vector` use SIMD if possible

gingerBill 3 years ago
parent
commit
0fd525d778
2 changed files with 97 additions and 3 deletions
  1. 65 3
      src/llvm_backend_expr.cpp
  2. 32 0
      src/llvm_backend_utility.cpp

+ 65 - 3
src/llvm_backend_expr.cpp

@@ -567,11 +567,10 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
 	GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
 	GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
 	GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
 	GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
 		
 		
-	if (!lb_matrix_elem_simple(xt)) {
-		goto slow_form;
+	if (lb_matrix_elem_simple(xt)) {
+		// TODO(bill): SIMD version
 	}
 	}
 	
 	
-slow_form:
 	{
 	{
 		Type *elem = xt->Matrix.elem;	
 		Type *elem = xt->Matrix.elem;	
 		
 		
@@ -618,6 +617,69 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 	GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
 	GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
 	
 	
 	Type *elem = mt->Matrix.elem;
 	Type *elem = mt->Matrix.elem;
+	LLVMTypeRef elem_type = lb_type(p->module, elem);
+	
+	unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
+	
+	if (lb_matrix_elem_simple(mt)) {
+		unsigned row_count = cast(unsigned)mt->Matrix.row_count; gb_unused(row_count);
+		unsigned column_count = cast(unsigned)mt->Matrix.column_count;
+		auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
+		auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
+		
+		unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
+		LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
+		
+		LLVMValueRef lhs_ptr = lb_address_from_load_or_generate_local(p, lhs).value;
+		LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, lhs_ptr, LLVMPointerType(total_matrix_type, 0), "");
+		LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
+		
+		
+		for (unsigned column_index = 0; column_index < column_count; column_index++) {
+			LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count);
+			LLVMValueRef column = LLVMBuildShuffleVector(p->builder, matrix_vector, LLVMGetUndef(LLVMTypeOf(matrix_vector)), mask, "");
+			m_columns[column_index] = column;
+		}
+		
+		for (unsigned row_index = 0; row_index < column_count; row_index++) {
+			LLVMValueRef value = lb_emit_struct_ev(p, rhs, row_index).value;
+			LLVMValueRef row = llvm_splat(p, value, row_count);
+			v_rows[row_index] = row;
+		}
+		
+		GB_ASSERT(column_count > 0);
+		
+		LLVMValueRef vector = nullptr;
+		if (is_type_float(elem)) {
+			for (i64 i = 0; i < column_count; i++) {
+				LLVMValueRef product = LLVMBuildFMul(p->builder, m_columns[i], v_rows[i], "");
+				if (i == 0) {
+					vector = product;
+				} else {
+					vector = LLVMBuildFAdd(p->builder, vector, product, "");
+				}
+			}
+		} else {
+			for (i64 i = 0; i < column_count; i++) {
+				LLVMValueRef product = LLVMBuildMul(p->builder, m_columns[i], v_rows[i], "");
+				if (i == 0) {
+					vector = product;
+				} else {
+					vector = LLVMBuildAdd(p->builder, vector, product, "");
+				}
+			}
+		}
+
+		lbAddr res = lb_add_local_generated(p, type, true);
+		LLVMValueRef res_ptr = res.addr.value;
+		unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
+		LLVMSetAlignment(res_ptr, alignment);
+		
+		res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
+		LLVMBuildStore(p->builder, vector, res_ptr);
+		
+		return lb_addr_load(p, res);
+	}
 	
 	
 	lbAddr res = lb_add_local_generated(p, type, true);
 	lbAddr res = lb_add_local_generated(p, type, true);
 	
 	

+ 32 - 0
src/llvm_backend_utility.cpp

@@ -1512,4 +1512,36 @@ lbValue lb_emit_mul_add(lbProcedure *p, lbValue a, lbValue b, lbValue c, Type *t
 		lbValue y = lb_emit_arith(p, Token_Add, x, c, t);
 		lbValue y = lb_emit_arith(p, Token_Add, x, c, t);
 		return y;
 		return y;
 	}
 	}
+}
+
+LLVMValueRef llvm_mask_iota(lbModule *m, unsigned start, unsigned count) {
+	auto iota = slice_make<LLVMValueRef>(temporary_allocator(), count);
+	for (unsigned i = 0; i < count; i++) {
+		iota[i] = lb_const_int(m, t_u32, start+i).value;
+	}
+	return LLVMConstVector(iota.data, count);
+}
+
+LLVMValueRef llvm_mask_zero(lbModule *m, unsigned count) {
+	return LLVMConstNull(LLVMVectorType(lb_type(m, t_u32), count));
+}
+
+LLVMValueRef llvm_splat(lbProcedure *p, LLVMValueRef value, unsigned count) {
+	GB_ASSERT(count > 0);
+	if (LLVMIsConstant(value)) {
+		LLVMValueRef single = LLVMConstVector(&value, 1);
+		if (count == 1) {
+			return single;
+		}
+		LLVMValueRef mask = llvm_mask_zero(p->module, count);
+		return LLVMConstShuffleVector(single, LLVMGetUndef(LLVMTypeOf(single)), mask);
+	}
+	
+	LLVMTypeRef single_type = LLVMVectorType(LLVMTypeOf(value), 1);
+	LLVMValueRef single = LLVMBuildBitCast(p->builder, value, single_type, "");
+	if (count == 1) {
+		return single;
+	}
+	LLVMValueRef mask = llvm_mask_zero(p->module, count);
+	return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, "");
 }
 }