瀏覽代碼

Add `#row_major matrix[R, C]T`

As well as `#column_major matrix[R, C]T` as an alias for just `matrix[R, C]T`.
This is because some libraries require a row_major internal layout but still want to be used with row or major oriented vectors.
gingerBill 1 年之前
父節點
當前提交
a750fc0ba6

+ 4 - 0
base/runtime/core.odin

@@ -177,6 +177,10 @@ Type_Info_Matrix :: struct {
 	row_count:    int,
 	column_count: int,
 	// Total element count = column_count * elem_stride
+	layout: enum u8 {
+		Column_Major, // array of column vectors
+		Row_Major,    // array of row vectors
+	},
 }
 Type_Info_Soa_Pointer :: struct {
 	elem: ^Type_Info,

+ 10 - 2
core/fmt/fmt.odin

@@ -2396,7 +2396,11 @@ fmt_matrix :: proc(fi: ^Info, v: any, verb: rune, info: runtime.Type_Info_Matrix
 			for col in 0..<info.column_count {
 				if col > 0 { io.write_string(fi.writer, ", ", &fi.n) }
 
-				offset := (row + col*info.elem_stride)*info.elem_size
+				offset: int
+				switch info.layout {
+				case .Column_Major: offset = (row + col*info.elem_stride)*info.elem_size
+				case .Row_Major:    offset = (col + row*info.elem_stride)*info.elem_size
+				}
 
 				data := uintptr(v.data) + uintptr(offset)
 				fmt_arg(fi, any{rawptr(data), info.elem.id}, verb)
@@ -2410,7 +2414,11 @@ fmt_matrix :: proc(fi: ^Info, v: any, verb: rune, info: runtime.Type_Info_Matrix
 			for col in 0..<info.column_count {
 				if col > 0 { io.write_string(fi.writer, ", ", &fi.n) }
 
-				offset := (row + col*info.elem_stride)*info.elem_size
+				offset: int
+				switch info.layout {
+				case .Column_Major: offset = (row + col*info.elem_stride)*info.elem_size
+				case .Row_Major:    offset = (col + row*info.elem_stride)*info.elem_size
+				}
 
 				data := uintptr(v.data) + uintptr(offset)
 				fmt_arg(fi, any{rawptr(data), info.elem.id}, verb)

+ 4 - 0
core/reflect/types.odin

@@ -173,6 +173,7 @@ are_types_identical :: proc(a, b: ^Type_Info) -> bool {
 		y := b.variant.(Type_Info_Matrix) or_return
 		if x.row_count != y.row_count { return false }
 		if x.column_count != y.column_count { return false }
+		if x.layout != y.layout { return false }
 		return are_types_identical(x.elem, y.elem)
 
 	case Type_Info_Bit_Field:
@@ -689,6 +690,9 @@ write_type_writer :: proc(w: io.Writer, ti: ^Type_Info, n_written: ^int = nil) -
 		write_type(w, info.pointer,      &n) or_return
 		
 	case Type_Info_Matrix:
+		if info.layout == .Row_Major {
+			io.write_string(w, "#row_major ",   &n) or_return
+		}
 		io.write_string(w, "matrix[",               &n) or_return
 		io.write_i64(w, i64(info.row_count), 10,    &n) or_return
 		io.write_string(w, ", ",                    &n) or_return

+ 9 - 3
src/check_expr.cpp

@@ -3431,6 +3431,11 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand
 				if (xt->Matrix.column_count != yt->Matrix.row_count) {
 					goto matrix_error;
 				}
+
+				if (xt->Matrix.is_row_major != yt->Matrix.is_row_major) {
+					goto matrix_error;
+				}
+
 				x->mode = Addressing_Value;
 				if (are_types_identical(xt, yt)) {
 					if (!is_type_named(x->type) && is_type_named(y->type)) {
@@ -3438,7 +3443,8 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand
 						x->type = y->type;
 					}
 				} else {
-					x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count);
+					bool is_row_major = xt->Matrix.is_row_major && yt->Matrix.is_row_major;
+					x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count, nullptr, nullptr, is_row_major);
 				}
 				goto matrix_success;
 			} else if (yt->kind == Type_Array) {
@@ -3452,7 +3458,7 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand
 				
 				// Treat arrays as column vectors
 				x->mode = Addressing_Value;
-				if (type_hint == nullptr && xt->Matrix.row_count == yt->Array.count) {
+				if (xt->Matrix.row_count == yt->Array.count) {
 					x->type = y->type;
 				} else {
 					x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1);
@@ -3483,7 +3489,7 @@ gb_internal void check_binary_matrix(CheckerContext *c, Token const &op, Operand
 				
 				// Treat arrays as row vectors
 				x->mode = Addressing_Value;
-				if (type_hint == nullptr && yt->Matrix.column_count == xt->Array.count) {
+				if (yt->Matrix.column_count == xt->Array.count) {
 					x->type = x->type;
 				} else {
 					x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count);

+ 1 - 1
src/check_type.cpp

@@ -2658,7 +2658,7 @@ gb_internal void check_matrix_type(CheckerContext *ctx, Type **type, Ast *node)
 	}
 type_assign:;
 	
-	*type = alloc_type_matrix(elem, row_count, column_count, generic_row, generic_column);
+	*type = alloc_type_matrix(elem, row_count, column_count, generic_row, generic_column, mt->is_row_major);
 	
 	return;
 }

+ 2 - 2
src/llvm_backend_const.cpp

@@ -1302,11 +1302,11 @@ gb_internal lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bo
 				GB_ASSERT_MSG(elem_count == max_count, "%td != %td", elem_count, max_count);
 
 				LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count);
-				
 				for_array(i, cl->elems) {
 					TypeAndValue tav = cl->elems[i]->tav;
 					GB_ASSERT(tav.mode != Addressing_Invalid);
-					i64 offset = matrix_row_major_index_to_offset(type, i);
+					i64 offset = 0;
+					offset = matrix_row_major_index_to_offset(type, i);
 					values[offset] = lb_const_value(m, elem_type, tav.value, allow_local).value;
 				}
 				for (isize i = 0; i < total_count; i++) {

+ 29 - 10
src/llvm_backend_expr.cpp

@@ -684,12 +684,6 @@ gb_internal lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type
 	Type *mt = base_type(m.type);
 	GB_ASSERT(mt->kind == Type_Matrix);
 
-	// TODO(bill): Determine why this fails on Windows sometimes
-	if (false && lb_is_matrix_simdable(mt)) {
-		LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m);
-		return lb_matrix_cast_vector_to_type(p, vector, type);
-	}
-
 	lbAddr res = lb_add_local_generated(p, type, true);
 
 	i64 row_count = mt->Matrix.row_count;
@@ -763,6 +757,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
 	GB_ASSERT(is_type_matrix(yt));
 	GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count);
 	GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem));
+	GB_ASSERT(xt->Matrix.is_row_major == yt->Matrix.is_row_major);
 
 	Type *elem = xt->Matrix.elem;
 
@@ -770,7 +765,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
 	unsigned inner         = cast(unsigned)xt->Matrix.column_count;
 	unsigned outer_columns = cast(unsigned)yt->Matrix.column_count;
 
-	if (lb_is_matrix_simdable(xt)) {
+	if (!xt->Matrix.is_row_major && lb_is_matrix_simdable(xt)) {
 		unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
 		unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
 
@@ -812,7 +807,7 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
 		return lb_addr_load(p, res);
 	}
 
-	{
+	if (!xt->Matrix.is_row_major) {
 		lbAddr res = lb_add_local_generated(p, type, true);
 
 		auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner);
@@ -835,6 +830,30 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
 			}
 		}
 
+		return lb_addr_load(p, res);
+	} else {
+		lbAddr res = lb_add_local_generated(p, type, true);
+
+		auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner);
+
+		for (unsigned i = 0; i < outer_rows; i++) {
+			for (unsigned j = 0; j < outer_columns; j++) {
+				lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
+				for (unsigned k = 0; k < inner; k++) {
+					inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k);
+					inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j);
+				}
+
+				lbValue sum = lb_const_nil(p->module, elem);
+				for (unsigned k = 0; k < inner; k++) {
+					lbValue a = inners[k][0];
+					lbValue b = inners[k][1];
+					sum = lb_emit_mul_add(p, a, b, sum, elem);
+				}
+				lb_emit_store(p, dst, sum);
+			}
+		}
+
 		return lb_addr_load(p, res);
 	}
 }
