Browse Source

Make `lb_emit_matrix_mul` SIMD if possible

gingerBill 3 years ago
parent
commit
d0d9a3a4f4
2 changed files with 110 additions and 63 deletions
  1. 82 62
      src/llvm_backend_expr.cpp
  2. 28 1
      src/llvm_backend_utility.cpp

+ 82 - 62
src/llvm_backend_expr.cpp

@@ -557,6 +557,20 @@ lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type)
 }
 }
 
 
 
 
+LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
+	Type *mt = base_type(matrix.type);
+	GB_ASSERT(mt->kind == Type_Matrix);
+	LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem);
+	
+	unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
+	LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
+	
+	LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value;
+	LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), "");
+	LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
+	return matrix_vector;
+}
+
 lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
 lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
 	Type *xt = base_type(lhs.type);
 	Type *xt = base_type(lhs.type);
 	Type *yt = base_type(rhs.type);
 	Type *yt = base_type(rhs.type);
@@ -567,31 +581,72 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
 	GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
 	GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
 	GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
 	GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
 		
 		
+	Type *elem = xt->Matrix.elem;
+	
+	unsigned outer_rows    = cast(unsigned)xt->Matrix.row_count;
+	unsigned inner         = cast(unsigned)xt->Matrix.column_count;
+	unsigned outer_columns = cast(unsigned)yt->Matrix.column_count;
+		
 	if (lb_matrix_elem_simple(xt)) {
 	if (lb_matrix_elem_simple(xt)) {
-		// TODO(bill): SIMD version
+		unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
+		unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
+		
+		auto x_rows    = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
+		auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
+		
+		
+		LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs);
+		LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs);
+		
+		for (unsigned i = 0; i < outer_rows; i++) {
+			auto mask_elems = slice_make<LLVMValueRef>(temporary_allocator(), inner);
+			for (unsigned j = 0; j < inner; j++) {
+				unsigned offset = x_stride*j + i;
+				mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
+			}
+			
+			// transpose mask
+			LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner);
+			LLVMValueRef row = LLVMBuildShuffleVector(p->builder, x_vector, LLVMGetUndef(LLVMTypeOf(x_vector)), mask, "");
+			x_rows[i] = row;
+		}
+		
+		for (unsigned i = 0; i < outer_columns; i++) {
+			LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner);
+			LLVMValueRef column = LLVMBuildShuffleVector(p->builder, y_vector, LLVMGetUndef(LLVMTypeOf(y_vector)), mask, "");
+			y_columns[i] = column;
+		}
+
+		
+		
+		lbAddr res = lb_add_local_generated(p, type, true);
+		for_array(i, x_rows) {
+			LLVMValueRef x_row = x_rows[i];
+			for_array(j, y_columns) {
+				LLVMValueRef y_column = y_columns[j];
+				LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column);
+				lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
+				LLVMBuildStore(p->builder, elem, dst.value);
+			}
+		}		
+		return lb_addr_load(p, res);
 	}
 	}
 	
 	
 	{
 	{
-		Type *elem = xt->Matrix.elem;	
-		
 		lbAddr res = lb_add_local_generated(p, type, true);
 		lbAddr res = lb_add_local_generated(p, type, true);
 		
 		
-		i64 outer_rows    = xt->Matrix.row_count;
-		i64 inner         = xt->Matrix.column_count;
-		i64 outer_columns = yt->Matrix.column_count;
-		
 		auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner);
 		auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner);
 		
 		
-		for (i64 j = 0; j < outer_columns; j++) {
-			for (i64 i = 0; i < outer_rows; i++) {
+		for (unsigned j = 0; j < outer_columns; j++) {
+			for (unsigned i = 0; i < outer_rows; i++) {
 				lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
 				lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
-				for (i64 k = 0; k < inner; k++) {
+				for (unsigned k = 0; k < inner; k++) {
 					inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k);
 					inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k);
 					inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j);
 					inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j);
 				}
 				}
 				
 				
