Bläddra i källkod

Add `intrinsics.simd_reduce_*`

gingerBill 3 år sedan
förälder
incheckning
10e4de3c01
4 ändrade filer med 167 tillägg och 9 borttagningar
  1. 8 0
      core/simd/simd.odin
  2. 50 0
      src/check_builtin.cpp
  3. 16 0
      src/checker_builtin_procs.hpp
  4. 93 9
      src/llvm_backend_proc.cpp

+ 8 - 0
core/simd/simd.odin

@@ -34,6 +34,14 @@ ge      :: intrinsics.simd_ge
 extract :: intrinsics.simd_extract
 replace :: intrinsics.simd_replace
 
+reduce_add_ordered :: intrinsics.simd_reduce_add_ordered
+reduce_mul_ordered :: intrinsics.simd_reduce_mul_ordered
+reduce_min         :: intrinsics.simd_reduce_min
+reduce_max         :: intrinsics.simd_reduce_max
+reduce_and         :: intrinsics.simd_reduce_and
+reduce_or          :: intrinsics.simd_reduce_or
+reduce_xor         :: intrinsics.simd_reduce_xor
+
 splat :: #force_inline proc "contextless" ($T: typeid/#simd[$LANES]$E, value: E) -> T {
 	return T{0..<LANES = value}
 }

+ 50 - 0
src/check_builtin.cpp

@@ -695,6 +695,56 @@ bool check_builtin_simd_operation(CheckerContext *c, Operand *operand, Ast *call
 		}
 		break;
 
+	case BuiltinProc_simd_reduce_add_ordered:
+	case BuiltinProc_simd_reduce_mul_ordered:
+	case BuiltinProc_simd_reduce_min:
+	case BuiltinProc_simd_reduce_max:
+		{
+			Operand x = {};
+			check_expr(c, &x, ce->args[0]); if (x.mode == Addressing_Invalid) { return false; }
+
+			if (!is_type_simd_vector(x.type)) {
+				error(x.expr, "'%.*s' expected a simd vector type", LIT(builtin_name));
+				return false;
+			}
+			Type *elem = base_array_type(x.type);
+			if (!is_type_integer(elem) && !is_type_float(elem)) {
+				gbString xs = type_to_string(x.type);
+				error(x.expr, "'%.*s' expected a #simd type with an integer or floating-point element, got '%s'", LIT(builtin_name), xs);
+				gb_string_free(xs);
+				return false;
+			}
+
+			operand->mode = Addressing_Value;
+			operand->type = base_array_type(x.type);
+			return true;
+		}
+
+	case BuiltinProc_simd_reduce_and:
+	case BuiltinProc_simd_reduce_or:
+	case BuiltinProc_simd_reduce_xor:
+		{
+			Operand x = {};
+			check_expr(c, &x, ce->args[0]); if (x.mode == Addressing_Invalid) { return false; }
+
+			if (!is_type_simd_vector(x.type)) {
+				error(x.expr, "'%.*s' expected a simd vector type", LIT(builtin_name));
+				return false;
+			}
+			Type *elem = base_array_type(x.type);
+			if (!is_type_integer(elem)) {
+				gbString xs = type_to_string(x.type);
+				error(x.expr, "'%.*s' expected a #simd type with an integer element, got '%s'", LIT(builtin_name), xs);
+				gb_string_free(xs);
+				return false;
+			}
+
+			operand->mode = Addressing_Value;
+			operand->type = base_array_type(x.type);
+			return true;
+		}
+
+
 	// case BuiltinProc_simd_rotate_left:
 	// 	{
 	// 		Operand x = {};

+ 16 - 0
src/checker_builtin_procs.hpp

@@ -149,6 +149,14 @@ BuiltinProc__simd_begin,
 
 	BuiltinProc_simd_extract,
 	BuiltinProc_simd_replace,
+
+	BuiltinProc_simd_reduce_add_ordered,
+	BuiltinProc_simd_reduce_mul_ordered,
+	BuiltinProc_simd_reduce_min,
+	BuiltinProc_simd_reduce_max,
+	BuiltinProc_simd_reduce_and,
+	BuiltinProc_simd_reduce_or,
+	BuiltinProc_simd_reduce_xor,
 BuiltinProc__simd_end,
 	
 	// Platform specific intrinsics
