Jelajahi Sumber

Add builtin `outer_product`

gingerBill 4 tahun lalu
induk
melakukan
68afbb37f4

+ 60 - 0
src/check_builtin.cpp

@@ -2017,6 +2017,66 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 		operand->type = check_matrix_type_hint(operand->type, type_hint);
 		break;
 	}
+	
+	case BuiltinProc_outer_product: {
+		Operand x = {};
+		Operand y = {};
+		check_expr(c, &x, ce->args[0]);
+		if (x.mode == Addressing_Invalid) {
+			return false;
+		}
+		check_expr(c, &y, ce->args[1]);
+		if (y.mode == Addressing_Invalid) {
+			return false;
+		}
+		if (!is_operand_value(x) || !is_operand_value(y)) {
+			error(call, "'%.*s' expects only arrays", LIT(builtin_name));
+			return false;
+		}
+		
+		if (!is_type_array(x.type) && !is_type_array(y.type)) {
+			gbString s1 = type_to_string(x.type);
+			gbString s2 = type_to_string(y.type);
+			error(call, "'%.*s' expects only arrays, got %s and %s", LIT(builtin_name), s1, s2);
+			gb_string_free(s2);
+			gb_string_free(s1);
+			return false;
+		}
+		
+		Type *xt = base_type(x.type);
+		Type *yt = base_type(y.type);
+		GB_ASSERT(xt->kind == Type_Array);
+		GB_ASSERT(yt->kind == Type_Array);
+		if (!are_types_identical(xt->Array.elem, yt->Array.elem)) {
+			gbString s1 = type_to_string(xt->Array.elem);
+			gbString s2 = type_to_string(yt->Array.elem);
+			error(call, "'%.*s' mismatched element types, got %s vs %s", LIT(builtin_name), s1, s2);
+			gb_string_free(s2);
+			gb_string_free(s1);
+			return false;
+		}
+		
+		if (xt->Array.count == 0 || yt->Array.count == 0) {
+			gbString s1 = type_to_string(x.type);
+			gbString s2 = type_to_string(y.type);
+			error(call, "'%.*s' expects only arrays of non-zero length, got %s and %s", LIT(builtin_name), s1, s2);
+			gb_string_free(s2);
+			gb_string_free(s1);
+			return false;
+		}
+		
+		i64 max_count = xt->Array.count*yt->Array.count;
+		if (max_count > MAX_MATRIX_ELEMENT_COUNT) {
+			error(call, "Product of the array lengths exceed the maximum matrix element count, got %d, expected a maximum of %d", cast(int)max_count, MAX_MATRIX_ELEMENT_COUNT);
+			return false;
+		}
+		
+		operand->mode = Addressing_Value;
+		operand->type = alloc_type_matrix(xt->Array.elem, xt->Array.count, yt->Array.count);	
+		operand->type = check_matrix_type_hint(operand->type, type_hint);
+		break;
+	}
+	
 
 	case BuiltinProc_simd_vector: {
 		Operand x = {};

+ 2 - 0
src/checker_builtin_procs.hpp

@@ -36,6 +36,7 @@ enum BuiltinProcId {
 	BuiltinProc_soa_unzip,
 	
 	BuiltinProc_transpose,
+	BuiltinProc_outer_product,
 
 	BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures
 
@@ -278,6 +279,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 	{STR_LIT("soa_unzip"),        1, false, Expr_Expr, BuiltinProcPkg_builtin},
 	
 	{STR_LIT("transpose"),        1, false, Expr_Expr, BuiltinProcPkg_builtin},
+	{STR_LIT("outer_product"),    2, false, Expr_Expr, BuiltinProcPkg_builtin},
 
 	{STR_LIT(""),                 0, true,  Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE
 

+ 32 - 0
src/llvm_backend_expr.cpp

@@ -522,9 +522,41 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
 		}
 	}
 	return lb_addr_load(p, res);
+}
+
+
+lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
+	Type *mt = base_type(type);
+	Type *at = base_type(a.type);
+	Type *bt = base_type(b.type);
+	GB_ASSERT(mt->kind == Type_Matrix);
+	GB_ASSERT(at->kind == Type_Array);
+	GB_ASSERT(bt->kind == Type_Array);
+	
+	
+	i64 row_count = mt->Matrix.row_count;
+	i64 column_count = mt->Matrix.column_count;
+	
+	GB_ASSERT(row_count == at->Array.count);
+	GB_ASSERT(column_count == bt->Array.count);
+	
+	
+	lbAddr res = lb_add_local_generated(p, type, true);
+	
+	for (i64 j = 0; j < column_count; j++) {
+		for (i64 i = 0; i < row_count; i++) {
+			lbValue x = lb_emit_struct_ev(p, a, cast(i32)i);
+			lbValue y = lb_emit_struct_ev(p, b, cast(i32)j);
+			lbValue src = lb_emit_arith(p, Token_Mul, x, y, mt->Matrix.elem);
+			lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
+			lb_emit_store(p, dst, src);
+		}
+	}
+	return lb_addr_load(p, res);
 
 }
 
+
 lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
 	Type *xt = base_type(lhs.type);
 	Type *yt = base_type(rhs.type);

+ 8 - 0
src/llvm_backend_proc.cpp

@@ -1263,6 +1263,14 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
 			lbValue m = lb_build_expr(p, ce->args[0]);
 			return lb_emit_matrix_tranpose(p, m, tv.type);
 		}
+		
+	case BuiltinProc_outer_product:
+		{
+			lbValue a = lb_build_expr(p, ce->args[0]);
+			lbValue b = lb_build_expr(p, ce->args[1]);
+			return lb_emit_outer_product(p, a, b, tv.type);
+		}
+
 
 	// "Intrinsics"