|
@@ -619,11 +619,10 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
Type *elem = mt->Matrix.elem;
|
|
|
LLVMTypeRef elem_type = lb_type(p->module, elem);
|
|
|
|
|
|
-
|
|
|
if (lb_matrix_elem_simple(mt)) {
|
|
|
unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
|
|
|
- unsigned row_count = cast(unsigned)mt->Matrix.row_count; gb_unused(row_count);
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
unsigned column_count = cast(unsigned)mt->Matrix.column_count;
|
|
|
auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
@@ -709,10 +708,79 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
|
|
|
|
i64 vector_count = get_array_type_count(vt);
|
|
|
|
|
|
- GB_ASSERT(mt->Matrix.row_count == vector_count);
|
|
|
+ GB_ASSERT(vector_count == mt->Matrix.row_count);
|
|
|
GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
|
|
|
|
|
|
Type *elem = mt->Matrix.elem;
|
|
|
+ LLVMTypeRef elem_type = lb_type(p->module, elem);
|
|
|
+
|
|
|
+ if (lb_matrix_elem_simple(mt)) {
|
|
|
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
+
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
+ unsigned column_count = cast(unsigned)mt->Matrix.column_count; gb_unused(column_count);
|
|
|
+ auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
|
|
+ auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
|
|
+
|
|
|
+ unsigned total_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
+ LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
|
|
|
+
|
|
|
+ LLVMValueRef matrix_ptr = lb_address_from_load_or_generate_local(p, rhs).value;
|
|
|
+ LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, matrix_ptr, LLVMPointerType(total_matrix_type, 0), "");
|
|
|
+ LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
|
|
|
+
|
|
|
+ for (unsigned row_index = 0; row_index < row_count; row_index++) {
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(temporary_allocator(), column_count);
|
|
|
+ for (unsigned column_index = 0; column_index < column_count; column_index++) {
|
|
|
+ unsigned offset = row_index + column_index*stride;
|
|
|
+ mask_elems[column_index] = lb_const_int(p->module, t_u32, offset).value;
|
|
|
+ }
|
|
|
+
|
|
|
+ // transpose mask
|
|
|
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count);
|
|
|
+ LLVMValueRef column = LLVMBuildShuffleVector(p->builder, matrix_vector, LLVMGetUndef(LLVMTypeOf(matrix_vector)), mask, "");
|
|
|
+ m_columns[row_index] = column;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (unsigned column_index = 0; column_index < row_count; column_index++) {
|
|
|
+ LLVMValueRef value = lb_emit_struct_ev(p, lhs, column_index).value;
|
|
|
+ LLVMValueRef row = llvm_splat(p, value, column_count);
|
|
|
+ v_rows[column_index] = row;
|
|
|
+ }
|
|
|
+
|
|
|
+ GB_ASSERT(row_count > 0);
|
|
|
+
|
|
|
+ LLVMValueRef vector = nullptr;
|
|
|
+ if (is_type_float(elem)) {
|
|
|
+ for (i64 i = 0; i < row_count; i++) {
|
|
|
+ LLVMValueRef product = LLVMBuildFMul(p->builder, v_rows[i], m_columns[i], "");
|
|
|
+ if (i == 0) {
|
|
|
+ vector = product;
|
|
|
+ } else {
|
|
|
+ vector = LLVMBuildFAdd(p->builder, vector, product, "");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (i64 i = 0; i < row_count; i++) {
|
|
|
+ LLVMValueRef product = LLVMBuildMul(p->builder, v_rows[i], m_columns[i], "");
|
|
|
+ if (i == 0) {
|
|
|
+ vector = product;
|
|
|
+ } else {
|
|
|
+ vector = LLVMBuildAdd(p->builder, vector, product, "");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+ LLVMValueRef res_ptr = res.addr.value;
|
|
|
+ unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
|
|
|
+ LLVMSetAlignment(res_ptr, alignment);
|
|
|
+
|
|
|
+ res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
|
|
|
+ LLVMBuildStore(p->builder, vector, res_ptr);
|
|
|
+
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+ }
|
|
|
|
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|