Browse Source

Add `intrinsics.fused_mul_add`

gingerBill 3 years ago
parent
commit
421d45a7a7

+ 2 - 0
core/intrinsics/intrinsics.odin

@@ -35,6 +35,8 @@ overflow_mul :: proc(lhs, rhs: $T) -> (T, bool) #optional_ok ---
 
 sqrt :: proc(x: $T) -> T where type_is_float(T) ---
 
+fused_mul_add :: proc(a, b, c: $T) -> T where type_is_float(T) || (type_is_simd_vector(T) && type_is_float(type_elem_type(T))) ---
+
 mem_copy                 :: proc(dst, src: rawptr, len: int) ---
 mem_copy_non_overlapping :: proc(dst, src: rawptr, len: int) ---
 mem_zero                 :: proc(ptr: rawptr, len: int) ---

+ 3 - 0
core/simd/simd.odin

@@ -109,6 +109,9 @@ count_zeros          :: intrinsics.count_zeros
 count_trailing_zeros :: intrinsics.count_trailing_zeros
 count_leading_zeros  :: intrinsics.count_leading_zeros
 
+fused_mul_add :: intrinsics.fused_mul_add
+fma           :: intrinsics.fused_mul_add
+
 to_array_ptr :: #force_inline proc "contextless" (v: ^#simd[$LANES]$E) -> ^[LANES]E {
 	return (^[LANES]E)(v)
 }

+ 53 - 0
src/check_builtin.cpp

@@ -3681,6 +3681,59 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
 		}
 		break;
 
+	case BuiltinProc_fused_mul_add:
+		{
+			Operand x = {};
+			Operand y = {};
+			Operand z = {};
+			check_expr(c, &x, ce->args[0]); if (x.mode == Addressing_Invalid) return false;
+			check_expr(c, &y, ce->args[1]); if (y.mode == Addressing_Invalid) return false;
+			check_expr(c, &z, ce->args[2]); if (z.mode == Addressing_Invalid) return false;
+
+			convert_to_typed(c, &y, x.type); if (y.mode == Addressing_Invalid) return false;
+			convert_to_typed(c, &x, y.type); if (x.mode == Addressing_Invalid) return false;
+			convert_to_typed(c, &z, x.type); if (z.mode == Addressing_Invalid) return false;
+			convert_to_typed(c, &x, z.type); if (x.mode == Addressing_Invalid) return false;
+			if (is_type_untyped(x.type)) {
+				gbString xts = type_to_string(x.type);
+				error(x.expr, "Expected a typed floating point value or #simd vector for '%.*s', got %s", LIT(builtin_name), xts);
+				gb_string_free(xts);
+				return false;
+			}
+
+			Type *elem = core_array_type(x.type);
+			if (!is_type_float(x.type) && !(is_type_simd_vector(x.type) && is_type_float(elem))) {
+				gbString xts = type_to_string(x.type);
+				error(x.expr, "Expected a floating point or #simd vector value for '%.*s', got %s", LIT(builtin_name), xts);
+				gb_string_free(xts);
+				return false;
+			}
+			if (is_type_different_to_arch_endianness(elem)) {
+				GB_ASSERT(elem->kind == Type_Basic);
+				if (elem->Basic.flags & (BasicFlag_EndianLittle|BasicFlag_EndianBig)) {
+					gbString xts = type_to_string(x.type);
+					error(x.expr, "Expected a float which does not specify the explicit endianness for '%.*s', got %s", LIT(builtin_name), xts);
+					gb_string_free(xts);
+					return false;
+				}
+			}
+
+			if (!are_types_identical(x.type, y.type) || !are_types_identical(y.type, z.type)) {
+				gbString xts = type_to_string(x.type);
+				gbString yts = type_to_string(y.type);
+				gbString zts = type_to_string(z.type);
+				error(x.expr, "Mismatched types for '%.*s', got %s vs %s vs %s", LIT(builtin_name), xts, yts, zts);
+				gb_string_free(zts);
+				gb_string_free(yts);
+				gb_string_free(xts);
+				return false;
+			}
+
+			operand->mode = Addressing_Value;
+			operand->type = default_type(x.type);
+		}
+		break;
+
 	case BuiltinProc_mem_copy:
 	case BuiltinProc_mem_copy_non_overlapping:
 		{

+ 2 - 0
src/checker_builtin_procs.hpp

@@ -65,6 +65,7 @@ enum BuiltinProcId {
 	BuiltinProc_overflow_mul,
 
 	BuiltinProc_sqrt,
+	BuiltinProc_fused_mul_add,
 
 	BuiltinProc_mem_copy,
 	BuiltinProc_mem_copy_non_overlapping,
@@ -348,6 +349,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
 	{STR_LIT("overflow_mul"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 
 	{STR_LIT("sqrt"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics},
+	{STR_LIT("fused_mul_add"), 3, false, Expr_Expr, BuiltinProcPkg_intrinsics},
 
 	{STR_LIT("mem_copy"),                 3, false, Expr_Stmt, BuiltinProcPkg_intrinsics},
 	{STR_LIT("mem_copy_non_overlapping"), 3, false, Expr_Stmt, BuiltinProcPkg_intrinsics},

+ 25 - 0
src/llvm_backend_proc.cpp

@@ -2005,6 +2005,31 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
 			return res;
 		}
 
+	case BuiltinProc_fused_mul_add:
+		{
+			Type *type = tv.type;
+			lbValue x = lb_emit_conv(p, lb_build_expr(p, ce->args[0]), type);
+			lbValue y = lb_emit_conv(p, lb_build_expr(p, ce->args[1]), type);
+			lbValue z = lb_emit_conv(p, lb_build_expr(p, ce->args[2]), type);
+
+
+			char const *name = "llvm.fma";
+			LLVMTypeRef types[1] = {lb_type(p->module, type)};
+			unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+			GB_ASSERT_MSG(id != 0, "Unable to find %s.%s", name, LLVMPrintTypeToString(types[0]));
+			LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types));
+
+			LLVMValueRef args[3] = {};
+			args[0] = x.value;
+			args[1] = y.value;
+			args[2] = z.value;
+
+			lbValue res = {};
+			res.value = LLVMBuildCall(p->builder, ip, args, gb_count_of(args), "");
+			res.type = type;
+			return res;
+		}
+
 	case BuiltinProc_mem_copy:
 		{
 			lbValue dst = lb_build_expr(p, ce->args[0]);