Browse Source

Fix #2017 mismatched types in binary matrix expression for `flt * (mat * vec)`

gingerBill 2 years ago
parent
commit
4998cf80c1
2 changed files with 22 additions and 6 deletions
  1. 4 1
      src/check_expr.cpp
  2. 18 5
      src/llvm_backend_expr.cpp

+ 4 - 1
src/check_expr.cpp

@@ -3039,8 +3039,8 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand
 		x->type = xt;
 		goto matrix_success;
 	} else {
-		GB_ASSERT(is_type_matrix(yt));
 		GB_ASSERT(!is_type_matrix(xt));
+		GB_ASSERT(is_type_matrix(yt));
 		
 		if (op.kind == Token_Mul) {
 			// NOTE(bill): no need to handle the matrix case here since it should be handled above
@@ -3061,6 +3061,9 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand
 					x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count);
 				}
 				goto matrix_success;
+			} else if (are_types_identical(yt->Matrix.elem, xt)) {
+				x->type = check_matrix_type_hint(y->type, type_hint);
+				return;
 			}
 		}
 		if (!are_types_identical(xt, yt)) {

+ 18 - 5
src/llvm_backend_expr.cpp

@@ -1,4 +1,4 @@
-lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise=false);
+lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise);
 
 lbValue lb_emit_logical_binary_expr(lbProcedure *p, TokenKind op, Ast *left, Ast *right, Type *type) {
 	lbModule *m = p->module;
@@ -987,7 +987,6 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise) {
 	GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
 
-
 	if (op == Token_Mul && !component_wise) {
 		Type *xt = base_type(lhs.type);
 		Type *yt = base_type(rhs.type);
@@ -1001,8 +1000,22 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue
 		} else if (is_type_array_like(xt)) {
 			GB_ASSERT(yt->kind == Type_Matrix);
 			return lb_emit_vector_mul_matrix(p, lhs, rhs, type);
-		}
+		} else {
+			GB_ASSERT(xt->kind == Type_Basic);
+			GB_ASSERT(yt->kind == Type_Matrix);
+			GB_ASSERT(is_type_matrix(type));
+
+			Type *array_type = alloc_type_array(yt->Matrix.elem, matrix_type_total_internal_elems(yt));
+			GB_ASSERT(type_size_of(array_type) == type_size_of(yt));
 
+			lbValue array_lhs = lb_emit_conv(p, lhs, array_type);
+			lbValue array_rhs = rhs;
+			array_rhs.type = array_type;
+
+			lbValue array = lb_emit_arith(p, op, array_lhs, array_rhs, array_type);
+			array.type = type;
+			return array;
+		}
 	} else {
 		if (is_type_matrix(lhs.type)) {
 			rhs = lb_emit_conv(p, rhs, lhs.type);
@@ -1047,7 +1060,7 @@ lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Ty
 	if (is_type_array_like(lhs.type) || is_type_array_like(rhs.type)) {
 		return lb_emit_arith_array(p, op, lhs, rhs, type);
 	} else if (is_type_matrix(lhs.type) || is_type_matrix(rhs.type)) {
-		return lb_emit_arith_matrix(p, op, lhs, rhs, type);
+		return lb_emit_arith_matrix(p, op, lhs, rhs, type, false);
 	} else if (is_type_complex(type)) {
 		lhs = lb_emit_conv(p, lhs, type);
 		rhs = lb_emit_conv(p, rhs, type);
@@ -1320,7 +1333,7 @@ lbValue lb_build_binary_expr(lbProcedure *p, Ast *expr) {
 	if (is_type_matrix(be->left->tav.type) || is_type_matrix(be->right->tav.type)) {
 		lbValue left = lb_build_expr(p, be->left);
 		lbValue right = lb_build_expr(p, be->right);
-		return lb_emit_arith_matrix(p, be->op.kind, left, right, default_type(tv.type));
+		return lb_emit_arith_matrix(p, be->op.kind, left, right, default_type(tv.type), false);
 	}