@@ -401,6 +409,14 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 
 	{STR_LIT("simd_extract"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 	{STR_LIT("simd_replace"), 3, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+
+	{STR_LIT("simd_reduce_add_ordered"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_reduce_mul_ordered"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_reduce_min"),         1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_reduce_max"),         1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_reduce_and"),         1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_reduce_or"),          1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("simd_reduce_xor"),         1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 	{STR_LIT(""), 0, false, Expr_Stmt, BuiltinProcPkg_intrinsics},
 
 

+ 93 - 9
src/llvm_backend_proc.cpp

@@ -981,7 +981,7 @@ lbValue lb_emit_call(lbProcedure *p, lbValue value, Array<lbValue> const &args,
 	return result;
 }
 
-lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, BuiltinProcId id) {
+lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, BuiltinProcId builtin_id) {
 	ast_node(ce, CallExpr, expr);
 
 	lbModule *m = p->module;
@@ -1000,7 +1000,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const
 
 	LLVMOpcode op_code = cast(LLVMOpcode)0;
 
-	switch (id) {
+	switch (builtin_id) {
 	case BuiltinProc_simd_add:
 	case BuiltinProc_simd_sub:
 	case BuiltinProc_simd_mul:
@@ -1008,14 +1008,14 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const
 	case BuiltinProc_simd_rem:
 		arg1 = lb_build_expr(p, ce->args[1]);
 		if (is_float) {
-			switch (id) {
+			switch (builtin_id) {
 			case BuiltinProc_simd_add: op_code = LLVMFAdd; break;
 			case BuiltinProc_simd_sub: op_code = LLVMFSub; break;
 			case BuiltinProc_simd_mul: op_code = LLVMFMul; break;
 			case BuiltinProc_simd_div: op_code = LLVMFDiv; break;
 			}
 		} else {
-			switch (id) {
+			switch (builtin_id) {
 			case BuiltinProc_simd_add: op_code = LLVMAdd; break;
 			case BuiltinProc_simd_sub: op_code = LLVMSub; break;
 			case BuiltinProc_simd_mul: op_code = LLVMMul; break;
@@ -1053,7 +1053,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const
 			Type *elem1 = base_array_type(arg1.type);
 
 			bool is_masked = false;
-			switch (id) {
+			switch (builtin_id) {
 			case BuiltinProc_simd_shl:        op_code = LLVMShl;                         is_masked = false; break;
 			case BuiltinProc_simd_shr:        op_code = is_signed ? LLVMAShr : LLVMLShr; is_masked = false; break;
 			case BuiltinProc_simd_shl_masked: op_code = LLVMShl;                         is_masked = true;  break;
@@ -1086,7 +1086,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const
 	case BuiltinProc_simd_or:
 	case BuiltinProc_simd_xor:
 		arg1 = lb_build_expr(p, ce->args[1]);
-		switch (id) {
+		switch (builtin_id) {
 		case BuiltinProc_simd_and: op_code = LLVMAnd; break;
 		case BuiltinProc_simd_or:  op_code = LLVMOr;  break;
 		case BuiltinProc_simd_xor: op_code = LLVMXor; break;
@@ -1144,7 +1144,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const
 		arg1 = lb_build_expr(p, ce->args[1]);
 		if (is_float) {
 			LLVMRealPredicate pred = cast(LLVMRealPredicate)0;
-			switch (id) {
+			switch (builtin_id) {
 			case BuiltinProc_simd_eq: pred = LLVMRealOEQ; break;
 			case BuiltinProc_simd_ne: pred = LLVMRealONE; break;
 			case BuiltinProc_simd_lt: pred = LLVMRealOLT; break;
@@ -1159,7 +1159,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const
 			}
 		} else {
 			LLVMIntPredicate pred = cast(LLVMIntPredicate)0;
-			switch (id) {
+			switch (builtin_id) {
 			case BuiltinProc_simd_eq: pred = LLVMIntEQ; break;
 			case BuiltinProc_simd_ne: pred = LLVMIntNE; break;
 			case BuiltinProc_simd_lt: pred = is_signed ? LLVMIntSLT :LLVMIntULT; break;
@@ -1184,8 +1184,92 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const
 		arg2 = lb_build_expr(p, ce->args[2]);
 		res.value = LLVMBuildInsertElement(p->builder, arg0.value, arg2.value, arg1.value, "");
 		return res;
+
+	case BuiltinProc_simd_reduce_add_ordered:
+	case BuiltinProc_simd_reduce_mul_ordered:
+		{
+			LLVMTypeRef llvm_elem = lb_type(m, elem);
+			LLVMValueRef args[2] = {};
+			isize args_count = 0;
+
+			char const *name = nullptr;
+			switch (builtin_id) {
+			case BuiltinProc_simd_reduce_add_ordered:
+				if (is_float) {
+					name = "llvm.vector.reduce.fadd";
+					args[args_count++] = LLVMConstReal(llvm_elem, 0.0);
+				} else {
+					name = "llvm.vector.reduce.add";
+				}
+				break;
+			case BuiltinProc_simd_reduce_mul_ordered:
+				if (is_float) {
+					name = "llvm.vector.reduce.fmul";
+					args[args_count++] = LLVMConstReal(llvm_elem, 1.0);
+				} else {
+					name = "llvm.vector.reduce.mul";
+				}
+				break;
+			}
+			args[args_count++] = arg0.value;
+
+
+			LLVMTypeRef types[1] = {lb_type(p->module, arg0.type)};
+			unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+			GB_ASSERT_MSG(id != 0, "Unable to find %s.%s", name, LLVMPrintTypeToString(types[0]));
+			LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types));
+
+			lbValue res = {};
+			res.value = LLVMBuildCall(p->builder, ip, args, cast(unsigned)args_count, "");
+			res.type = tv.type;
+			return res;
+		}
+	case BuiltinProc_simd_reduce_min:
+	case BuiltinProc_simd_reduce_max:
+	case BuiltinProc_simd_reduce_and:
+	case BuiltinProc_simd_reduce_or:
+	case BuiltinProc_simd_reduce_xor:
+		{
+			char const *name = nullptr;
+			switch (builtin_id) {
+			case BuiltinProc_simd_reduce_min:
+				if (is_float) {
+					name = "llvm.vector.reduce.fmin";
+				} else if (is_signed) {
+					name = "llvm.vector.reduce.smin";
+				} else {
+					name = "llvm.vector.reduce.umin";
+				}
+				break;
+			case BuiltinProc_simd_reduce_max:
+				if (is_float) {
+					name = "llvm.vector.reduce.fmax";
+				} else if (is_signed) {
+					name = "llvm.vector.reduce.smax";
+				} else {
+					name = "llvm.vector.reduce.umax";
+				}
+				break;
+			case BuiltinProc_simd_reduce_and: name = "llvm.vector.reduce.and"; break;
+			case BuiltinProc_simd_reduce_or:  name = "llvm.vector.reduce.or";  break;
+			case BuiltinProc_simd_reduce_xor: name = "llvm.vector.reduce.xor"; break;
+			}
+			LLVMTypeRef types[1] = {lb_type(p->module, arg0.type)};
+			unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+			GB_ASSERT_MSG(id != 0, "Unable to find %s.%s", name, LLVMPrintTypeToString(types[0]));
+			LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types));
+
+			LLVMValueRef args[1] = {};
+			args[0] = arg0.value;
+
+			lbValue res = {};
+			res.value = LLVMBuildCall(p->builder, ip, args, gb_count_of(args), "");
+			res.type = tv.type;
+			return res;
+		}
 	}
-	GB_PANIC("Unhandled simd intrinsic: '%.*s'", LIT(builtin_procs[id].name));
+	GB_PANIC("Unhandled simd intrinsic: '%.*s'", LIT(builtin_procs[builtin_id].name));
+
 	return {};
 }