Browse Source

Add builtin `transpose`

gingerBill 3 years ago
parent
commit
7faca7066c
5 changed files with 80 additions and 133 deletions
  1. 32 2
      src/check_builtin.cpp
  2. 21 15
      src/check_expr.cpp
  3. 4 0
      src/checker_builtin_procs.hpp
  4. 17 116
      src/llvm_backend_expr.cpp
  5. 6 0
      src/llvm_backend_proc.cpp

+ 32 - 2
src/check_builtin.cpp

@@ -1966,13 +1966,13 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 			return false;
 		}
 		if (!is_operand_value(x)) {
-			error(call, "'soa_unzip' expects an #soa slice");
+			error(call, "'%.*s' expects an #soa slice", LIT(builtin_name));
 			return false;
 		}
 		Type *t = base_type(x.type);
 		if (!is_type_soa_struct(t) || t->Struct.soa_kind != StructSoa_Slice) {
 			gbString s = type_to_string(x.type);
-			error(call, "'soa_unzip' expects an #soa slice, got %s", s);
+			error(call, "'%.*s' expects an #soa slice, got %s", LIT(builtin_name), s);
 			gb_string_free(s);
 			return false;
 		}
@@ -1987,6 +1987,36 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 		operand->mode = Addressing_Value;
 		break;
 	}
+	
+	case BuiltinProc_transpose: {
+		Operand x = {};
+		check_expr(c, &x, ce->args[0]);
+		if (x.mode == Addressing_Invalid) {
+			return false;
+		}
+		if (!is_operand_value(x)) {
+			error(call, "'%.*s' expects a matrix or array", LIT(builtin_name));
+			return false;
+		}
+		Type *t = base_type(x.type);
+		if (!is_type_matrix(t) && !is_type_array(t)) {
+			gbString s = type_to_string(x.type);
+			error(call, "'%.*s' expects a matrix or array, got %s", LIT(builtin_name), s);
+			gb_string_free(s);
+			return false;
+		}
+		
+		operand->mode = Addressing_Value;
+		if (is_type_array(t)) {
+			// Do nothing
+			operand->type = x.type;			
+		} else {
+			GB_ASSERT(t->kind == Type_Matrix);
+			operand->type = alloc_type_matrix(t->Matrix.elem, t->Matrix.column_count, t->Matrix.row_count);
+		}
+		operand->type = check_matrix_type_hint(operand->type, type_hint);
+		break;
+	}
 
 	case BuiltinProc_simd_vector: {
 		Operand x = {};

+ 21 - 15
src/check_expr.cpp

@@ -2708,6 +2708,25 @@ bool can_use_other_type_as_type_hint(bool use_lhs_as_type_hint, Type *other_type
 	return false;
 }
 
+Type *check_matrix_type_hint(Type *matrix, Type *type_hint) {
+	Type *xt = base_type(matrix);
+	if (type_hint != nullptr) {
+		Type *th = base_type(type_hint);
+		if (are_types_identical(th, xt)) {
+			return type_hint;
+		} else if (xt->kind == Type_Matrix && th->kind == Type_Array) {
+			if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) {
+				// ignore
+			} else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) {
+				return type_hint;
+			} else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) {
+				return type_hint;
+			}
+		}
+	}
+	return matrix;
+}
+
 
 void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand *y, Type *type_hint, bool use_lhs_as_type_hint) {
 	if (!check_binary_op(c, x, op)) {
@@ -2791,21 +2810,8 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand
 	}
 
 matrix_success:
-	if (type_hint != nullptr) {
-		Type *th = base_type(type_hint);
-		if (are_types_identical(th, x->type)) {
-			x->type = type_hint;
-		} else if (x->type->kind == Type_Matrix && th->kind == Type_Array) {
-			Type *xt = x->type;
-			if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) {
-				// ignore
-			} else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) {
-				x->type = type_hint;
-			} else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) {
-				x->type = type_hint;
-			}
-		}
-	}
+	x->type = check_matrix_type_hint(x->type, type_hint);
+	
 	return;
 	
 	

+ 4 - 0
src/checker_builtin_procs.hpp

