|
@@ -495,21 +495,70 @@ bool lb_matrix_elem_simple(Type *t) {
|
|
|
case Basic_f16le:
|
|
|
case Basic_f16be:
|
|
|
// TODO(bill): determine when this is fine
|
|
|
- return false;
|
|
|
+ return true;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
|
|
|
+ Type *mt = base_type(matrix.type);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+ LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem);
|
|
|
+
|
|
|
+ unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
+ LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
|
|
|
+
|
|
|
+ LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value;
|
|
|
+ LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), "");
|
|
|
+ LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
|
|
|
+ return matrix_vector;
|
|
|
+}
|
|
|
+
|
|
|
lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
|
|
|
if (is_type_array(m.type)) {
|
|
|
+ // no-op
|
|
|
m.type = type;
|
|
|
return m;
|
|
|
}
|
|
|
Type *mt = base_type(m.type);
|
|
|
GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
|
|
|
+ if (lb_matrix_elem_simple(mt)) {
|
|
|
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
+ unsigned column_count = cast(unsigned)mt->Matrix.column_count;
|
|
|
+
|
|
|
+ auto rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
+
|
|
|
+ LLVMValueRef vector = lb_matrix_to_vector(p, m);
|
|
|
+ for (unsigned i = 0; i < row_count; i++) {
|
|
|
+ for (unsigned j = 0; j < column_count; j++) {
|
|
|
+ unsigned offset = stride*j + i;
|
|
|
+ mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
|
|
|
+ }
|
|
|
+
|
|
|
+ // transpose mask
|
|
|
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count);
|
|
|
+ LLVMValueRef row = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, "");
|
|
|
+ rows[i] = row;
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+ for_array(i, rows) {
|
|
|
+ LLVMValueRef row = rows[i];
|
|
|
+ lbValue dst_row_ptr = lb_emit_matrix_epi(p, res.addr, 0, i);
|
|
|
+ LLVMValueRef ptr = dst_row_ptr.value;
|
|
|
+ ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(LLVMTypeOf(row), 0), "");
|
|
|
+ LLVMBuildStore(p->builder, row, ptr);
|
|
|
+ }
|
|
|
+
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+ }
|
|
|
+
|
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
|
|
i64 row_count = mt->Matrix.row_count;
|
|
@@ -556,21 +605,6 @@ lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type)
|
|
|
|
|
|
}
|
|
|
|
|
|
-
|
|
|
-LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
|
|
|
- Type *mt = base_type(matrix.type);
|
|
|
- GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
- LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem);
|
|
|
-
|
|
|
- unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
- LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
|
|
|
-
|
|
|
- LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value;
|
|
|
- LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), "");
|
|
|
- LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
|
|
|
- return matrix_vector;
|
|
|
-}
|
|
|
-
|
|
|
lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
|
|
|
Type *xt = base_type(lhs.type);
|
|
|
Type *yt = base_type(rhs.type);
|
|
@@ -594,12 +628,11 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
|
|
|
auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
|
|
|
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
|
|
|
|
|
|
-
|
|
|
LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs);
|
|
|
LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs);
|
|
|
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), inner);
|
|
|
for (unsigned i = 0; i < outer_rows; i++) {
|
|
|
- auto mask_elems = slice_make<LLVMValueRef>(temporary_allocator(), inner);
|
|
|
for (unsigned j = 0; j < inner; j++) {
|
|
|
unsigned offset = x_stride*j + i;
|
|
|
mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
|
|
@@ -616,8 +649,6 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
|
|
|
LLVMValueRef column = LLVMBuildShuffleVector(p->builder, y_vector, LLVMGetUndef(LLVMTypeOf(y_vector)), mask, "");
|
|
|
y_columns[i] = column;
|
|
|
}
|
|
|
-
|
|
|
-
|
|
|
|
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
for_array(i, x_rows) {
|
|
@@ -760,8 +791,8 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
|
|
|
LLVMValueRef matrix_vector = lb_matrix_to_vector(p, rhs);
|
|
|
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
for (unsigned row_index = 0; row_index < row_count; row_index++) {
|
|
|
- auto mask_elems = slice_make<LLVMValueRef>(temporary_allocator(), column_count);
|
|
|
for (unsigned column_index = 0; column_index < column_count; column_index++) {
|
|
|
unsigned offset = row_index + column_index*stride;
|
|
|
mask_elems[column_index] = lb_const_int(p->module, t_u32, offset).value;
|