Browse Source

Allow casting between square matrices of the same element type

gingerBill 3 years ago
parent
commit
e0b9475378
3 changed files with 56 additions and 17 deletions
  1. 19 0
      src/check_expr.cpp
  2. 5 5
      src/check_type.cpp
  3. 32 12
      src/llvm_backend_expr.cpp

+ 19 - 0
src/check_expr.cpp

@@ -2460,6 +2460,24 @@ bool check_is_castable_to(CheckerContext *c, Operand *operand, Type *y) {
 	if (is_type_quaternion(src) && is_type_quaternion(dst)) {
 	if (is_type_quaternion(src) && is_type_quaternion(dst)) {
 		return true;
 		return true;
 	}
 	}
+	
+	if (is_type_matrix(src) && is_type_matrix(dst)) {
+		GB_ASSERT(src->kind == Type_Matrix);
+		GB_ASSERT(dst->kind == Type_Matrix);
+		if (!are_types_identical(src->Matrix.elem, dst->Matrix.elem)) {
+			return false;
+		}
+		
+		if (src->Matrix.row_count != src->Matrix.column_count) {
+			return false;
+		}
+		
+		if (dst->Matrix.row_count != dst->Matrix.column_count) {
+			return false;
+		}
+		
+		return true;
+	}
 
 
 
 
 	// Cast between pointers
 	// Cast between pointers
@@ -8838,6 +8856,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type
 	case Ast_EnumType:
 	case Ast_EnumType:
 	case Ast_MapType:
 	case Ast_MapType:
 	case Ast_BitSetType:
 	case Ast_BitSetType:
+	case Ast_MatrixType:
 		o->mode = Addressing_Type;
 		o->mode = Addressing_Type;
 		o->type = check_type(c, node);
 		o->type = check_type(c, node);
 		break;
 		break;

+ 5 - 5
src/check_type.cpp

@@ -1154,7 +1154,11 @@ Type *determine_type_from_polymorphic(CheckerContext *ctx, Type *poly_type, Oper
 	bool show_error = modify_type && !ctx->hide_polymorphic_errors;
 	bool show_error = modify_type && !ctx->hide_polymorphic_errors;
 	if (!is_operand_value(operand)) {
 	if (!is_operand_value(operand)) {
 		if (show_error) {
 		if (show_error) {
-			error(operand.expr, "Cannot determine polymorphic type from parameter");
+			gbString pts = type_to_string(poly_type);
+			gbString ots = type_to_string(operand.type);
+			defer (gb_string_free(pts));
+			defer (gb_string_free(ots));
+			error(operand.expr, "Cannot determine polymorphic type from parameter: '%s' to '%s'", ots, pts);
 		}
 		}
 		return t_invalid;
 		return t_invalid;
 	}
 	}
@@ -2839,10 +2843,6 @@ bool check_type_internal(CheckerContext *ctx, Ast *e, Type **type, Type *named_t
 	
 	
 	
 	
 	case_ast_node(mt, MatrixType, e);
 	case_ast_node(mt, MatrixType, e);
-		bool ips = ctx->in_polymorphic_specialization;
-		defer (ctx->in_polymorphic_specialization = ips);
-		ctx->in_polymorphic_specialization = false;
-
 		check_matrix_type(ctx, type, e);
 		check_matrix_type(ctx, type, e);
 		set_base_type(named_type, *type);
 		set_base_type(named_type, *type);
 		return true;
 		return true;

+ 32 - 12
src/llvm_backend_expr.cpp

@@ -476,7 +476,7 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r
 	}
 	}
 }
 }
 
 
