|
@@ -557,6 +557,20 @@ lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type)
|
|
|
}
|
|
|
|
|
|
|
|
|
+LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
|
|
|
+ Type *mt = base_type(matrix.type);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+ LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem);
|
|
|
+
|
|
|
+ unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
+ LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
|
|
|
+
|
|
|
+ LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value;
|
|
|
+ LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), "");
|
|
|
+ LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
|
|
|
+ return matrix_vector;
|
|
|
+}
|
|
|
+
|
|
|
lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
|
|
|
Type *xt = base_type(lhs.type);
|
|
|
Type *yt = base_type(rhs.type);
|
|
@@ -567,31 +581,72 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
|
|
|
GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
|
|
|
GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
|
|
|
|
|
|
+ Type *elem = xt->Matrix.elem;
|
|
|
+
|
|
|
+ unsigned outer_rows = cast(unsigned)xt->Matrix.row_count;
|
|
|
+ unsigned inner = cast(unsigned)xt->Matrix.column_count;
|
|
|
+ unsigned outer_columns = cast(unsigned)yt->Matrix.column_count;
|
|
|
+
|
|
|
if (lb_matrix_elem_simple(xt)) {
|
|
|
- // TODO(bill): SIMD version
|
|
|
+ unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
|
|
|
+ unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
|
|
|
+
|
|
|
+ auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
|
|
|
+ auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
|
|
|
+
|
|
|
+
|
|
|
+ LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs);
|
|
|
+ LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs);
|
|
|
+
|
|
|
+ for (unsigned i = 0; i < outer_rows; i++) {
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(temporary_allocator(), inner);
|
|
|
+ for (unsigned j = 0; j < inner; j++) {
|
|
|
+ unsigned offset = x_stride*j + i;
|
|
|
+ mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
|
|
|
+ }
|
|
|
+
|
|
|
+ // transpose mask
|
|
|
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner);
|
|
|
+ LLVMValueRef row = LLVMBuildShuffleVector(p->builder, x_vector, LLVMGetUndef(LLVMTypeOf(x_vector)), mask, "");
|
|
|
+ x_rows[i] = row;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (unsigned i = 0; i < outer_columns; i++) {
|
|
|
+ LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner);
|
|
|
+ LLVMValueRef column = LLVMBuildShuffleVector(p->builder, y_vector, LLVMGetUndef(LLVMTypeOf(y_vector)), mask, "");
|
|
|
+ y_columns[i] = column;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+ for_array(i, x_rows) {
|
|
|
+ LLVMValueRef x_row = x_rows[i];
|
|
|
+ for_array(j, y_columns) {
|
|
|
+ LLVMValueRef y_column = y_columns[j];
|
|
|
+ LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column);
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
|
+ LLVMBuildStore(p->builder, elem, dst.value);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return lb_addr_load(p, res);
|
|
|
}
|
|
|
|
|
|
{
|
|
|
- Type *elem = xt->Matrix.elem;
|
|
|
-
|
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
|
|
- i64 outer_rows = xt->Matrix.row_count;
|
|
|
- i64 inner = xt->Matrix.column_count;
|
|
|
- i64 outer_columns = yt->Matrix.column_count;
|
|
|
-
|
|
|
auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner);
|
|
|
|
|
|
- for (i64 j = 0; j < outer_columns; j++) {
|
|
|
- for (i64 i = 0; i < outer_rows; i++) {
|
|
|
+ for (unsigned j = 0; j < outer_columns; j++) {
|
|
|
+ for (unsigned i = 0; i < outer_rows; i++) {
|
|
|
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
|
- for (i64 k = 0; k < inner; k++) {
|
|
|
+ for (unsigned k = 0; k < inner; k++) {
|
|
|
inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k);
|
|
|
inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j);
|
|
|
}
|
|
|
|
|
|
- lbValue sum = lb_emit_load(p, dst);
|
|
|
- for (i64 k = 0; k < inner; k++) {
|
|
|
+ lbValue sum = lb_const_nil(p->module, elem);
|
|
|
+ for (unsigned k = 0; k < inner; k++) {
|
|
|
lbValue a = inners[k][0];
|
|
|
lbValue b = inners[k][1];
|
|
|
sum = lb_emit_mul_add(p, a, b, sum, elem);
|
|
@@ -617,7 +672,6 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
|
|
|
|
|
|
Type *elem = mt->Matrix.elem;
|
|
|
- LLVMTypeRef elem_type = lb_type(p->module, elem);
|
|
|
|
|
|
if (lb_matrix_elem_simple(mt)) {
|
|
|
unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
@@ -627,13 +681,7 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
|
|
|
- unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
- LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
|
|
|
-
|
|
|
- LLVMValueRef lhs_ptr = lb_address_from_load_or_generate_local(p, lhs).value;
|
|
|
- LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, lhs_ptr, LLVMPointerType(total_matrix_type, 0), "");
|
|
|
- LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
|
|
|
-
|
|
|
+ LLVMValueRef matrix_vector = lb_matrix_to_vector(p, lhs);
|
|
|
|
|
|
for (unsigned column_index = 0; column_index < column_count; column_index++) {
|
|
|
LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count);
|
|
@@ -650,23 +698,12 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
GB_ASSERT(column_count > 0);
|
|
|
|
|
|
LLVMValueRef vector = nullptr;
|
|
|
- if (is_type_float(elem)) {
|
|
|
- for (i64 i = 0; i < column_count; i++) {
|
|
|
- LLVMValueRef product = LLVMBuildFMul(p->builder, m_columns[i], v_rows[i], "");
|
|
|
- if (i == 0) {
|
|
|
- vector = product;
|
|
|
- } else {
|
|
|
- vector = LLVMBuildFAdd(p->builder, vector, product, "");
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- for (i64 i = 0; i < column_count; i++) {
|
|
|
- LLVMValueRef product = LLVMBuildMul(p->builder, m_columns[i], v_rows[i], "");
|
|
|
- if (i == 0) {
|
|
|
- vector = product;
|
|
|
- } else {
|
|
|
- vector = LLVMBuildAdd(p->builder, vector, product, "");
|
|
|
- }
|
|
|
+ for (i64 i = 0; i < column_count; i++) {
|
|
|
+ LLVMValueRef product = llvm_vector_mul(p, m_columns[i], v_rows[i]);
|
|
|
+ if (i == 0) {
|
|
|
+ vector = product;
|
|
|
+ } else {
|
|
|
+ vector = llvm_vector_add(p, vector, product);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -712,7 +749,6 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
|
|
|
|
|
|
Type *elem = mt->Matrix.elem;
|
|
|
- LLVMTypeRef elem_type = lb_type(p->module, elem);
|
|
|
|
|
|
if (lb_matrix_elem_simple(mt)) {
|
|
|
unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
@@ -722,13 +758,8 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
|
|
auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
|
|
|
|
|
- unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
- LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
|
|
|
-
|
|
|
- LLVMValueRef matrix_ptr = lb_address_from_load_or_generate_local(p, rhs).value;
|
|
|
- LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, matrix_ptr, LLVMPointerType(total_matrix_type, 0), "");
|
|
|
- LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
|
|
|
-
|
|
|
+ LLVMValueRef matrix_vector = lb_matrix_to_vector(p, rhs);
|
|
|
+
|
|
|
for (unsigned row_index = 0; row_index < row_count; row_index++) {
|
|
|
auto mask_elems = slice_make<LLVMValueRef>(temporary_allocator(), column_count);
|
|
|
for (unsigned column_index = 0; column_index < column_count; column_index++) {
|
|
@@ -751,23 +782,12 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
GB_ASSERT(row_count > 0);
|
|
|
|
|
|
LLVMValueRef vector = nullptr;
|
|
|
- if (is_type_float(elem)) {
|
|
|
- for (i64 i = 0; i < row_count; i++) {
|
|
|
- LLVMValueRef product = LLVMBuildFMul(p->builder, v_rows[i], m_columns[i], "");
|
|
|
- if (i == 0) {
|
|
|
- vector = product;
|
|
|
- } else {
|
|
|
- vector = LLVMBuildFAdd(p->builder, vector, product, "");
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- for (i64 i = 0; i < row_count; i++) {
|
|
|
- LLVMValueRef product = LLVMBuildMul(p->builder, v_rows[i], m_columns[i], "");
|
|
|
- if (i == 0) {
|
|
|
- vector = product;
|
|
|
- } else {
|
|
|
- vector = LLVMBuildAdd(p->builder, vector, product, "");
|
|
|
- }
|
|
|
+ for (i64 i = 0; i < row_count; i++) {
|
|
|
+ LLVMValueRef product = llvm_vector_mul(p, v_rows[i], m_columns[i]);
|
|
|
+ if (i == 0) {
|
|
|
+ vector = product;
|
|
|
+ } else {
|
|
|
+ vector = llvm_vector_add(p, vector, product);
|
|
|
}
|
|
|
}
|
|
|
|