|
@@ -331,7 +331,7 @@ bool lb_try_direct_vector_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbVal
|
|
|
z = LLVMBuildFRem(p->builder, x, y, "");
|
|
|
break;
|
|
|
default:
|
|
|
- GB_PANIC("Unsupported vector operation");
|
|
|
+ GB_PANIC("Unsupported vector operation %.*s", LIT(token_strings[op]));
|
|
|
break;
|
|
|
}
|
|
|
|
|
@@ -476,11 +476,545 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+bool lb_is_matrix_simdable(Type *t) {
|
|
|
+ Type *mt = base_type(t);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+
|
|
|
+ Type *elem = core_type(mt->Matrix.elem);
|
|
|
+ if (is_type_complex(elem)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (is_type_different_to_arch_endianness(elem)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ switch (build_context.metrics.arch) {
|
|
|
+ case TargetArch_amd64:
|
|
|
+ case TargetArch_arm64:
|
|
|
+ // possible
|
|
|
+ break;
|
|
|
+ case TargetArch_386:
|
|
|
+ case TargetArch_wasm32:
|
|
|
+ // nope
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (elem->kind == Type_Basic) {
|
|
|
+ switch (elem->Basic.kind) {
|
|
|
+ case Basic_f16:
|
|
|
+ case Basic_f16le:
|
|
|
+ case Basic_f16be:
|
|
|
+ switch (build_context.metrics.arch) {
|
|
|
+ case TargetArch_amd64:
|
|
|
+ return false;
|
|
|
+ case TargetArch_arm64:
|
|
|
+ // TODO(bill): determine when this is fine
|
|
|
+ return true;
|
|
|
+ case TargetArch_386:
|
|
|
+ case TargetArch_wasm32:
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
|
|
|
+ Type *mt = base_type(matrix.type);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+ LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem);
|
|
|
+
|
|
|
+ unsigned total_count = cast(unsigned)matrix_type_total_internal_elems(mt);
|
|
|
+ LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count);
|
|
|
+
|
|
|
+#if 1
|
|
|
+ LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value;
|
|
|
+ LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), "");
|
|
|
+ LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, "");
|
|
|
+ LLVMSetAlignment(matrix_vector, cast(unsigned)type_align_of(mt));
|
|
|
+ return matrix_vector;
|
|
|
+#else
|
|
|
+ LLVMValueRef matrix_vector = LLVMBuildBitCast(p->builder, matrix.value, total_matrix_type, "");
|
|
|
+ return matrix_vector;
|
|
|
+#endif
|
|
|
+}
|
|
|
+
|
|
|
+LLVMValueRef lb_matrix_trimmed_vector_mask(lbProcedure *p, Type *mt) {
|
|
|
+ mt = base_type(mt);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+
|
|
|
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
+ unsigned column_count = cast(unsigned)mt->Matrix.column_count;
|
|
|
+ unsigned mask_elems_index = 0;
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count*column_count);
|
|
|
+ for (unsigned j = 0; j < column_count; j++) {
|
|
|
+ for (unsigned i = 0; i < row_count; i++) {
|
|
|
+ unsigned offset = stride*j + i;
|
|
|
+ mask_elems[mask_elems_index++] = lb_const_int(p->module, t_u32, offset).value;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count);
|
|
|
+ return mask;
|
|
|
+}
|
|
|
+
|
|
|
+LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
|
|
|
+ LLVMValueRef vector = lb_matrix_to_vector(p, m);
|
|
|
+
|
|
|
+ Type *mt = base_type(m.type);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+
|
|
|
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
+ if (stride == row_count) {
|
|
|
+ return vector;
|
|
|
+ }
|
|
|
+
|
|
|
+ LLVMValueRef mask = lb_matrix_trimmed_vector_mask(p, mt);
|
|
|
+ LLVMValueRef trimmed_vector = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, "");
|
|
|
+ return trimmed_vector;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
|
|
|
+ if (is_type_array(m.type)) {
|
|
|
+ // no-op
|
|
|
+ m.type = type;
|
|
|
+ return m;
|
|
|
+ }
|
|
|
+ Type *mt = base_type(m.type);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+
|
|
|
+ if (lb_is_matrix_simdable(mt)) {
|
|
|
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
+ unsigned column_count = cast(unsigned)mt->Matrix.column_count;
|
|
|
+
|
|
|
+ auto rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
+
|
|
|
+ LLVMValueRef vector = lb_matrix_to_vector(p, m);
|
|
|
+ for (unsigned i = 0; i < row_count; i++) {
|
|
|
+ for (unsigned j = 0; j < column_count; j++) {
|
|
|
+ unsigned offset = stride*j + i;
|
|
|
+ mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
|
|
|
+ }
|
|
|
+
|
|
|
+ // transpose mask
|
|
|
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count);
|
|
|
+ LLVMValueRef row = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, "");
|
|
|
+ rows[i] = row;
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+ for_array(i, rows) {
|
|
|
+ LLVMValueRef row = rows[i];
|
|
|
+ lbValue dst_row_ptr = lb_emit_matrix_epi(p, res.addr, 0, i);
|
|
|
+ LLVMValueRef ptr = dst_row_ptr.value;
|
|
|
+ ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(LLVMTypeOf(row), 0), "");
|
|
|
+ LLVMBuildStore(p->builder, row, ptr);
|
|
|
+ }
|
|
|
+
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+
|
|
|
+ i64 row_count = mt->Matrix.row_count;
|
|
|
+ i64 column_count = mt->Matrix.column_count;
|
|
|
+ for (i64 j = 0; j < column_count; j++) {
|
|
|
+ for (i64 i = 0; i < row_count; i++) {
|
|
|
+ lbValue src = lb_emit_matrix_ev(p, m, i, j);
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, j, i);
|
|
|
+ lb_emit_store(p, dst, src);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+}
|
|
|
+
|
|
|
+lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef vector, Type *type) {
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+ LLVMValueRef res_ptr = res.addr.value;
|
|
|
+ unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
|
|
|
+ LLVMSetAlignment(res_ptr, alignment);
|
|
|
+
|
|
|
+ res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
|
|
|
+ LLVMBuildStore(p->builder, vector, res_ptr);
|
|
|
+
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+}
|
|
|
+
|
|
|
+lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) {
|
|
|
+ if (is_type_array(m.type)) {
|
|
|
+ // no-op
|
|
|
+ m.type = type;
|
|
|
+ return m;
|
|
|
+ }
|
|
|
+ Type *mt = base_type(m.type);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+
|
|
|
+ if (lb_is_matrix_simdable(mt)) {
|
|
|
+ LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m);
|
|
|
+ return lb_matrix_cast_vector_to_type(p, vector, type);
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+
|
|
|
+ i64 row_count = mt->Matrix.row_count;
|
|
|
+ i64 column_count = mt->Matrix.column_count;
|
|
|
+ for (i64 j = 0; j < column_count; j++) {
|
|
|
+ for (i64 i = 0; i < row_count; i++) {
|
|
|
+ lbValue src = lb_emit_matrix_ev(p, m, i, j);
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
|
+ lb_emit_store(p, dst, src);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
|
|
|
+ Type *mt = base_type(type);
|
|
|
+ Type *at = base_type(a.type);
|
|
|
+ Type *bt = base_type(b.type);
|
|
|
+ GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
+ GB_ASSERT(at->kind == Type_Array);
|
|
|
+ GB_ASSERT(bt->kind == Type_Array);
|
|
|
+
|
|
|
+
|
|
|
+ i64 row_count = mt->Matrix.row_count;
|
|
|
+ i64 column_count = mt->Matrix.column_count;
|
|
|
+
|
|
|
+ GB_ASSERT(row_count == at->Array.count);
|
|
|
+ GB_ASSERT(column_count == bt->Array.count);
|
|
|
+
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+
|
|
|
+ for (i64 j = 0; j < column_count; j++) {
|
|
|
+ for (i64 i = 0; i < row_count; i++) {
|
|
|
+ lbValue x = lb_emit_struct_ev(p, a, cast(i32)i);
|
|
|
+ lbValue y = lb_emit_struct_ev(p, b, cast(i32)j);
|
|
|
+ lbValue src = lb_emit_arith(p, Token_Mul, x, y, mt->Matrix.elem);
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
|
+ lb_emit_store(p, dst, src);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
|
|
|
+ // TODO(bill): Handle edge case for f16 types on x86(-64) platforms
|
|
|
+
|
|
|
+ Type *xt = base_type(lhs.type);
|
|
|
+ Type *yt = base_type(rhs.type);
|
|
|
+
|
|
|
+ GB_ASSERT(is_type_matrix(type));
|
|
|
+ GB_ASSERT(is_type_matrix(xt));
|
|
|
+ GB_ASSERT(is_type_matrix(yt));
|
|
|
+ GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
|
|
|
+ GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
|
|
|
+
|
|
|
+ Type *elem = xt->Matrix.elem;
|
|
|
+
|
|
|
+ unsigned outer_rows = cast(unsigned)xt->Matrix.row_count;
|
|
|
+ unsigned inner = cast(unsigned)xt->Matrix.column_count;
|
|
|
+ unsigned outer_columns = cast(unsigned)yt->Matrix.column_count;
|
|
|
+
|
|
|
+ if (lb_is_matrix_simdable(xt)) {
|
|
|
+ unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
|
|
|
+ unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
|
|
|
+
|
|
|
+ auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
|
|
|
+ auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
|
|
|
+
|
|
|
+ LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs);
|
|
|
+ LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs);
|
|
|
+
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), inner);
|
|
|
+ for (unsigned i = 0; i < outer_rows; i++) {
|
|
|
+ for (unsigned j = 0; j < inner; j++) {
|
|
|
+ unsigned offset = x_stride*j + i;
|
|
|
+ mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
|
|
|
+ }
|
|
|
+
|
|
|
+ // transpose mask
|
|
|
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner);
|
|
|
+ LLVMValueRef row = LLVMBuildShuffleVector(p->builder, x_vector, LLVMGetUndef(LLVMTypeOf(x_vector)), mask, "");
|
|
|
+ x_rows[i] = row;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (unsigned i = 0; i < outer_columns; i++) {
|
|
|
+ LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner);
|
|
|
+ LLVMValueRef column = LLVMBuildShuffleVector(p->builder, y_vector, LLVMGetUndef(LLVMTypeOf(y_vector)), mask, "");
|
|
|
+ y_columns[i] = column;
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+ for_array(i, x_rows) {
|
|
|
+ LLVMValueRef x_row = x_rows[i];
|
|
|
+ for_array(j, y_columns) {
|
|
|
+ LLVMValueRef y_column = y_columns[j];
|
|
|
+ LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column);
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
|
+ LLVMBuildStore(p->builder, elem, dst.value);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+ }
|
|
|
+
|
|
|
+ {
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+
|
|
|
+ auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner);
|
|
|
+
|
|
|
+ for (unsigned j = 0; j < outer_columns; j++) {
|
|
|
+ for (unsigned i = 0; i < outer_rows; i++) {
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
|
+ for (unsigned k = 0; k < inner; k++) {
|
|
|
+ inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k);
|
|
|
+ inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j);
|
|
|
+ }
|
|
|
+
|
|
|
+ lbValue sum = lb_const_nil(p->module, elem);
|
|
|
+ for (unsigned k = 0; k < inner; k++) {
|
|
|
+ lbValue a = inners[k][0];
|
|
|
+ lbValue b = inners[k][1];
|
|
|
+ sum = lb_emit_mul_add(p, a, b, sum, elem);
|
|
|
+ }
|
|
|
+ lb_emit_store(p, dst, sum);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
|
|
|
+ // TODO(bill): Handle edge case for f16 types on x86(-64) platforms
|
|
|
+
|
|
|
+ Type *mt = base_type(lhs.type);
|
|
|
+ Type *vt = base_type(rhs.type);
|
|
|
+
|
|
|
+ GB_ASSERT(is_type_matrix(mt));
|
|
|
+ GB_ASSERT(is_type_array_like(vt));
|
|
|
+
|
|
|
+ i64 vector_count = get_array_type_count(vt);
|
|
|
+
|
|
|
+ GB_ASSERT(mt->Matrix.column_count == vector_count);
|
|
|
+ GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
|
|
|
+
|
|
|
+ Type *elem = mt->Matrix.elem;
|
|
|
+
|
|
|
+ if (lb_is_matrix_simdable(mt)) {
|
|
|
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
+
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
+ unsigned column_count = cast(unsigned)mt->Matrix.column_count;
|
|
|
+ auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
+ auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
+
|
|
|
+ LLVMValueRef matrix_vector = lb_matrix_to_vector(p, lhs);
|
|
|
+
|
|
|
+ for (unsigned column_index = 0; column_index < column_count; column_index++) {
|
|
|
+ LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count);
|
|
|
+ LLVMValueRef column = LLVMBuildShuffleVector(p->builder, matrix_vector, LLVMGetUndef(LLVMTypeOf(matrix_vector)), mask, "");
|
|
|
+ m_columns[column_index] = column;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (unsigned row_index = 0; row_index < column_count; row_index++) {
|
|
|
+ LLVMValueRef value = lb_emit_struct_ev(p, rhs, row_index).value;
|
|
|
+ LLVMValueRef row = llvm_vector_broadcast(p, value, row_count);
|
|
|
+ v_rows[row_index] = row;
|
|
|
+ }
|
|
|
+
|
|
|
+ GB_ASSERT(column_count > 0);
|
|
|
+
|
|
|
+ LLVMValueRef vector = nullptr;
|
|
|
+ for (i64 i = 0; i < column_count; i++) {
|
|
|
+ if (i == 0) {
|
|
|
+ vector = llvm_vector_mul(p, m_columns[i], v_rows[i]);
|
|
|
+ } else {
|
|
|
+ vector = llvm_vector_mul_add(p, m_columns[i], v_rows[i], vector);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return lb_matrix_cast_vector_to_type(p, vector, type);
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+
|
|
|
+ for (i64 i = 0; i < mt->Matrix.row_count; i++) {
|
|
|
+ for (i64 j = 0; j < mt->Matrix.column_count; j++) {
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, 0);
|
|
|
+ lbValue d0 = lb_emit_load(p, dst);
|
|
|
+
|
|
|
+ lbValue a = lb_emit_matrix_ev(p, lhs, i, j);
|
|
|
+ lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j);
|
|
|
+ lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
|
|
|
+ lb_emit_store(p, dst, c);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+}
|
|
|
+
|
|
|
+lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
|
|
|
+ // TODO(bill): Handle edge case for f16 types on x86(-64) platforms
|
|
|
+
|
|
|
+ Type *mt = base_type(rhs.type);
|
|
|
+ Type *vt = base_type(lhs.type);
|
|
|
+
|
|
|
+ GB_ASSERT(is_type_matrix(mt));
|
|
|
+ GB_ASSERT(is_type_array_like(vt));
|
|
|
+
|
|
|
+ i64 vector_count = get_array_type_count(vt);
|
|
|
+
|
|
|
+ GB_ASSERT(vector_count == mt->Matrix.row_count);
|
|
|
+ GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
|
|
|
+
|
|
|
+ Type *elem = mt->Matrix.elem;
|
|
|
+
|
|
|
+ if (lb_is_matrix_simdable(mt)) {
|
|
|
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
|
|
|
+
|
|
|
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
|
|
|
+ unsigned column_count = cast(unsigned)mt->Matrix.column_count; gb_unused(column_count);
|
|
|
+ auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
|
|
+ auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count);
|
|
|
+
|
|
|
+ LLVMValueRef matrix_vector = lb_matrix_to_vector(p, rhs);
|
|
|
+
|
|
|
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
|
|
+ for (unsigned row_index = 0; row_index < row_count; row_index++) {
|
|
|
+ for (unsigned column_index = 0; column_index < column_count; column_index++) {
|
|
|
+ unsigned offset = row_index + column_index*stride;
|
|
|
+ mask_elems[column_index] = lb_const_int(p->module, t_u32, offset).value;
|
|
|
+ }
|
|
|
+
|
|
|
+ // transpose mask
|
|
|
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count);
|
|
|
+ LLVMValueRef column = LLVMBuildShuffleVector(p->builder, matrix_vector, LLVMGetUndef(LLVMTypeOf(matrix_vector)), mask, "");
|
|
|
+ m_columns[row_index] = column;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (unsigned column_index = 0; column_index < row_count; column_index++) {
|
|
|
+ LLVMValueRef value = lb_emit_struct_ev(p, lhs, column_index).value;
|
|
|
+ LLVMValueRef row = llvm_vector_broadcast(p, value, column_count);
|
|
|
+ v_rows[column_index] = row;
|
|
|
+ }
|
|
|
+
|
|
|
+ GB_ASSERT(row_count > 0);
|
|
|
+
|
|
|
+ LLVMValueRef vector = nullptr;
|
|
|
+ for (i64 i = 0; i < row_count; i++) {
|
|
|
+ if (i == 0) {
|
|
|
+ vector = llvm_vector_mul(p, v_rows[i], m_columns[i]);
|
|
|
+ } else {
|
|
|
+ vector = llvm_vector_mul_add(p, v_rows[i], m_columns[i], vector);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+ LLVMValueRef res_ptr = res.addr.value;
|
|
|
+ unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
|
|
|
+ LLVMSetAlignment(res_ptr, alignment);
|
|
|
+
|
|
|
+ res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
|
|
|
+ LLVMBuildStore(p->builder, vector, res_ptr);
|
|
|
+
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+ }
|
|
|
+
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
+
|
|
|
+ for (i64 j = 0; j < mt->Matrix.column_count; j++) {
|
|
|
+ for (i64 k = 0; k < mt->Matrix.row_count; k++) {
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, 0, j);
|
|
|
+ lbValue d0 = lb_emit_load(p, dst);
|
|
|
+
|
|
|
+ lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k);
|
|
|
+ lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
|
|
|
+ lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
|
|
|
+ lb_emit_store(p, dst, c);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return lb_addr_load(p, res);
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise=false) {
|
|
|
+ GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
|
|
|
+
|
|
|
+
|
|
|
+ if (op == Token_Mul && !component_wise) {
|
|
|
+ Type *xt = base_type(lhs.type);
|
|
|
+ Type *yt = base_type(rhs.type);
|
|
|
+
|
|
|
+ if (xt->kind == Type_Matrix) {
|
|
|
+ if (yt->kind == Type_Matrix) {
|
|
|
+ return lb_emit_matrix_mul(p, lhs, rhs, type);
|
|
|
+ } else if (is_type_array_like(yt)) {
|
|
|
+ return lb_emit_matrix_mul_vector(p, lhs, rhs, type);
|
|
|
+ }
|
|
|
+ } else if (is_type_array_like(xt)) {
|
|
|
+ GB_ASSERT(yt->kind == Type_Matrix);
|
|
|
+ return lb_emit_vector_mul_matrix(p, lhs, rhs, type);
|
|
|
+ }
|
|
|
+
|
|
|
+ } else {
|
|
|
+ if (is_type_matrix(lhs.type)) {
|
|
|
+ rhs = lb_emit_conv(p, rhs, lhs.type);
|
|
|
+ } else {
|
|
|
+ lhs = lb_emit_conv(p, lhs, rhs.type);
|
|
|
+ }
|
|
|
+
|
|
|
+ Type *xt = base_type(lhs.type);
|
|
|
+ Type *yt = base_type(rhs.type);
|
|
|
+
|
|
|
+ GB_ASSERT_MSG(are_types_identical(xt, yt), "%s %.*s %s", type_to_string(lhs.type), LIT(token_strings[op]), type_to_string(rhs.type));
|
|
|
+ GB_ASSERT(xt->kind == Type_Matrix);
|
|
|
+ // element-wise arithmetic
|
|
|
+ // pretend it is an array
|
|
|
+ lbValue array_lhs = lhs;
|
|
|
+ lbValue array_rhs = rhs;
|
|
|
+ Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_internal_elems(xt));
|
|
|
+ GB_ASSERT(type_size_of(array_type) == type_size_of(xt));
|
|
|
+
|
|
|
+ array_lhs.type = array_type;
|
|
|
+ array_rhs.type = array_type;
|
|
|
+
|
|
|
+ if (token_is_comparison(op)) {
|
|
|
+ lbValue res = lb_emit_comp(p, op, array_lhs, array_rhs);
|
|
|
+ return lb_emit_conv(p, res, type);
|
|
|
+ } else {
|
|
|
+ lbValue array = lb_emit_arith(p, op, array_lhs, array_rhs, array_type);
|
|
|
+ array.type = type;
|
|
|
+ return array;
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ GB_PANIC("TODO: lb_emit_arith_matrix");
|
|
|
+
|
|
|
+ return {};
|
|
|
+}
|
|
|
+
|
|
|
|
|
|
|
|
|
lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
|
|
|
if (is_type_array_like(lhs.type) || is_type_array_like(rhs.type)) {
|
|
|
return lb_emit_arith_array(p, op, lhs, rhs, type);
|
|
|
+ } else if (is_type_matrix(lhs.type) || is_type_matrix(rhs.type)) {
|
|
|
+ return lb_emit_arith_matrix(p, op, lhs, rhs, type);
|
|
|
} else if (is_type_complex(type)) {
|
|
|
lhs = lb_emit_conv(p, lhs, type);
|
|
|
rhs = lb_emit_conv(p, rhs, type);
|
|
@@ -749,6 +1283,13 @@ lbValue lb_build_binary_expr(lbProcedure *p, Ast *expr) {
|
|
|
ast_node(be, BinaryExpr, expr);
|
|
|
|
|
|
TypeAndValue tv = type_and_value_of_expr(expr);
|
|
|
+
|
|
|
+ if (is_type_matrix(be->left->tav.type) || is_type_matrix(be->right->tav.type)) {
|
|
|
+ lbValue left = lb_build_expr(p, be->left);
|
|
|
+ lbValue right = lb_build_expr(p, be->right);
|
|
|
+ return lb_emit_arith_matrix(p, be->op.kind, left, right, default_type(tv.type));
|
|
|
+ }
|
|
|
+
|
|
|
|
|
|
switch (be->op.kind) {
|
|
|
case Token_Add:
|
|
@@ -1417,6 +1958,62 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) {
|
|
|
}
|
|
|
return lb_addr_load(p, v);
|
|
|
}
|
|
|
+
|
|
|
+ if (is_type_matrix(dst) && !is_type_matrix(src)) {
|
|
|
+ GB_ASSERT_MSG(dst->Matrix.row_count == dst->Matrix.column_count, "%s <- %s", type_to_string(dst), type_to_string(src));
|
|
|
+
|
|
|
+ Type *elem = base_array_type(dst);
|
|
|
+ lbValue e = lb_emit_conv(p, value, elem);
|
|
|
+ lbAddr v = lb_add_local_generated(p, t, false);
|
|
|
+ for (i64 i = 0; i < dst->Matrix.row_count; i++) {
|
|
|
+ isize j = cast(isize)i;
|
|
|
+ lbValue ptr = lb_emit_matrix_epi(p, v.addr, j, j);
|
|
|
+ lb_emit_store(p, ptr, e);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ return lb_addr_load(p, v);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (is_type_matrix(dst) && is_type_matrix(src)) {
|
|
|
+ GB_ASSERT(dst->kind == Type_Matrix);
|
|
|
+ GB_ASSERT(src->kind == Type_Matrix);
|
|
|
+ lbAddr v = lb_add_local_generated(p, t, true);
|
|
|
+
|
|
|
+ if (is_matrix_square(dst) && is_matrix_square(dst)) {
|
|
|
+ for (i64 j = 0; j < dst->Matrix.column_count; j++) {
|
|
|
+ for (i64 i = 0; i < dst->Matrix.row_count; i++) {
|
|
|
+ if (i < src->Matrix.row_count && j < src->Matrix.column_count) {
|
|
|
+ lbValue d = lb_emit_matrix_epi(p, v.addr, i, j);
|
|
|
+ lbValue s = lb_emit_matrix_ev(p, value, i, j);
|
|
|
+ lb_emit_store(p, d, s);
|
|
|
+ } else if (i == j) {
|
|
|
+ lbValue d = lb_emit_matrix_epi(p, v.addr, i, j);
|
|
|
+ lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true);
|
|
|
+ lb_emit_store(p, d, s);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count;
|
|
|
+ i64 src_count = src->Matrix.row_count*src->Matrix.column_count;
|
|
|
+ GB_ASSERT(dst_count == src_count);
|
|
|
+
|
|
|
+ for (i64 j = 0; j < src->Matrix.column_count; j++) {
|
|
|
+ for (i64 i = 0; i < src->Matrix.row_count; i++) {
|
|
|
+ lbValue s = lb_emit_matrix_ev(p, value, i, j);
|
|
|
+ i64 index = i + j*src->Matrix.row_count;
|
|
|
+ i64 dst_i = index%dst->Matrix.row_count;
|
|
|
+ i64 dst_j = index/dst->Matrix.row_count;
|
|
|
+ lbValue d = lb_emit_matrix_epi(p, v.addr, dst_i, dst_j);
|
|
|
+ lb_emit_store(p, d, s);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return lb_addr_load(p, v);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
|
|
|
if (is_type_any(dst)) {
|
|
|
if (is_type_untyped_nil(src)) {
|
|
@@ -2481,6 +3078,10 @@ lbValue lb_build_expr(lbProcedure *p, Ast *expr) {
|
|
|
case_ast_node(ie, IndexExpr, expr);
|
|
|
return lb_addr_load(p, lb_build_addr(p, expr));
|
|
|
case_end;
|
|
|
+
|
|
|
+ case_ast_node(ie, MatrixIndexExpr, expr);
|
|
|
+ return lb_addr_load(p, lb_build_addr(p, expr));
|
|
|
+ case_end;
|
|
|
|
|
|
case_ast_node(ia, InlineAsmExpr, expr);
|
|
|
Type *t = type_of_expr(expr);
|
|
@@ -2976,6 +3577,25 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) {
|
|
|
lbValue v = lb_emit_ptr_offset(p, elem, index);
|
|
|
return lb_addr(v);
|
|
|
}
|
|
|
+
|
|
|
+ case Type_Matrix: {
|
|
|
+ lbValue matrix = {};
|
|
|
+ matrix = lb_build_addr_ptr(p, ie->expr);
|
|
|
+ if (deref) {
|
|
|
+ matrix = lb_emit_load(p, matrix);
|
|
|
+ }
|
|
|
+ lbValue index = lb_build_expr(p, ie->index);
|
|
|
+ index = lb_emit_conv(p, index, t_int);
|
|
|
+ lbValue elem = lb_emit_matrix_ep(p, matrix, lb_const_int(p->module, t_int, 0), index);
|
|
|
+ elem = lb_emit_conv(p, elem, alloc_type_pointer(type_of_expr(expr)));
|
|
|
+
|
|
|
+ auto index_tv = type_and_value_of_expr(ie->index);
|
|
|
+ if (index_tv.mode != Addressing_Constant) {
|
|
|
+ lbValue len = lb_const_int(p->module, t_int, t->Matrix.column_count);
|
|
|
+ lb_emit_bounds_check(p, ast_token(ie->index), index, len);
|
|
|
+ }
|
|
|
+ return lb_addr(elem);
|
|
|
+ }
|
|
|
|
|
|
|
|
|
case Type_Basic: { // Basic_string
|
|
@@ -2998,6 +3618,35 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) {
|
|
|
}
|
|
|
}
|
|
|
case_end;
|
|
|
+
|
|
|
+ case_ast_node(ie, MatrixIndexExpr, expr);
|
|
|
+ Type *t = base_type(type_of_expr(ie->expr));
|
|
|
+
|
|
|
+ bool deref = is_type_pointer(t);
|
|
|
+ t = base_type(type_deref(t));
|
|
|
+
|
|
|
+ lbValue m = {};
|
|
|
+ m = lb_build_addr_ptr(p, ie->expr);
|
|
|
+ if (deref) {
|
|
|
+ m = lb_emit_load(p, m);
|
|
|
+ }
|
|
|
+ lbValue row_index = lb_build_expr(p, ie->row_index);
|
|
|
+ lbValue column_index = lb_build_expr(p, ie->column_index);
|
|
|
+ row_index = lb_emit_conv(p, row_index, t_int);
|
|
|
+ column_index = lb_emit_conv(p, column_index, t_int);
|
|
|
+ lbValue elem = lb_emit_matrix_ep(p, m, row_index, column_index);
|
|
|
+
|
|
|
+ auto row_index_tv = type_and_value_of_expr(ie->row_index);
|
|
|
+ auto column_index_tv = type_and_value_of_expr(ie->column_index);
|
|
|
+ if (row_index_tv.mode != Addressing_Constant || column_index_tv.mode != Addressing_Constant) {
|
|
|
+ lbValue row_count = lb_const_int(p->module, t_int, t->Matrix.row_count);
|
|
|
+ lbValue column_count = lb_const_int(p->module, t_int, t->Matrix.column_count);
|
|
|
+ lb_emit_matrix_bounds_check(p, ast_token(ie->row_index), row_index, column_index, row_count, column_count);
|
|
|
+ }
|
|
|
+ return lb_addr(elem);
|
|
|
+
|
|
|
+
|
|
|
+ case_end;
|
|
|
|
|
|
case_ast_node(se, SliceExpr, expr);
|
|
|
|
|
@@ -3246,6 +3895,7 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) {
|
|
|
case Type_Slice: et = bt->Slice.elem; break;
|
|
|
case Type_BitSet: et = bt->BitSet.elem; break;
|
|
|
case Type_SimdVector: et = bt->SimdVector.elem; break;
|
|
|
+ case Type_Matrix: et = bt->Matrix.elem; break;
|
|
|
}
|
|
|
|
|
|
String proc_name = {};
|
|
@@ -3777,7 +4427,104 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) {
|
|
|
}
|
|
|
break;
|
|
|
}
|
|
|
+
|
|
|
+ case Type_Matrix: {
|
|
|
+ if (cl->elems.count > 0) {
|
|
|
+ lb_addr_store(p, v, lb_const_value(p->module, type, exact_value_compound(expr)));
|
|
|
|
|
|
+ auto temp_data = array_make<lbCompoundLitElemTempData>(temporary_allocator(), 0, cl->elems.count);
|
|
|
+
|
|
|
+ // NOTE(bill): Separate value, gep, store into their own chunks
|
|
|
+ for_array(i, cl->elems) {
|
|
|
+ Ast *elem = cl->elems[i];
|
|
|
+
|
|
|
+ if (elem->kind == Ast_FieldValue) {
|
|
|
+ ast_node(fv, FieldValue, elem);
|
|
|
+ if (lb_is_elem_const(fv->value, et)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (is_ast_range(fv->field)) {
|
|
|
+ ast_node(ie, BinaryExpr, fv->field);
|
|
|
+ TypeAndValue lo_tav = ie->left->tav;
|
|
|
+ TypeAndValue hi_tav = ie->right->tav;
|
|
|
+ GB_ASSERT(lo_tav.mode == Addressing_Constant);
|
|
|
+ GB_ASSERT(hi_tav.mode == Addressing_Constant);
|
|
|
+
|
|
|
+ TokenKind op = ie->op.kind;
|
|
|
+ i64 lo = exact_value_to_i64(lo_tav.value);
|
|
|
+ i64 hi = exact_value_to_i64(hi_tav.value);
|
|
|
+ if (op != Token_RangeHalf) {
|
|
|
+ hi += 1;
|
|
|
+ }
|
|
|
+
|
|
|
+ lbValue value = lb_build_expr(p, fv->value);
|
|
|
+
|
|
|
+ for (i64 k = lo; k < hi; k++) {
|
|
|
+ lbCompoundLitElemTempData data = {};
|
|
|
+ data.value = value;
|
|
|
+
|
|
|
+ data.elem_index = cast(i32)matrix_index_to_offset(bt, k);
|
|
|
+ array_add(&temp_data, data);
|
|
|
+ }
|
|
|
+
|
|
|
+ } else {
|
|
|
+ auto tav = fv->field->tav;
|
|
|
+ GB_ASSERT(tav.mode == Addressing_Constant);
|
|
|
+ i64 index = exact_value_to_i64(tav.value);
|
|
|
+
|
|
|
+ lbValue value = lb_build_expr(p, fv->value);
|
|
|
+ lbCompoundLitElemTempData data = {};
|
|
|
+ data.value = lb_emit_conv(p, value, et);
|
|
|
+ data.expr = fv->value;
|
|
|
+
|
|
|
+ data.elem_index = cast(i32)matrix_index_to_offset(bt, index);
|
|
|
+ array_add(&temp_data, data);
|
|
|
+ }
|
|
|
+
|
|
|
+ } else {
|
|
|
+ if (lb_is_elem_const(elem, et)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ lbCompoundLitElemTempData data = {};
|
|
|
+ data.expr = elem;
|
|
|
+ data.elem_index = cast(i32)matrix_index_to_offset(bt, i);
|
|
|
+ array_add(&temp_data, data);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for_array(i, temp_data) {
|
|
|
+ temp_data[i].gep = lb_emit_array_epi(p, lb_addr_get_ptr(p, v), temp_data[i].elem_index);
|
|
|
+ }
|
|
|
+
|
|
|
+ for_array(i, temp_data) {
|
|
|
+ lbValue field_expr = temp_data[i].value;
|
|
|
+ Ast *expr = temp_data[i].expr;
|
|
|
+
|
|
|
+ auto prev_hint = lb_set_copy_elision_hint(p, lb_addr(temp_data[i].gep), expr);
|
|
|
+
|
|
|
+ if (field_expr.value == nullptr) {
|
|
|
+ field_expr = lb_build_expr(p, expr);
|
|
|
+ }
|
|
|
+ Type *t = field_expr.type;
|
|
|
+ GB_ASSERT(t->kind != Type_Tuple);
|
|
|
+ lbValue ev = lb_emit_conv(p, field_expr, et);
|
|
|
+
|
|
|
+ if (!p->copy_elision_hint.used) {
|
|
|
+ temp_data[i].value = ev;
|
|
|
+ }
|
|
|
+
|
|
|
+ lb_reset_copy_elision_hint(p, prev_hint);
|
|
|
+ }
|
|
|
+
|
|
|
+ for_array(i, temp_data) {
|
|
|
+ if (temp_data[i].value.value != nullptr) {
|
|
|
+ lb_emit_store(p, temp_data[i].gep, temp_data[i].value);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
}
|
|
|
|
|
|
return v;
|