-bool lb_matrix_elem_simple(Type *t) {
+bool lb_is_matrix_simdable(Type *t) {
 	Type *mt = base_type(t);
 	Type *mt = base_type(t);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	
 	
@@ -555,7 +555,7 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
 	Type *mt = base_type(m.type);
 	Type *mt = base_type(m.type);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	
 	
-	if (lb_matrix_elem_simple(mt)) {
+	if (lb_is_matrix_simdable(mt)) {
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 		unsigned row_count    = cast(unsigned)mt->Matrix.row_count;
 		unsigned row_count    = cast(unsigned)mt->Matrix.row_count;
 		unsigned column_count = cast(unsigned)mt->Matrix.column_count;
 		unsigned column_count = cast(unsigned)mt->Matrix.column_count;
@@ -623,7 +623,7 @@ lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) {
 	Type *mt = base_type(m.type);
 	Type *mt = base_type(m.type);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	GB_ASSERT(mt->kind == Type_Matrix);
 	
 	
-	if (lb_matrix_elem_simple(mt)) {
+	if (lb_is_matrix_simdable(mt)) {
 		LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m);
 		LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m);
 		return lb_matrix_cast_vector_to_type(p, vector, type);
 		return lb_matrix_cast_vector_to_type(p, vector, type);
 	}
 	}
@@ -690,7 +690,7 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
 	unsigned inner         = cast(unsigned)xt->Matrix.column_count;
 	unsigned inner         = cast(unsigned)xt->Matrix.column_count;
 	unsigned outer_columns = cast(unsigned)yt->Matrix.column_count;
 	unsigned outer_columns = cast(unsigned)yt->Matrix.column_count;
 		
 		
-	if (lb_matrix_elem_simple(xt)) {
+	if (lb_is_matrix_simdable(xt)) {
 		unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
 		unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
 		unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
 		unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
 		
 		
@@ -773,7 +773,7 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 	
 	
 	Type *elem = mt->Matrix.elem;
 	Type *elem = mt->Matrix.elem;
 	
 	
-	if (lb_matrix_elem_simple(mt)) {
+	if (lb_is_matrix_simdable(mt)) {
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 		
 		
 		unsigned row_count = cast(unsigned)mt->Matrix.row_count;
 		unsigned row_count = cast(unsigned)mt->Matrix.row_count;
@@ -819,9 +819,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 			
 			
 			lbValue a = lb_emit_matrix_ev(p, lhs, i, j);
 			lbValue a = lb_emit_matrix_ev(p, lhs, i, j);
 			lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j);
 			lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j);
-			lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
-			lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem);
-			lb_emit_store(p, dst, d);
+			lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
+			lb_emit_store(p, dst, c);
 		}
 		}
 	}
 	}
 	
 	
@@ -842,7 +841,7 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 	
 	
 	Type *elem = mt->Matrix.elem;
 	Type *elem = mt->Matrix.elem;
 	
 	
-	if (lb_matrix_elem_simple(mt)) {
+	if (lb_is_matrix_simdable(mt)) {
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 		
 		
 		unsigned row_count = cast(unsigned)mt->Matrix.row_count;
 		unsigned row_count = cast(unsigned)mt->Matrix.row_count;
@@ -903,9 +902,8 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 			
 			
 			lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k);
 			lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k);
 			lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
 			lbValue b = lb_emit_matrix_ev(p, rhs, k, j);
-			lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem);
-			lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem);
-			lb_emit_store(p, dst, d);
+			lbValue c = lb_emit_mul_add(p, a, b, d0, elem);
+			lb_emit_store(p, dst, c);
 		}
 		}
 	}
 	}
 	
 	
@@ -1938,6 +1936,28 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) {
 		
 		
 		return lb_addr_load(p, v);
 		return lb_addr_load(p, v);
 	}
 	}
+	
+	if (is_type_matrix(dst) && is_type_matrix(src)) {
+		GB_ASSERT(dst->kind == Type_Matrix);
+		GB_ASSERT(src->kind == Type_Matrix);
+		lbAddr v = lb_add_local_generated(p, t, true);
+		for (i64 j = 0; j < dst->Matrix.column_count; j++) {
+			for (i64 i = 0; i < dst->Matrix.row_count; i++) {
+				if (i < src->Matrix.row_count && j < src->Matrix.column_count) {
+					lbValue d = lb_emit_matrix_epi(p, v.addr, i, j);
+					lbValue s = lb_emit_matrix_ev(p, value, i, j);
+					lb_emit_store(p, d, s);
+				} else if (i == j) {
+					lbValue d = lb_emit_matrix_epi(p, v.addr, i, j);
+					lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true);
+					lb_emit_store(p, d, s);
+				}
+			}
+		}
+		return lb_addr_load(p, v);
+	}	
+	
+	
 
 
 	if (is_type_any(dst)) {
 	if (is_type_any(dst)) {
 		if (is_type_untyped_nil(src)) {
 		if (is_type_untyped_nil(src)) {