Bläddra i källkod

Support rank-2 arrays (matrix-like) for `transpose`

gingerBill 3 år sedan
förälder
incheckning
19aec13a10
3 ändrade filer med 76 tillägg och 2 borttagningar
  1. 36 2
      src/check_builtin.cpp
  2. 21 0
      src/llvm_backend_expr.cpp
  3. 19 0
      src/types.cpp

+ 36 - 2
src/check_builtin.cpp

@@ -2183,9 +2183,43 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 		}
 		
 		operand->mode = Addressing_Value;
-		if (is_type_array(t)) {
+		if (t->kind == Type_Array) {
+			i32 rank = type_math_rank(t);
 			// Do nothing
-			operand->type = x.type;			
+			operand->type = x.type;
+			if (rank > 2) {
+				gbString s = type_to_string(x.type);
+				error(call, "'%.*s' expects a matrix or array with a rank of 2, got %s of rank %d", LIT(builtin_name), s, rank);
+				gb_string_free(s);
+				return false;
+			} else if (rank == 2) {
+				Type *inner = base_type(t->Array.elem);
+				GB_ASSERT(inner->kind == Type_Array);
+				Type *elem = inner->Array.elem;
+				Type *array_inner = alloc_type_array(elem, t->Array.count);
+				Type *array_outer = alloc_type_array(array_inner, inner->Array.count);
+				operand->type = array_outer;
+
+				i64 elements = t->Array.count*inner->Array.count;
+				i64 size = type_size_of(operand->type);
+				if (!is_type_valid_for_matrix_elems(elem)) {
+					gbString s = type_to_string(x.type);
+					error(call, "'%.*s' expects a matrix or array with a base element type of an integer, float, or complex number, got %s", LIT(builtin_name), s);
+					gb_string_free(s);
+				} else if (elements > MATRIX_ELEMENT_COUNT_MAX) {
+					gbString s = type_to_string(x.type);
+					error(call, "'%.*s' expects a matrix or array with a maximum of %d elements, got %s with %lld elements", LIT(builtin_name), MATRIX_ELEMENT_COUNT_MAX, s, elements);
+					gb_string_free(s);
+				} else if (elements > MATRIX_ELEMENT_COUNT_MAX) {
+					gbString s = type_to_string(x.type);
+					error(call, "'%.*s' expects a matrix or array with non-zero elements, got %s", LIT(builtin_name), MATRIX_ELEMENT_COUNT_MAX, s);
+					gb_string_free(s);
+				} else if (size > MATRIX_ELEMENT_MAX_SIZE) {
+					gbString s = type_to_string(x.type);
+					error(call, "Too large of a type for '%.*s', got %s of size %lld, maximum size %d", LIT(builtin_name), s, cast(long long)size, MATRIX_ELEMENT_MAX_SIZE);
+					gb_string_free(s);
+				}
+			}
 		} else {
 			GB_ASSERT(t->kind == Type_Matrix);
 			operand->type = alloc_type_matrix(t->Matrix.elem, t->Matrix.column_count, t->Matrix.row_count);

+ 21 - 0
src/llvm_backend_expr.cpp

@@ -580,6 +580,27 @@ LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
 
 lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
 	if (is_type_array(m.type)) {
+		i32 rank = type_math_rank(m.type);
+		if (rank == 2) {
+			lbAddr addr = lb_add_local_generated(p, type, false);
+			lbValue dst = addr.addr;
+			lbValue src = m;
+			i32 n = cast(i32)get_array_type_count(m.type);
+			i32 m = cast(i32)get_array_type_count(type);
+			// m.type == [n][m]T
+			// type   == [m][n]T
+
+			for (i32 j = 0; j < m; j++) {
+				lbValue dst_col = lb_emit_struct_ep(p, dst, j);
+				for (i32 i = 0; i < n; i++) {
+					lbValue dst_row = lb_emit_struct_ep(p, dst_col, i);
+					lbValue src_col = lb_emit_struct_ev(p, src, i);
+					lbValue src_row = lb_emit_struct_ev(p, src_col, j);
+					lb_emit_store(p, dst_row, src_row);
+				}
+			}
+			return lb_addr_load(p, addr);
+		}
 		// no-op
 		m.type = type;
 		return m;

+ 19 - 0
src/types.cpp

@@ -363,6 +363,7 @@ enum TypeInfoFlag : u32 {
 enum : int {
 	MATRIX_ELEMENT_COUNT_MIN = 1,
 	MATRIX_ELEMENT_COUNT_MAX = 16,
+	MATRIX_ELEMENT_MAX_SIZE = MATRIX_ELEMENT_COUNT_MAX * (2 * 8), // complex128
 };
 
 
@@ -1583,6 +1584,24 @@ Type *core_array_type(Type *t) {
 	}
 }
 
+i32 type_math_rank(Type *t) {
+	i32 rank = 0;
+	for (;;) {
+		t = base_type(t);
+		switch (t->kind) {
+		case Type_Array:
+			rank += 1;
+			t = t->Array.elem;
+			break;
+		case Type_Matrix:
+			rank += 2;
+			t = t->Matrix.elem;
+			break;
+		default:
+			return rank;
+		}
+	}
+}
 
 
 Type *base_complex_elem_type(Type *t) {