|
@@ -509,12 +509,16 @@ LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) {
|
|
GB_ASSERT(mt->kind == Type_Matrix);
|
|
GB_ASSERT(mt->kind == Type_Matrix);
|
|
GB_ASSERT(lb_matrix_elem_simple(mt));
|
|
GB_ASSERT(lb_matrix_elem_simple(mt));
|
|
|
|
|
|
- unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
|
|
|
+
|
|
|
|
+ i64 stride = matrix_type_stride_in_elems(mt);
|
|
|
|
+ i64 rows = mt->Matrix.row_count;
|
|
|
|
+ i64 columns = mt->Matrix.column_count;
|
|
|
|
+ unsigned elem_count = cast(unsigned)(rows*columns);
|
|
|
|
|
|
Type *elem = mt->Matrix.elem;
|
|
Type *elem = mt->Matrix.elem;
|
|
LLVMTypeRef elem_type = lb_type(m, elem);
|
|
LLVMTypeRef elem_type = lb_type(m, elem);
|
|
|
|
|
|
- LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
|
|
|
|
|
|
+ LLVMTypeRef vector_type = LLVMVectorType(elem_type, elem_count);
|
|
LLVMTypeRef types[] = {vector_type};
|
|
LLVMTypeRef types[] = {vector_type};
|
|
|
|
|
|
char const *name = "llvm.matrix.column.major.load";
|
|
char const *name = "llvm.matrix.column.major.load";
|
|
@@ -524,44 +528,18 @@ LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) {
|
|
|
|
|
|
lbValue ptr = lb_address_from_load_or_generate_local(p, lhs);
|
|
lbValue ptr = lb_address_from_load_or_generate_local(p, lhs);
|
|
ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
|
|
ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
|
|
-
|
|
|
|
|
|
+
|
|
LLVMValueRef values[5] = {};
|
|
LLVMValueRef values[5] = {};
|
|
values[0] = ptr.value;
|
|
values[0] = ptr.value;
|
|
- values[1] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
|
|
|
|
- values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
|
|
|
|
- values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
|
|
|
|
- values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
|
|
|
|
-
|
|
|
|
- return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
|
|
|
|
-}
|
|
|
|
-LLVMValueRef llvm_matrix_column_major_load_from_ptr(lbProcedure *p, lbValue ptr) {
|
|
|
|
- lbModule *m = p->module;
|
|
|
|
-
|
|
|
|
- Type *mt = base_type(type_deref(ptr.type));
|
|
|
|
- GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
|
- GB_ASSERT(lb_matrix_elem_simple(mt));
|
|
|
|
-
|
|
|
|
- unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
|
-
|
|
|
|
- Type *elem = mt->Matrix.elem;
|
|
|
|
- LLVMTypeRef elem_type = lb_type(m, elem);
|
|
|
|
-
|
|
|
|
- LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
|
|
|
|
- LLVMTypeRef types[] = {vector_type};
|
|
|
|
-
|
|
|
|
- char const *name = "llvm.matrix.column.major.load";
|
|
|
|
- unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
|
|
|
|
- GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
|
|
|
|
- LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
|
|
|
|
-
|
|
|
|
- LLVMValueRef values[5] = {};
|
|
|
|
- values[0] = lb_emit_matrix_epi(p, ptr, 0, 0).value;
|
|
|
|
- values[1] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
|
|
|
|
|
|
+ values[1] = lb_const_int(m, t_u64, stride).value;
|
|
values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
|
|
values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
|
|
values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
|
|
values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
|
|
values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
|
|
values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
|
|
|
|
|
|
- return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
|
|
|
|
|
|
+ LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
|
|
|
|
+ gb_printf_err("%s\n", LLVMPrintValueToString(call));
|
|
|
|
+ // LLVMAddAttributeAtIndex(call, 0, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt)));
|
|
|
|
+ return call;
|
|
}
|
|
}
|
|
|
|
|
|
void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef vector_value) {
|
|
void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef vector_value) {
|
|
@@ -571,12 +549,7 @@ void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef ve
|
|
GB_ASSERT(mt->kind == Type_Matrix);
|
|
GB_ASSERT(mt->kind == Type_Matrix);
|
|
GB_ASSERT(lb_matrix_elem_simple(mt));
|
|
GB_ASSERT(lb_matrix_elem_simple(mt));
|
|
|
|
|
|
- unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
|
-
|
|
|
|
- Type *elem = mt->Matrix.elem;
|
|
|
|
- LLVMTypeRef elem_type = lb_type(m, elem);
|
|
|
|
-
|
|
|
|
- LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
|
|
|
|
|
|
+ LLVMTypeRef vector_type = LLVMTypeOf(vector_value);
|
|
LLVMTypeRef types[] = {vector_type};
|
|
LLVMTypeRef types[] = {vector_type};
|
|
|
|
|
|
char const *name = "llvm.matrix.column.major.store";
|
|
char const *name = "llvm.matrix.column.major.store";
|
|
@@ -587,56 +560,26 @@ void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef ve
|
|
lbValue ptr = lb_addr_get_ptr(p, addr);
|
|
lbValue ptr = lb_addr_get_ptr(p, addr);
|
|
ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
|
|
ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
|
|
|
|
|
|
- GB_ASSERT(LLVMTypeOf(vector_value) == vector_type);
|
|
|
|
unsigned vector_size = LLVMGetVectorSize(vector_type);
|
|
unsigned vector_size = LLVMGetVectorSize(vector_type);
|
|
GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
|
|
GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
|
|
|
|
|
|
- LLVMValueRef values[6] = {};
|
|
|
|
- values[0] = vector_value;
|
|
|
|
- values[1] = ptr.value;
|
|
|
|
- values[2] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
|
|
|
|
- values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
|
|
|
|
- values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
|
|
|
|
- values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
|
|
|
|
-
|
|
|
|
- LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-void llvm_matrix_column_major_store_to_raw_ptr(lbProcedure *p, Type *mt, lbValue ptr, LLVMValueRef vector_value) {
|
|
|
|
- lbModule *m = p->module;
|
|
|
|
-
|
|
|
|
- mt = base_type(mt);
|
|
|
|
- GB_ASSERT(mt->kind == Type_Matrix);
|
|
|
|
- GB_ASSERT(lb_matrix_elem_simple(mt));
|
|
|
|
-
|
|
|
|
- unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt);
|
|
|
|
-
|
|
|
|
- Type *elem = mt->Matrix.elem;
|
|
|
|
- LLVMTypeRef elem_type = lb_type(m, elem);
|
|
|
|
-
|
|
|
|
- LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count);
|
|
|
|
- LLVMTypeRef types[] = {vector_type};
|
|
|
|
-
|
|
|
|
- char const *name = "llvm.matrix.column.major.store";
|
|
|
|
- unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
|
|
|
|
- GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
|
|
|
|
- LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
|
|
|
|
-
|
|
|
|
- GB_ASSERT(LLVMTypeOf(vector_value) == vector_type);
|
|
|
|
- unsigned vector_size = LLVMGetVectorSize(vector_type);
|
|
|
|
- GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
|
|
|
|
|
|
+ i64 stride = matrix_type_stride_in_elems(mt);
|
|
|
|
|
|
LLVMValueRef values[6] = {};
|
|
LLVMValueRef values[6] = {};
|
|
values[0] = vector_value;
|
|
values[0] = vector_value;
|
|
values[1] = ptr.value;
|
|
values[1] = ptr.value;
|
|
- values[2] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width
|
|
|
|
|
|
+ values[2] = lb_const_int(m, t_u64, stride).value;
|
|
values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
|
|
values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
|
|
values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
|
|
values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
|
|
values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
|
|
values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
|
|
|
|
|
|
- LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
|
|
|
|
|
|
+ LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
|
|
|
|
+ gb_printf_err("%s\n", LLVMPrintValueToString(call));
|
|
|
|
+ // LLVMAddAttributeAtIndex(call, 1, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt)));
|
|
|
|
+ gb_unused(call);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+
|
|
LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, i64 outer_rows, i64 inner, i64 outer_columns) {
|
|
LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, i64 outer_rows, i64 inner, i64 outer_columns) {
|
|
lbModule *m = p->module;
|
|
lbModule *m = p->module;
|
|
|
|
|
|
@@ -648,6 +591,7 @@ LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b
|
|
LLVMTypeRef elem_type = LLVMGetElementType(a_type);
|
|
LLVMTypeRef elem_type = LLVMGetElementType(a_type);
|
|
|
|
|
|
LLVMTypeRef res_vector_type = LLVMVectorType(elem_type, cast(unsigned)(outer_rows*outer_columns));
|
|
LLVMTypeRef res_vector_type = LLVMVectorType(elem_type, cast(unsigned)(outer_rows*outer_columns));
|
|
|
|
+
|
|
LLVMTypeRef types[] = {res_vector_type, a_type, b_type};
|
|
LLVMTypeRef types[] = {res_vector_type, a_type, b_type};
|
|
|
|
|
|
char const *name = "llvm.matrix.multiply";
|
|
char const *name = "llvm.matrix.multiply";
|
|
@@ -662,7 +606,9 @@ LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b
|
|
values[3] = lb_const_int(m, t_u32, inner).value;
|
|
values[3] = lb_const_int(m, t_u32, inner).value;
|
|
values[4] = lb_const_int(m, t_u32, outer_columns).value;
|
|
values[4] = lb_const_int(m, t_u32, outer_columns).value;
|
|
|
|
|
|
- return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
|
|
|
|
|
|
+ LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
|
|
|
|
+ gb_printf_err("%s\n", LLVMPrintValueToString(call));
|
|
|
|
+ return call;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -684,19 +630,13 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
|
|
// TODO(bill): LLVM ERROR: Do not know how to split the result of this operator!
|
|
// TODO(bill): LLVM ERROR: Do not know how to split the result of this operator!
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
|
|
- lbValue res_ptr = lb_addr_get_ptr(p, res);
|
|
|
|
- res_ptr = lb_emit_matrix_epi(p, res_ptr, 0, 0);
|
|
|
|
-
|
|
|
|
- lbValue lhs_ptr = lb_address_from_load_or_generate_local(p, lhs);
|
|
|
|
- lbValue rhs_ptr = lb_address_from_load_or_generate_local(p, rhs);
|
|
|
|
- LLVMValueRef a = llvm_matrix_column_major_load_from_ptr(p, lhs_ptr);
|
|
|
|
- LLVMValueRef b = llvm_matrix_column_major_load_from_ptr(p, rhs_ptr);
|
|
|
|
- LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count);
|
|
|
|
-
|
|
|
|
- llvm_matrix_column_major_store_to_raw_ptr(p, type, res_ptr, c);
|
|
|
|
|
|
+ LLVMValueRef a = llvm_matrix_column_major_load(p, lhs); gb_unused(a);
|
|
|
|
+ LLVMValueRef b = llvm_matrix_column_major_load(p, rhs); gb_unused(b);
|
|
|
|
+ LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count); gb_unused(c);
|
|
|
|
+ llvm_matrix_column_major_store(p, res, c);
|
|
|
|
|
|
return lb_addr_load(p, res);
|
|
return lb_addr_load(p, res);
|
|
- }
|
|
|
|
|
|
+ }
|
|
|
|
|
|
slow_form:
|
|
slow_form:
|
|
{
|
|
{
|
|
@@ -704,18 +644,21 @@ slow_form:
|
|
|
|
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
|
|
- for (i64 i = 0; i < xt->Matrix.row_count; i++) {
|
|
|
|
- for (i64 j = 0; j < yt->Matrix.column_count; j++) {
|
|
|
|
- for (i64 k = 0; k < xt->Matrix.column_count; k++) {
|
|
|
|
|
|
+ i64 outer_rows = xt->Matrix.row_count;
|
|
|
|
+ i64 inner = xt->Matrix.column_count;
|
|
|
|
+ i64 outer_columns = yt->Matrix.column_count;
|
|
|
|
+
|
|
|
|
+ for (i64 j = 0; j < outer_columns; j++) {
|
|
|
|
+ for (i64 i = 0; i < outer_rows; i++) {
|
|
|
|
+ for (i64 k = 0; k < inner; k++) {
|
|
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
|
|
|
+ lbValue d0 = lb_emit_load(p, dst);
|
|
|
|
|
|
lbValue a = lb_emit_matrix_ev(p, lhs, i, k);
|
|
lbValue a = lb_emit_matrix_ev(p, lhs, i, k);
|
|
lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
|
|
lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
|
|
lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
|
|
lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
|
|
- lbValue d = lb_emit_load(p, dst);
|
|
|
|
- lbValue e = lb_emit_arith(p, Token_Add, d, c, elem);
|
|
|
|
- lb_emit_store(p, dst, e);
|
|
|
|
-
|
|
|
|
|
|
+ lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem);
|
|
|
|
+ lb_emit_store(p, dst, d);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -724,6 +667,72 @@ slow_form:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
|
|
|
|
+ Type *mt = base_type(lhs.type);
|
|
|
|
+ Type *vt = base_type(rhs.type);
|
|
|
|
+
|
|
|
|
+ GB_ASSERT(is_type_matrix(mt));
|
|
|
|
+ GB_ASSERT(is_type_array_like(vt));
|
|
|
|
+
|
|
|
|
+ i64 vector_count = get_array_type_count(vt);
|
|
|
|
+
|
|
|
|
+ GB_ASSERT(mt->Matrix.column_count == vector_count);
|
|
|
|
+ GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
|
|
|
|
+
|
|
|
|
+ Type *elem = mt->Matrix.elem;
|
|
|
|
+
|
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
+
|
|
|
|
+ for (i64 i = 0; i < mt->Matrix.row_count; i++) {
|
|
|
|
+ for (i64 j = 0; j < mt->Matrix.column_count; j++) {
|
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, 0);
|
|
|
|
+ lbValue d0 = lb_emit_load(p, dst);
|
|
|
|
+
|
|
|
|
+ lbValue a = lb_emit_matrix_ev(p, lhs, i, j);
|
|
|
|
+ lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j);
|
|
|
|
+ lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
|
|
|
|
+ lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem);
|
|
|
|
+ lb_emit_store(p, dst, d);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return lb_addr_load(p, res);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
|
|
|
|
+ Type *mt = base_type(rhs.type);
|
|
|
|
+ Type *vt = base_type(lhs.type);
|
|
|
|
+
|
|
|
|
+ GB_ASSERT(is_type_matrix(mt));
|
|
|
|
+ GB_ASSERT(is_type_array_like(vt));
|
|
|
|
+
|
|
|
|
+ i64 vector_count = get_array_type_count(vt);
|
|
|
|
+
|
|
|
|
+ GB_ASSERT(mt->Matrix.row_count == vector_count);
|
|
|
|
+ GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt)));
|
|
|
|
+
|
|
|
|
+ Type *elem = mt->Matrix.elem;
|
|
|
|
+
|
|
|
|
+ lbAddr res = lb_add_local_generated(p, type, true);
|
|
|
|
+
|
|
|
|
+ for (i64 j = 0; j < mt->Matrix.column_count; j++) {
|
|
|
|
+ for (i64 k = 0; k < mt->Matrix.row_count; k++) {
|
|
|
|
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, 0, j);
|
|
|
|
+ lbValue d0 = lb_emit_load(p, dst);
|
|
|
|
+
|
|
|
|
+ lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k);
|
|
|
|
+ lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
|
|
|
|
+ lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
|
|
|
|
+ lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem);
|
|
|
|
+ lb_emit_store(p, dst, d);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return lb_addr_load(p, res);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
|
|
lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
|
|
lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
|
|
GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
|
|
GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
|
|
@@ -735,7 +744,12 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue
|
|
if (xt->kind == Type_Matrix) {
|
|
if (xt->kind == Type_Matrix) {
|
|
if (yt->kind == Type_Matrix) {
|
|
if (yt->kind == Type_Matrix) {
|
|
return lb_emit_matrix_mul(p, lhs, rhs, type);
|
|
return lb_emit_matrix_mul(p, lhs, rhs, type);
|
|
|
|
+ } else if (is_type_array_like(yt)) {
|
|
|
|
+ return lb_emit_matrix_mul_vector(p, lhs, rhs, type);
|
|
}
|
|
}
|
|
|
|
+ } else if (is_type_array_like(xt)) {
|
|
|
|
+ GB_ASSERT(yt->kind == Type_Matrix);
|
|
|
|
+ return lb_emit_vector_mul_matrix(p, lhs, rhs, type);
|
|
}
|
|
}
|
|
|
|
|
|
} else {
|
|
} else {
|
|
@@ -1036,6 +1050,13 @@ lbValue lb_build_binary_expr(lbProcedure *p, Ast *expr) {
|
|
ast_node(be, BinaryExpr, expr);
|
|
ast_node(be, BinaryExpr, expr);
|
|
|
|
|
|
TypeAndValue tv = type_and_value_of_expr(expr);
|
|
TypeAndValue tv = type_and_value_of_expr(expr);
|
|
|
|
+
|
|
|
|
+ if (is_type_matrix(be->left->tav.type) || is_type_matrix(be->right->tav.type)) {
|
|
|
|
+ lbValue left = lb_build_expr(p, be->left);
|
|
|
|
+ lbValue right = lb_build_expr(p, be->right);
|
|
|
|
+ return lb_emit_arith_matrix(p, be->op.kind, left, right, default_type(tv.type));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
|
|
switch (be->op.kind) {
|
|
switch (be->op.kind) {
|
|
case Token_Add:
|
|
case Token_Add:
|