|
@@ -517,6 +517,33 @@ LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
|
|
return matrix_vector;
|
|
return matrix_vector;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
|
|
|
|
+ Type *mt = base_type(m.type);
|
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
|
+
|
|
|
|
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
|
+ unsigned column_count = cast(unsigned)mt->Matrix.column_count;
|
|
|
|
+
|
|
|
|
+ auto columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
|
+
|
|
|
|
+ LLVMValueRef vector = lb_matrix_to_vector(p, m);
|
|
|
|
+
|
|
|
|
+ unsigned mask_elems_index = 0;
|
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count*column_count);
|
|
|
|
+ for (unsigned j = 0; j < column_count; j++) {
|
|
|
|
+ for (unsigned i = 0; i < row_count; i++) {
|
|
|
|
+ unsigned offset = stride*j + i;
|
|
|
|
+ mask_elems[mask_elems_index++] = lb_const_int(p->module, t_u32, offset).value;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count);
|
|
|
|
+ LLVMValueRef trimmed_vector = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, "");
|
|
|
|
+ return trimmed_vector;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+
|
|
lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
|
|
lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
|
|
if (is_type_array(m.type)) {
|
|
if (is_type_array(m.type)) {
|
|
// no-op
|
|
// no-op
|
|
@@ -573,6 +600,46 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
|
|
return lb_addr_load(p, res);
|
|
return lb_addr_load(p, res);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef vector, Type *type) {
|
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
+ LLVMValueRef res_ptr = res.addr.value;
|
|
|
|
+ unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
|
|
|
|
+ LLVMSetAlignment(res_ptr, alignment);
|
|
|
|
+
|
|
|
|
+ res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
|
|
|
|
+ LLVMBuildStore(p->builder, vector, res_ptr);
|
|
|
|
+
|
|
|
|
+ return lb_addr_load(p, res);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) {
|
|
|
|
+ if (is_type_array(m.type)) {
|
|
|
|
+ // no-op
|
|
|
|
+ m.type = type;
|
|
|
|
+ return m;
|
|
|
|
+ }
|
|
|
|
+ Type *mt = base_type(m.type);
|
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
|
+
|
|
|
|
+ if (lb_matrix_elem_simple(mt)) {
|
|
|
|
+ LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m);
|
|
|
|
+ return lb_matrix_cast_vector_to_type(p, vector, type);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
+
|
|
|
|
+ i64 row_count = mt->Matrix.row_count;
|
|
|
|
+ i64 column_count = mt->Matrix.column_count;
|
|
|
|
+ for (i64 j = 0; j < column_count; j++) {
|
|
|
|
+ for (i64 i = 0; i < row_count; i++) {
|
|
|
|
+ lbValue src = lb_emit_matrix_ev(p, m, i, j);
|
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
|
|
+ lb_emit_store(p, dst, src);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return lb_addr_load(p, res);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
|
|
lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
|
|
lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
|
|
Type *mt = base_type(type);
|
|
Type *mt = base_type(type);
|
|
@@ -737,16 +804,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
|
|
vector = llvm_vector_add(p, vector, product);
|
|
vector = llvm_vector_add(p, vector, product);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
- lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
- LLVMValueRef res_ptr = res.addr.value;
|
|
|
|
- unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
|
|
|
|
- LLVMSetAlignment(res_ptr, alignment);
|
|
|
|
-
|
|
|
|
- res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
|
|
|
|
- LLVMBuildStore(p->builder, vector, res_ptr);
|
|
|
|
|
|
|
|
- return lb_addr_load(p, res);
|
|
|
|
|
|
+ return lb_matrix_cast_vector_to_type(p, vector, type);
|
|
}
|
|
}
|
|
|
|
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|