Browse Source

Add `matrix_flatten` - `matrix[R, C]T` -> `[R*C]T`

gingerBill 3 years ago
parent
commit
d3abc1a2b4
4 changed files with 106 additions and 9 deletions
  1. 30 0
      src/check_builtin.cpp
  2. 2 0
      src/checker_builtin_procs.hpp
  3. 68 9
      src/llvm_backend_expr.cpp
  4. 6 0
      src/llvm_backend_proc.cpp

+ 30 - 0
src/check_builtin.cpp

@@ -2131,6 +2131,36 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 		break;
 		break;
 	}
 	}
 	
 	
+	case BuiltinProc_matrix_flatten: {
+		Operand x = {};
+		check_expr(c, &x, ce->args[0]);
+		if (x.mode == Addressing_Invalid) {
+			return false;
+		}
+		if (!is_operand_value(x)) {
+			error(call, "'%.*s' expects a matrix or array", LIT(builtin_name));
+			return false;
+		}
+		Type *t = base_type(x.type);
+		if (!is_type_matrix(t) && !is_type_array(t)) {
+			gbString s = type_to_string(x.type);
+			error(call, "'%.*s' expects a matrix or array, got %s", LIT(builtin_name), s);
+			gb_string_free(s);
+			return false;
+		}
+		
+		operand->mode = Addressing_Value;
+		if (is_type_array(t)) {
+			// Do nothing
+			operand->type = x.type;			
+		} else {
+			GB_ASSERT(t->kind == Type_Matrix);
+			operand->type = alloc_type_array(t->Matrix.elem, t->Matrix.row_count*t->Matrix.column_count);
+		}
+		operand->type = check_matrix_type_hint(operand->type, type_hint);
+		break;
+	}
+	
 
 
 	case BuiltinProc_simd_vector: {
 	case BuiltinProc_simd_vector: {
 		Operand x = {};
 		Operand x = {};

+ 2 - 0
src/checker_builtin_procs.hpp

@@ -38,6 +38,7 @@ enum BuiltinProcId {
 	BuiltinProc_transpose,
 	BuiltinProc_transpose,
 	BuiltinProc_outer_product,
 	BuiltinProc_outer_product,
 	BuiltinProc_hadamard_product,
 	BuiltinProc_hadamard_product,
+	BuiltinProc_matrix_flatten,
 
 
 	BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures
 	BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures
 
 
@@ -282,6 +283,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 	{STR_LIT("transpose"),        1, false, Expr_Expr, BuiltinProcPkg_builtin},
 	{STR_LIT("transpose"),        1, false, Expr_Expr, BuiltinProcPkg_builtin},
 	{STR_LIT("outer_product"),    2, false, Expr_Expr, BuiltinProcPkg_builtin},
 	{STR_LIT("outer_product"),    2, false, Expr_Expr, BuiltinProcPkg_builtin},
 	{STR_LIT("hadamard_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin},
 	{STR_LIT("hadamard_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin},
+	{STR_LIT("matrix_flatten"),   1, false, Expr_Expr, BuiltinProcPkg_builtin},
 
 
 	{STR_LIT(""),                 0, true,  Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE
 	{STR_LIT(""),                 0, true,  Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE
 
 

+ 68 - 9
src/llvm_backend_expr.cpp

@@ -517,6 +517,33 @@ LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
 	return matrix_vector;
 	return matrix_vector;
 }
 }
 
 
+LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
+	Type *mt = base_type(m.type);
+	GB_ASSERT(mt->kind == Type_Matrix);
+	
+	unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
+	unsigned row_count = cast(unsigned)mt->Matrix.row_count;
+	unsigned column_count = cast(unsigned)mt->Matrix.column_count;
+	
+	auto columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
+	
+	LLVMValueRef vector = lb_matrix_to_vector(p, m);
+	
+	unsigned mask_elems_index = 0;
+	auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count*column_count);
+	for (unsigned j = 0; j < column_count; j++) {
+		for (unsigned i = 0; i < row_count; i++) {
+			unsigned offset = stride*j + i;
+			mask_elems[mask_elems_index++] = lb_const_int(p->module, t_u32, offset).value;
+		}
+	}
+	
+	LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count);
+	LLVMValueRef trimmed_vector = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, "");
+	return trimmed_vector;
+}
+
+
 lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
 lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
 	if (is_type_array(m.type)) {
 	if (is_type_array(m.type)) {
 		// no-op
 		// no-op
@@ -573,6 +600,46 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
 	return lb_addr_load(p, res);
 	return lb_addr_load(p, res);
 }
 }
 
 
+lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef vector, Type *type) {
+	lbAddr res = lb_add_local_generated(p, type, true);
+	LLVMValueRef res_ptr = res.addr.value;
+	unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
+	LLVMSetAlignment(res_ptr, alignment);
+	
+	res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
+	LLVMBuildStore(p->builder, vector, res_ptr);
+	
+	return lb_addr_load(p, res);
+}
+
+lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) {
+	if (is_type_array(m.type)) {
+		// no-op
+		m.type = type;
+		return m;
+	}
+	Type *mt = base_type(m.type);
+	GB_ASSERT(mt->kind == Type_Matrix);
+	
+	if (lb_matrix_elem_simple(mt)) {
+		LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m);
+		return lb_matrix_cast_vector_to_type(p, vector, type);
+	}
+	
+	lbAddr res = lb_add_local_generated(p, type, true);
+	
+	i64 row_count = mt->Matrix.row_count;
+	i64 column_count = mt->Matrix.column_count;
+	for (i64 j = 0; j < column_count; j++) {
+		for (i64 i = 0; i < row_count; i++) {
+			lbValue src = lb_emit_matrix_ev(p, m, i, j);
+			lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
+			lb_emit_store(p, dst, src);
+		}
+	}
+	return lb_addr_load(p, res);
+}
+
 
 
 lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
 lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
 	Type *mt = base_type(type);
 	Type *mt = base_type(type);
@@ -737,16 +804,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
 				vector = llvm_vector_add(p, vector, product);
 				vector = llvm_vector_add(p, vector, product);
 			}
 			}
 		}
 		}
-
-		lbAddr res = lb_add_local_generated(p, type, true);
-		LLVMValueRef res_ptr = res.addr.value;
-		unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
-		LLVMSetAlignment(res_ptr, alignment);
-		
-		res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
-		LLVMBuildStore(p->builder, vector, res_ptr);
 		
 		
-		return lb_addr_load(p, res);
+		return lb_matrix_cast_vector_to_type(p, vector, type);
 	}
 	}
 	
 	
 	lbAddr res = lb_add_local_generated(p, type, true);
 	lbAddr res = lb_add_local_generated(p, type, true);

+ 6 - 0
src/llvm_backend_proc.cpp

@@ -1280,6 +1280,12 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
 			GB_ASSERT(is_type_matrix(tv.type));
 			GB_ASSERT(is_type_matrix(tv.type));
 			return lb_emit_arith_matrix(p, Token_Mul, a, b, tv.type, true);
 			return lb_emit_arith_matrix(p, Token_Mul, a, b, tv.type, true);
 		}
 		}
+		
+	case BuiltinProc_matrix_flatten:
+		{
+			lbValue m = lb_build_expr(p, ce->args[0]);
+			return lb_emit_matrix_flatten(p, m, tv.type);
+		}
 
 
 
 
 	// "Intrinsics"
 	// "Intrinsics"