-				lbValue sum = lb_emit_load(p, dst);
-				for (i64 k = 0; k < inner; k++) {
+				lbValue sum = lb_const_nil(p->module, elem);
+				for (unsigned k = 0; k < inner; k++) {
 					lbValue a = inners[k][0];
 					lbValue a = inners[k][0];
 					lbValue b = inners[k][1];
 					lbValue b = inners[k][1];
 					sum = lb_emit_mul_add(p, a, b, sum, elem);
 					sum = lb_emit_mul_add(p, a, b, sum, elem);
@@ -617,7 +672,6 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 	GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
 	GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
 	
 	
 	Type *elem = mt->Matrix.elem;
 	Type *elem = mt->Matrix.elem;
-	LLVMTypeRef elem_type = lb_type(p->module, elem);
 	
 	
 	if (lb_matrix_elem_simple(mt)) {
 	if (lb_matrix_elem_simple(mt)) {
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
@@ -627,13 +681,7 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 		auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
 		auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
 		auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
 		auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
 		
 		
-		unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
-		LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
-		
-		LLVMValueRef lhs_ptr = lb_address_from_load_or_generate_local(p, lhs).value;
-		LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, lhs_ptr, LLVMPointerType(total_matrix_type, 0), "");
-		LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
-		
+		LLVMValueRef matrix_vector = lb_matrix_to_vector(p, lhs);		
 		
 		
 		for (unsigned column_index = 0; column_index < column_count; column_index++) {
 		for (unsigned column_index = 0; column_index < column_count; column_index++) {
 			LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count);
 			LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count);
@@ -650,23 +698,12 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 		GB_ASSERT(column_count > 0);
 		GB_ASSERT(column_count > 0);
 		
 		
 		LLVMValueRef vector = nullptr;
 		LLVMValueRef vector = nullptr;
-		if (is_type_float(elem)) {
-			for (i64 i = 0; i < column_count; i++) {
-				LLVMValueRef product = LLVMBuildFMul(p->builder, m_columns[i], v_rows[i], "");
-				if (i == 0) {
-					vector = product;
-				} else {
-					vector = LLVMBuildFAdd(p->builder, vector, product, "");
-				}
-			}
-		} else {
-			for (i64 i = 0; i < column_count; i++) {
-				LLVMValueRef product = LLVMBuildMul(p->builder, m_columns[i], v_rows[i], "");
-				if (i == 0) {
-					vector = product;
-				} else {
-					vector = LLVMBuildAdd(p->builder, vector, product, "");
-				}
+		for (i64 i = 0; i < column_count; i++) {
+			LLVMValueRef product = llvm_vector_mul(p, m_columns[i], v_rows[i]);
+			if (i == 0) {
+				vector = product;
+			} else {
+				vector = llvm_vector_add(p, vector, product);
 			}
 			}
 		}
 		}
 
 
@@ -712,7 +749,6 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 	GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
 	GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
 	
 	
 	Type *elem = mt->Matrix.elem;
 	Type *elem = mt->Matrix.elem;
-	LLVMTypeRef elem_type = lb_type(p->module, elem);
 	
 	
 	if (lb_matrix_elem_simple(mt)) {
 	if (lb_matrix_elem_simple(mt)) {
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
@@ -722,13 +758,8 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 		auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
 		auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
 		auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
 		auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
 		
 		
-		unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
-		LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
-		
-		LLVMValueRef matrix_ptr = lb_address_from_load_or_generate_local(p, rhs).value;
-		LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, matrix_ptr, LLVMPointerType(total_matrix_type, 0), "");
-		LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
-		
+		LLVMValueRef matrix_vector = lb_matrix_to_vector(p, rhs);
+
 		for (unsigned row_index = 0; row_index < row_count; row_index++) {
 		for (unsigned row_index = 0; row_index < row_count; row_index++) {
 			auto mask_elems = slice_make<LLVMValueRef>(temporary_allocator(), column_count);
 			auto mask_elems = slice_make<LLVMValueRef>(temporary_allocator(), column_count);
 			for (unsigned column_index = 0; column_index < column_count; column_index++) {
 			for (unsigned column_index = 0; column_index < column_count; column_index++) {
@@ -751,23 +782,12 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 		GB_ASSERT(row_count > 0);
 		GB_ASSERT(row_count > 0);
 		
 		
 		LLVMValueRef vector = nullptr;
 		LLVMValueRef vector = nullptr;
-		if (is_type_float(elem)) {
-			for (i64 i = 0; i < row_count; i++) {
-				LLVMValueRef product = LLVMBuildFMul(p->builder, v_rows[i], m_columns[i], "");
-				if (i == 0) {
-					vector = product;
-				} else {
-					vector = LLVMBuildFAdd(p->builder, vector, product, "");
-				}
-			}
-		} else {
-			for (i64 i = 0; i < row_count; i++) {
-				LLVMValueRef product = LLVMBuildMul(p->builder, v_rows[i], m_columns[i], "");
-				if (i == 0) {
-					vector = product;
-				} else {
-					vector = LLVMBuildAdd(p->builder, vector, product, "");
-				}
+		for (i64 i = 0; i < row_count; i++) {
+			LLVMValueRef product = llvm_vector_mul(p, v_rows[i], m_columns[i]);
+			if (i == 0) {
+				vector = product;
+			} else {
+				vector = llvm_vector_add(p, vector, product);
 			}
 			}
 		}
 		}
 
 

+ 28 - 1
src/llvm_backend_utility.cpp

@@ -1577,7 +1577,7 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
 	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
 	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
 	
 	
 	LLVMTypeRef types[1] = {};
 	LLVMTypeRef types[1] = {};
-	types[0] = elem;
+	types[0] = type;
 	
 	
 	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types));
 	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types));
 	LLVMValueRef values[2] = {};
 	LLVMValueRef values[2] = {};
@@ -1585,4 +1585,31 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
 	values[1] = value;
 	values[1] = value;
 	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, "");
 	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, "");
 	return call;
 	return call;
+}
+
+LLVMValueRef llvm_vector_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
+	GB_ASSERT(LLVMTypeOf(a) == LLVMTypeOf(b));
+	
+	LLVMTypeRef elem = LLVMGetElementType(LLVMTypeOf(a));
+	
+	if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) {
+		return LLVMBuildAdd(p->builder, a, b, "");
+	}
+	return LLVMBuildFAdd(p->builder, a, b, "");
+}
+
+LLVMValueRef llvm_vector_mul(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
+	GB_ASSERT(LLVMTypeOf(a) == LLVMTypeOf(b));
+	
+	LLVMTypeRef elem = LLVMGetElementType(LLVMTypeOf(a));
+	
+	if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) {
+		return LLVMBuildMul(p->builder, a, b, "");
+	}
+	return LLVMBuildFMul(p->builder, a, b, "");
+}
+
+
+LLVMValueRef llvm_vector_dot(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
+	return llvm_vector_reduce_add(p, llvm_vector_mul(p, a, b));
 }
 }