@@ -34,6 +34,8 @@ enum BuiltinProcId {
 
 	BuiltinProc_soa_zip,
 	BuiltinProc_soa_unzip,
+	
+	BuiltinProc_transpose,
 
 	BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures
 
@@ -274,6 +276,8 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 
 	{STR_LIT("soa_zip"),          1, true,  Expr_Expr, BuiltinProcPkg_builtin},
 	{STR_LIT("soa_unzip"),        1, false, Expr_Expr, BuiltinProcPkg_builtin},
+	
+	{STR_LIT("transpose"),        1, false, Expr_Expr, BuiltinProcPkg_builtin},
 
 	{STR_LIT(""),                 0, true,  Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE
 

+ 17 - 116
src/llvm_backend_expr.cpp

@@ -502,116 +502,29 @@ bool lb_matrix_elem_simple(Type *t) {
 	return true;
 }
 
-LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) {
-	lbModule *m = p->module;
-	
-	Type *mt = base_type(lhs.type);
-	GB_ASSERT(mt->kind == Type_Matrix);
-	GB_ASSERT(lb_matrix_elem_simple(mt));
-	
-	
-	i64 stride = matrix_type_stride_in_elems(mt);
-	i64 rows = mt->Matrix.row_count;
-	i64 columns = mt->Matrix.column_count;
-	unsigned elem_count = cast(unsigned)(rows*columns);
-	
-	Type *elem = mt->Matrix.elem;
-	LLVMTypeRef elem_type = lb_type(m, elem);
-	
-	LLVMTypeRef vector_type = LLVMVectorType(elem_type, elem_count);
-	LLVMTypeRef types[] = {vector_type};
-	
-	char const *name = "llvm.matrix.column.major.load";
-	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
-	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
-	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
-	
-	lbValue ptr = lb_address_from_load_or_generate_local(p, lhs);
-	ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
-		
-	LLVMValueRef values[5] = {};
-	values[0] = ptr.value;
-	values[1] = lb_const_int(m, t_u64, stride).value; 
-	values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
-	values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
-	values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
-	
-	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
-	gb_printf_err("%s\n", LLVMPrintValueToString(call));
-	// LLVMAddAttributeAtIndex(call, 0, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt)));
-	return call;
-}
-
-void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef vector_value) {
-	lbModule *m = p->module;
-	
-	Type *mt = base_type(lb_addr_type(addr));
+lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
+	if (is_type_array(m.type)) {
+		m.type = type;
+		return m;
+	}
+	Type *mt = base_type(m.type);
 	GB_ASSERT(mt->kind == Type_Matrix);
-	GB_ASSERT(lb_matrix_elem_simple(mt));
-	
-	LLVMTypeRef vector_type = LLVMTypeOf(vector_value);
-	LLVMTypeRef types[] = {vector_type};
-	
-	char const *name = "llvm.matrix.column.major.store";
-	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
-	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
-	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
-	
-	lbValue ptr = lb_addr_get_ptr(p, addr);
-	ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
-	
-	unsigned vector_size = LLVMGetVectorSize(vector_type);
-	GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
 	
-	i64 stride = matrix_type_stride_in_elems(mt);
-	
-	LLVMValueRef values[6] = {};
-	values[0] = vector_value;
-	values[1] = ptr.value;
-	values[2] = lb_const_int(m, t_u64, stride).value;
-	values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
-	values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
-	values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
+	lbAddr res = lb_add_local_generated(p, type, true);
 	
-	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
-	gb_printf_err("%s\n", LLVMPrintValueToString(call));
-	// LLVMAddAttributeAtIndex(call, 1, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt)));
-	gb_unused(call);
-}
-
+	i64 row_count = mt->Matrix.row_count;
+	i64 column_count = mt->Matrix.column_count;
+	for (i64 j = 0; j < column_count; j++) {
+		for (i64 i = 0; i < row_count; i++) {
+			lbValue src = lb_emit_matrix_ev(p, m, i, j);
+			lbValue dst = lb_emit_matrix_epi(p, res.addr, j, i);
+			lb_emit_store(p, dst, src);
+		}
+	}
+	return lb_addr_load(p, res);
 
-LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, i64 outer_rows, i64 inner, i64 outer_columns) {
-	lbModule *m = p->module;
-	
-	LLVMTypeRef a_type = LLVMTypeOf(a);
-	LLVMTypeRef b_type = LLVMTypeOf(b);
-	
-	GB_ASSERT(LLVMGetElementType(a_type) == LLVMGetElementType(b_type));
-	
-	LLVMTypeRef elem_type = LLVMGetElementType(a_type);
-	
-	LLVMTypeRef res_vector_type = LLVMVectorType(elem_type, cast(unsigned)(outer_rows*outer_columns));
-	
-	LLVMTypeRef types[] = {res_vector_type, a_type, b_type};
-	
-	char const *name = "llvm.matrix.multiply";
-	unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
-	GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
-	LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
-
-	LLVMValueRef values[5] = {};
-	values[0] = a;
-	values[1] = b;
-	values[2] = lb_const_int(m, t_u32, outer_rows).value;
-	values[3] = lb_const_int(m, t_u32, inner).value;
-	values[4] = lb_const_int(m, t_u32, outer_columns).value;
-	
-	LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
-	gb_printf_err("%s\n", LLVMPrintValueToString(call));
-	return call;
 }
 
-
 lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
 	Type *xt = base_type(lhs.type);
 	Type *yt = base_type(rhs.type);
@@ -626,18 +539,6 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
 		goto slow_form;
 	}
 	
-	if (false) {
-		// TODO(bill): LLVM ERROR: Do not know how to split the result of this operator!
-		lbAddr res = lb_add_local_generated(p, type, true);
-		
-		LLVMValueRef a = llvm_matrix_column_major_load(p, lhs); gb_unused(a);
-		LLVMValueRef b = llvm_matrix_column_major_load(p, rhs); gb_unused(b);
-		LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count); gb_unused(c);
-		llvm_matrix_column_major_store(p, res, c);
-		
-		return lb_addr_load(p, res);
-	} 
-		
 slow_form:
 	{
 		Type *elem = xt->Matrix.elem;	

+ 6 - 0
src/llvm_backend_proc.cpp

@@ -1257,6 +1257,12 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
 		return lb_soa_zip(p, ce, tv);
 	case BuiltinProc_soa_unzip:
 		return lb_soa_unzip(p, ce, tv);
+		
+	case BuiltinProc_transpose:
+		{
+			lbValue m = lb_build_expr(p, ce->args[0]);
+			return lb_emit_matrix_tranpose(p, m, tv.type);
+		}
 
 	// "Intrinsics"