@@ -855,7 +874,7 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal
 
 	Type *elem = mt->Matrix.elem;
 
-	if (lb_is_matrix_simdable(mt)) {
+	if (!mt->Matrix.is_row_major && lb_is_matrix_simdable(mt)) {
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 
 		unsigned row_count = cast(unsigned)mt->Matrix.row_count;
@@ -924,7 +943,7 @@ gb_internal lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbVal
 
 	Type *elem = mt->Matrix.elem;
 
-	if (lb_is_matrix_simdable(mt)) {
+	if (!mt->Matrix.is_row_major && lb_is_matrix_simdable(mt)) {
 		unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
 
 		unsigned row_count = cast(unsigned)mt->Matrix.row_count;

+ 2 - 1
src/llvm_backend_type.cpp

@@ -979,12 +979,13 @@ gb_internal void lb_setup_type_info_data_giant_array(lbModule *m, i64 global_typ
 				tag_type = t_type_info_matrix;
 				i64 ez = type_size_of(t->Matrix.elem);
 
-				LLVMValueRef vals[5] = {
+				LLVMValueRef vals[6] = {
 					get_type_info_ptr(m, t->Matrix.elem),
 					lb_const_int(m, t_int, ez).value,
 					lb_const_int(m, t_int, matrix_type_stride_in_elems(t)).value,
 					lb_const_int(m, t_int, t->Matrix.row_count).value,
 					lb_const_int(m, t_int, t->Matrix.column_count).value,
+					lb_const_int(m, t_u8,  cast(u8)t->Matrix.is_row_major).value,
 				};
 
 				variant_value = llvm_const_named_struct(m, tag_type, vals, gb_count_of(vals));

+ 15 - 7
src/llvm_backend_utility.cpp

@@ -1464,14 +1464,16 @@ gb_internal lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isi
 	Type *t = s.type;
 	GB_ASSERT(is_type_pointer(t));
 	Type *mt = base_type(type_deref(t));
-	if (column == 0) {
-		GB_ASSERT_MSG(is_type_matrix(mt) || is_type_array_like(mt), "%s", type_to_string(mt));
-		return lb_emit_epi(p, s, row);
-	} else if (row == 0 && is_type_array_like(mt)) {
-		return lb_emit_epi(p, s, column);
+
+	if (!mt->Matrix.is_row_major) {
+		if (column == 0) {
+			GB_ASSERT_MSG(is_type_matrix(mt) || is_type_array_like(mt), "%s", type_to_string(mt));
+			return lb_emit_epi(p, s, row);
+		} else if (row == 0 && is_type_array_like(mt)) {
+			return lb_emit_epi(p, s, column);
+		}
 	}
 	
-	
 	GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt));
 	
 	isize offset = matrix_indices_to_offset(mt, row, column);
@@ -1491,7 +1493,13 @@ gb_internal lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lb
 	row = lb_emit_conv(p, row, t_int);
 	column = lb_emit_conv(p, column, t_int);
 	
-	LLVMValueRef index = LLVMBuildAdd(p->builder, row.value, LLVMBuildMul(p->builder, column.value, stride_elems, ""), "");
+	LLVMValueRef index = nullptr;
+
+	if (mt->Matrix.is_row_major) {
+		index = LLVMBuildAdd(p->builder, column.value, LLVMBuildMul(p->builder, row.value, stride_elems, ""), "");
+	} else {
+		index = LLVMBuildAdd(p->builder, row.value, LLVMBuildMul(p->builder, column.value, stride_elems, ""), "");
+	}
 
 	LLVMValueRef indices[2] = {
 		LLVMConstInt(lb_type(p->module, t_int), 0, false),

+ 13 - 0
src/parser.cpp

@@ -2329,6 +2329,19 @@ gb_internal Ast *parse_operand(AstFile *f, bool lhs) {
 				break;
 			}
 			return original_type;
+		} else if (name.string == "row_major" ||
+		           name.string == "column_major") {
+			Ast *original_type = parse_type(f);
+			Ast *type = unparen_expr(original_type);
+			switch (type->kind) {
+			case Ast_MatrixType:
+				type->MatrixType.is_row_major = (name.string == "row_major");
+				break;
+			default:
+				syntax_error(type, "Expected a matrix type after #%.*s, got %.*s", LIT(name.string), LIT(ast_strings[type->kind]));
+				break;
+			}
+			return original_type;
 		} else if (name.string == "partial") {
 			Ast *tag = ast_basic_directive(f, token, name);
 			Ast *original_expr = parse_expr(f, lhs);

+ 1 - 0
src/parser.hpp

@@ -772,6 +772,7 @@ AST_KIND(_TypeBegin, "", bool) \
 		Ast *row_count;    \
 		Ast *column_count; \
 		Ast *elem;         \
+		bool is_row_major; \
 	}) \
 AST_KIND(_TypeEnd,  "", bool)
 

+ 15 - 4
src/types.cpp

@@ -281,6 +281,7 @@ struct TypeProc {
 		Type *generic_row_count;                          \
 		Type *generic_column_count;                       \
 		i64   stride_in_bytes;                            \
+		bool  is_row_major;                               \
 	})                                                        \
 	TYPE_KIND(BitField, struct {                              \
 		Scope *         scope;                            \
@@ -1002,7 +1003,7 @@ gb_internal Type *alloc_type_array(Type *elem, i64 count, Type *generic_count =
 	return t;
 }
 
-gb_internal Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count, Type *generic_row_count = nullptr, Type *generic_column_count = nullptr) {
+gb_internal Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count, Type *generic_row_count = nullptr, Type *generic_column_count = nullptr, bool is_row_major = false) {
 	if (generic_row_count != nullptr || generic_column_count != nullptr) {
 		Type *t = alloc_type(Type_Matrix);
 		t->Matrix.elem                 = elem;
@@ -1010,12 +1011,14 @@ gb_internal Type *alloc_type_matrix(Type *elem, i64 row_count, i64 column_count,
 		t->Matrix.column_count         = column_count;
 		t->Matrix.generic_row_count    = generic_row_count;
 		t->Matrix.generic_column_count = generic_column_count;
+		t->Matrix.is_row_major         = is_row_major;
 		return t;
 	}
 	Type *t = alloc_type(Type_Matrix);
 	t->Matrix.elem = elem;
 	t->Matrix.row_count = row_count;
 	t->Matrix.column_count = column_count;
+	t->Matrix.is_row_major = is_row_major;
 	return t;
 }
 
@@ -1512,14 +1515,18 @@ gb_internal i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_inde
 	GB_ASSERT(0 <= row_index && row_index < t->Matrix.row_count);
 	GB_ASSERT(0 <= column_index && column_index < t->Matrix.column_count);
 	i64 stride_elems = matrix_type_stride_in_elems(t);
-	// NOTE(bill): Column-major layout internally
-	return row_index + stride_elems*column_index;
+	if (t->Matrix.is_row_major) {
+		return column_index + stride_elems*row_index;
+	} else {
+		// NOTE(bill): Column-major layout internally
+		return row_index + stride_elems*column_index;
+	}
 }
 
 gb_internal i64 matrix_row_major_index_to_offset(Type *t, i64 index) {
 	t = base_type(t);
 	GB_ASSERT(t->kind == Type_Matrix);
-	
+
 	i64 row_index    = index/t->Matrix.column_count;
 	i64 column_index = index%t->Matrix.column_count;
 	return matrix_indices_to_offset(t, row_index, column_index);
@@ -2690,6 +2697,7 @@ gb_internal bool are_types_identical_internal(Type *x, Type *y, bool check_tuple
 	case Type_Matrix:
 		return x->Matrix.row_count == y->Matrix.row_count &&
 		       x->Matrix.column_count == y->Matrix.column_count &&
+		       x->Matrix.is_row_major == y->Matrix.is_row_major &&
 		       are_types_identical(x->Matrix.elem, y->Matrix.elem);
 
 	case Type_DynamicArray:
@@ -4735,6 +4743,9 @@ gb_internal gbString write_type_to_string(gbString str, Type *type, bool shortha
 		break;
 		
 	case Type_Matrix:
+		if (type->Matrix.is_row_major) {
+			str = gb_string_appendc(str, "#row_major ");
+		}
 		str = gb_string_appendc(str, gb_bprintf("matrix[%d, %d]", cast(int)type->Matrix.row_count, cast(int)type->Matrix.column_count));
 		str = write_type_to_string(str, type->Matrix.elem);
 		break;