Browse Source

Correct matrix builtins for `#row_major`

gingerBill 1 year ago
parent
commit
18fb665bf6
3 changed files with 12 additions and 5 deletions
  1. 2 2
      src/check_builtin.cpp
  2. 9 2
      src/check_expr.cpp
  3. 1 1
      src/types.cpp

+ 2 - 2
src/check_builtin.cpp

@@ -3488,7 +3488,7 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 			}
 		} else {
 			GB_ASSERT(t->kind == Type_Matrix);
-			operand->type = alloc_type_matrix(t->Matrix.elem, t->Matrix.column_count, t->Matrix.row_count);
+			operand->type = alloc_type_matrix(t->Matrix.elem, t->Matrix.column_count, t->Matrix.row_count, nullptr, nullptr, t->Matrix.is_row_major);
 		}
 		operand->type = check_matrix_type_hint(operand->type, type_hint);
 		break;
@@ -3556,7 +3556,7 @@ gb_internal bool check_builtin_procedure(CheckerContext *c, Operand *operand, As
 		}
 		
 		operand->mode = Addressing_Value;
-		operand->type = alloc_type_matrix(elem, xt->Array.count, yt->Array.count);	
+		operand->type = alloc_type_matrix(elem, xt->Array.count, yt->Array.count, nullptr, nullptr, false);
 		operand->type = check_matrix_type_hint(operand->type, type_hint);
 		break;
 	}

+ 9 - 2
src/check_expr.cpp

@@ -3397,6 +3397,13 @@ gb_internal Type *check_matrix_type_hint(Type *matrix, Type *type_hint) {
 		Type *th = base_type(type_hint);
 		if (are_types_identical(th, xt)) {
 			return type_hint;
+		} else if (xt->kind == Type_Matrix && th->kind == Type_Matrix) {
+			if (!are_types_identical(xt->Matrix.elem, th->Matrix.elem)) {
+				// ignore
+			} if (xt->Matrix.row_count == th->Matrix.row_count &&
+			      xt->Matrix.column_count == th->Matrix.column_count) {
+				return type_hint;
+			}
 		} else if (xt->kind == Type_Matrix && th->kind == Type_Array) {
 			if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) {
 				// ignore
@@ -3461,7 +3468,7 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand
 				if (xt->Matrix.row_count == yt->Array.count) {
 					x->type = y->type;
 				} else {
-					x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1);
+					x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1, nullptr, nullptr, xt->Matrix.is_row_major);
 				}
 				goto matrix_success;
 			}
@@ -3492,7 +3499,7 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand
 				if (yt->Matrix.column_count == xt->Array.count) {
 					x->type = x->type;
 				} else {
-					x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count);
+					x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count, nullptr, nullptr, yt->Matrix.is_row_major);
 				}
 				goto matrix_success;
 			} else if (are_types_identical(yt->Matrix.elem, xt)) {

+ 1 - 1
src/types.cpp

@@ -1003,7 +1003,7 @@ gb_internal Type *alloc_type_array(Type *elem, i64 count, Type *generic_count =
 	return t;
 }
 
-gb_internal Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count, Type *generic_row_count = nullptr, Type *generic_column_count = nullptr, bool is_row_major = false) {
+gb_internal Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count, Type *generic_row_count, Type *generic_column_count, bool is_row_major) {
 	if (generic_row_count != nullptr || generic_column_count != nullptr) {
 		Type *t = alloc_type(Type_Matrix);
 		t->Matrix.elem                 = elem;