Browse Source

Add support for expanding trig intrinsics (#325)

We can now expand the following intrinsics:

    Acos
    Asin
    Atan
    Hcos
    Hsin
    Htan

The expansion uses the same approximation algorithms used by the d3d compiler.
David Peixotto 8 years ago
parent
commit
489147d88c

+ 2 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -43,6 +43,7 @@ ModulePass *createDxilGenerationPass(bool NotOptimized, hlsl::HLSLExtensionsCode
 ModulePass *createHLEmitMetadataPass();
 ModulePass *createHLEnsureMetadataPass();
 ModulePass *createDxilEmitMetadataPass();
+FunctionPass *createDxilExpandTrigIntrinsicsPass();
 ModulePass *createDxilLoadMetadataPass();
 ModulePass *createDxilPrecisePropagatePass();
 FunctionPass *createDxilLegalizeResourceUsePass();
@@ -57,6 +58,7 @@ void initializeDxilGenerationPassPass(llvm::PassRegistry&);
 void initializeHLEnsureMetadataPass(llvm::PassRegistry&);
 void initializeHLEmitMetadataPass(llvm::PassRegistry&);
 void initializeDxilEmitMetadataPass(llvm::PassRegistry&);
+void initializeDxilExpandTrigIntrinsicsPass(llvm::PassRegistry&);
 void initializeDxilLoadMetadataPass(llvm::PassRegistry&);
 void initializeDxilPrecisePropagatePassPass(llvm::PassRegistry&);
 void initializeDxilLegalizeResourceUsePassPass(llvm::PassRegistry&);

+ 1 - 0
lib/HLSL/CMakeLists.txt

@@ -10,6 +10,7 @@ add_llvm_library(LLVMHLSL
   DxilContainerAssembler.cpp
   DxilContainerReflection.cpp
   DxilEliminateOutputDynamicIndexing.cpp
+  DxilExpandTrigIntrinsics.cpp
   DxilGenerationPass.cpp
   DxilInterpolationMode.cpp
   DxilLegalizeSampleOffsetPass.cpp

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -85,6 +85,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeDxilCondenseResourcesPass(Registry);
     initializeDxilEliminateOutputDynamicIndexingPass(Registry);
     initializeDxilEmitMetadataPass(Registry);
+    initializeDxilExpandTrigIntrinsicsPass(Registry);
     initializeDxilGenerationPassPass(Registry);
     initializeDxilLegalizeEvalOperationsPass(Registry);
     initializeDxilLegalizeResourceUsePassPass(Registry);

+ 519 - 0
lib/HLSL/DxilExpandTrigIntrinsics.cpp

@@ -0,0 +1,519 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// DxilExpandTrigIntrinsics.cpp                                              //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// Expand trigonmetric intrinsics to a sequence of dxil instructions.        //
+// ========================================================================= //
+//
+// We provide expansions to approximate several trigonmetric functions that
+// typically do not have native instructions in hardware. The details of each
+// expansion is given below, but typically the exansion occurs in three steps
+// 
+//     1. Perform range reduction (if necessary) to reduce input range
+//        to a value that works with the approximation.
+//     2. Compute an approximation to the function (typically by evaluating 
+//        a polynomial).
+//     3. Perform range expansion (if necessary) to map the result back to
+//        the original range.
+// 
+// For example, say we are expanding f(x) using an approximation to f, call it
+// f*(x). And assume that f* only works for positive inputs, but we know that
+// f(-x) = -f(x).Then the expansion would be
+// 
+//     1. a = abs(x)
+//     2. v = f*(a)
+//     3. e = x < 0 ? -v : v
+// 
+// where e contains the final expanded result.
+// 
+// References
+// ---------------------------------------------------------------------------
+// [HMF] Handbook of Mathematical Formulas by Abramowitz and Stegun, 1964
+// [ADC] Approximations for Digital Computers by Hastings, 1955
+// [WIK] Wikipedia, 2017
+// 
+// The approximation functions mostly come from [ADC]. The approximations
+// are also referenced in [HMF], but they give original credit to [ADC].
+// 
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/HLSL/DxilGenerationPass.h"
+#include "dxc/HLSL/DxilOperations.h"
+#include "dxc/HLSL/DxilSignatureElement.h"
+#include "dxc/HLSL/DxilModule.h"
+#include "dxc/Support/Global.h"
+#include "dxc/HLSL/DxilInstructions.h"
+
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/ADT/MapVector.h"
+
+#include <cmath>
+#include <utility>
+
+using namespace llvm;
+using namespace hlsl;
+
+namespace {
+class DxilExpandTrigIntrinsics : public FunctionPass {
+private:
+
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit DxilExpandTrigIntrinsics() : FunctionPass(ID) {}
+
+  const char *getPassName() const override {
+    return "DXIL expand trig intrinsics";
+  }
+  
+  bool runOnFunction(Function &F) override;
+  
+
+private:
+  typedef std::vector<CallInst *> IntrinsicList;
+  IntrinsicList findTrigFunctionsToExpand(Function &F);
+  CallInst *isExpandableTrigIntrinsicCall(Instruction *I);
+  bool expandTrigIntrinsics(DxilModule &DM, const IntrinsicList &worklist);
+  FastMathFlags getFastMathFlagsForIntrinsic(CallInst *intrinsic);
+  void prepareBuilderToExpandIntrinsic(IRBuilder<> &builder, CallInst *intrinsic);
+
+  // Expansion implementations.
+  Value *expandACos(IRBuilder<> &builder, DxilInst_Acos acos, DxilModule &DM);
+  Value *expandASin(IRBuilder<> &builder, DxilInst_Asin asin, DxilModule &DM);
+  Value *expandATan(IRBuilder<> &builder, DxilInst_Atan atan, DxilModule &DM);
+  Value *expandHCos(IRBuilder<> &builder, DxilInst_Hcos hcos, DxilModule &DM);
+  Value *expandHSin(IRBuilder<> &builder, DxilInst_Hsin hsin, DxilModule &DM);
+  Value *expandHTan(IRBuilder<> &builder, DxilInst_Htan htan, DxilModule &DM);
+};
+
+// Math constants.
+// Values taken from https://msdn.microsoft.com/en-us/library/4hwaceh6.aspx.
+// Replicated here because they are not part of standard C++.
+namespace math {
+  constexpr double PI    = 3.14159265358979323846;
+  constexpr double PI_2  = 1.57079632679489661923;
+  constexpr double LOG2E = 1.44269504088896340736;
+}
+
+}
+
+
+bool DxilExpandTrigIntrinsics::runOnFunction(Function &F) {
+  DxilModule &DM = F.getParent()->GetOrCreateDxilModule(); 
+  IntrinsicList intrinsics = findTrigFunctionsToExpand(F);
+  const bool changed = expandTrigIntrinsics(DM, intrinsics);
+  return changed;
+}
+
+CallInst *DxilExpandTrigIntrinsics::isExpandableTrigIntrinsicCall(Instruction *I) {
+    if (OP::IsDxilOpFuncCallInst(I)) {
+      switch (OP::GetDxilOpFuncCallInst(I)) {
+      case OP::OpCode::Acos:
+      case OP::OpCode::Asin:
+      case OP::OpCode::Atan:
+      case OP::OpCode::Hcos:
+      case OP::OpCode::Hsin:
+      case OP::OpCode::Htan:
+        return cast<CallInst>(I);
+      default: break;
+      }
+    }
+    return nullptr;
+}
+
+DxilExpandTrigIntrinsics::IntrinsicList DxilExpandTrigIntrinsics::findTrigFunctionsToExpand(Function &F) {
+  IntrinsicList worklist;
+  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
+    if (CallInst *call = isExpandableTrigIntrinsicCall(&*I))
+      worklist.push_back(call);
+
+  return worklist;
+}
+
+static bool isPreciseBuilder(IRBuilder<> &builder) {
+  return !builder.getFastMathFlags().any();
+}
+
+static void setPreciseBuilder(IRBuilder<> &builder, bool precise) {
+  FastMathFlags flags;
+  if (precise)
+    flags.clear();
+  else
+    flags.setUnsafeAlgebra();
+  builder.SetFastMathFlags(flags);
+}
+
+void DxilExpandTrigIntrinsics::prepareBuilderToExpandIntrinsic(IRBuilder<> &builder, CallInst *intrinsic) {
+  DxilModule &DM = intrinsic->getModule()->GetOrCreateDxilModule();
+  builder.SetInsertPoint(intrinsic);
+  setPreciseBuilder(builder, DM.IsPrecise(intrinsic));
+}
+  
+bool DxilExpandTrigIntrinsics::expandTrigIntrinsics(DxilModule &DM, const IntrinsicList &worklist) {
+  IRBuilder<> builder(DM.GetCtx());
+  for (CallInst *intrinsic: worklist) {
+    Value *expansion = nullptr;
+    prepareBuilderToExpandIntrinsic(builder, intrinsic);
+    
+    OP::OpCode opcode = OP::GetDxilOpFuncCallInst(intrinsic);
+    switch (opcode) {
+    case OP::OpCode::Acos: expansion = expandACos(builder, intrinsic, DM); break;
+    case OP::OpCode::Asin: expansion = expandASin(builder, intrinsic, DM); break;
+    case OP::OpCode::Atan: expansion = expandATan(builder, intrinsic, DM); break;
+    case OP::OpCode::Hcos: expansion = expandHCos(builder, intrinsic, DM); break;
+    case OP::OpCode::Hsin: expansion = expandHSin(builder, intrinsic, DM); break;
+    case OP::OpCode::Htan: expansion = expandHTan(builder, intrinsic, DM); break;
+    default:
+      assert(false && "unexpected intrinsic");
+      break;
+    }
+
+    assert(expansion);
+    intrinsic->replaceAllUsesWith(expansion);
+    intrinsic->eraseFromParent();
+  }
+
+  return !worklist.empty();
+}
+
+// Helper
+// return dx.op.UnaryFloat(X)
+//
+static Value *emitUnaryFloat(IRBuilder<> &builder, Value *X, OP *dxOp, OP::OpCode opcode, StringRef name) {
+  Function *F = dxOp->GetOpFunc(opcode, X->getType());
+  Value *Args[] = { dxOp->GetI32Const(static_cast<int>(opcode)), X };
+  CallInst *Call = builder.CreateCall(F, Args, name);
+
+  if (isPreciseBuilder(builder))
+    DxilMDHelper::MarkPrecise(Call);
+  return Call;
+}
+
+// Helper
+// return dx.op.Fabs(X)
+//
+static Value *emitFAbs(IRBuilder<> &builder, Value *X, OP *dxOp, StringRef name) {
+  return emitUnaryFloat(builder, X, dxOp, OP::OpCode::FAbs, name);
+}
+
+// Helper
+// return dx.op.Sqrt(X)
+//
+static Value *emitSqrt(IRBuilder<> &builder, Value *X, OP *dxOp, StringRef name) {
+  return emitUnaryFloat(builder, X, dxOp, OP::OpCode::Sqrt, name);
+}
+
+// Helper
+// return sqrt(1 - X) * psi*(X)
+//
+// We compute the polynomial using Horners method to evaluate it efficently.
+//
+// psi*(X) = a0 + a1x + a2x^2 + a3x^3
+//         = a0 + x(a1 + a2x + a3x^2)
+//         = a0 + x(a1 + x(a2 + a3x))
+//
+static Value *emitSqrt1mXtimesPsiX(IRBuilder<> &builder, Value *X, OP *dxOp, StringRef name) {
+  Value *One = ConstantFP::get(X->getType(), 1.0);
+  Value *a0 = ConstantFP::get(X->getType(),  1.5707288);
+  Value *a1 = ConstantFP::get(X->getType(), -0.2121144);
+  Value *a2 = ConstantFP::get(X->getType(),  0.0742610);
+  Value *a3 = ConstantFP::get(X->getType(), -0.0187293);
+
+
+  // sqrt(1-x)
+  Value *r1 = builder.CreateFSub(One, X, name);
+  Value *r2 = emitSqrt(builder, r1, dxOp, name);
+
+  // psi*(x)
+  Value *r3 = builder.CreateFMul(X,  a3, name);
+         r3 = builder.CreateFAdd(r3, a2, name);
+         r3 = builder.CreateFMul(X,  r3, name);
+         r3 = builder.CreateFAdd(r3, a1, name);
+         r3 = builder.CreateFMul(X,  r3, name);
+         r3 = builder.CreateFAdd(r3, a0, name);
+
+  // sqrt(1-x) * psi*(x)
+  Value *r4 = builder.CreateFMul(r2, r3,  name);
+  return r4;
+}
+
+// Helper
+// return e^x, e^-x
+//
+// We can use the dxil Exp function to compute the exponential. The only slight
+// wrinkle is that in dxil Exp(x) = 2^x and we need e^x. Luckily we can easily
+// change the base of the exponent using the following identity [HFM(p69)]
+//
+//  e^x = 2^{x * log_2(e)}
+//
+static std::pair<Value *, Value *> emitExEmx(IRBuilder<> &builder, Value *X, OP *dxOp, StringRef name) {
+  Value *Zero  = ConstantFP::get(X->getType(), 0.0);
+  Value *Log2e = ConstantFP::get(X->getType(), math::LOG2E);
+
+  Value *r0 = builder.CreateFMul(X, Log2e, name);
+  Value *r1 = emitUnaryFloat(builder, r0, dxOp, OP::OpCode::Exp, name);
+  Value *r2 = builder.CreateFSub(Zero, r0, name);
+  Value *r3 = emitUnaryFloat(builder, r2, dxOp, OP::OpCode::Exp, name);
+
+  return std::make_pair(r1, r3);
+}
+
+// Asin
+// ----------------------------------------------------------------------------
+// Function
+//    arcsin X = pi/2  - sqrt(1 - X) * psi(X)
+//
+// Range
+//    0 <= X <= 1
+//
+// Approximation
+//    Psi*(X) = a0 + a1x + a2x^2 + a3x^3
+//      a0 =  1.5707288
+//      a1 = -0.2121144
+//      a2 =  0.0742610
+//      a3 = -0.0187293
+// 
+// The domain of the approximation is 0 <=x <= 1, but the domain of asin is
+// -1 <= x <= 1. So we need to perform a range reduction to [0,1] before
+// computing the approximation. 
+// 
+// We use the following identity from [HMF(p80),WIK] for range reduction
+// 
+// 	asin(-x) = -asin(x)
+// 
+// We take the absolute value of x, compute asin(x) using the approximation
+// and then negate the value if x < 0.
+//
+// In [HMF] the authors claim an error, e, of |e| <= 5e-5, but the error graph
+// in [ADC] looks like the error can be larger that that for some inputs.
+// 
+Value *DxilExpandTrigIntrinsics::expandASin(IRBuilder<> &builder, DxilInst_Asin asin, DxilModule &DM) {
+  assert(asin);
+  StringRef name = "asin.x";
+  Value *X = asin.get_value();
+  Value *PI_2 = ConstantFP::get(X->getType(), math::PI_2);
+  Value *Zero = ConstantFP::get(X->getType(), 0.0);
+  
+  // Range reduction to [0, 1]
+  Value *absX = emitFAbs(builder, X, DM.GetOP(), name);
+
+  // Approximation
+  Value *psiX = emitSqrt1mXtimesPsiX(builder, absX, DM.GetOP(), name);
+  Value *asinX = builder.CreateFSub(PI_2, psiX, name);
+  Value *asinmX = builder.CreateFSub(Zero, asinX, name);
+
+  // Range expansion to [-1, 1]
+  Value *lt0 = builder.CreateFCmp(CmpInst::FCMP_ULT, X, Zero, name);
+  Value *r = builder.CreateSelect(lt0, asinmX, asinX, name);
+
+  return r;
+}
+
+
+// Acos
+// ----------------------------------------------------------------------------
+// The acos expansion uses the following identity [WIK]. So that we can use the
+// same approximation psi*(x) that we use for asin.
+// 
+// 	acos(x) = pi/2 - asin(x)
+// 
+// Substituting the equation for asin(x) we get
+// 
+// 	acos(x) = pi/2 - asin(x)
+// 	        = pi/2 - (pi/2 - sqrt(1-x)*psi(x))
+// 	        = sqrt(1-x)*psi(x)
+// 
+// We use the following identity from [HMF(p80),WIK] for range reduction
+// 
+// 	acos(-x) = pi - acos(x)
+//               = pi - sqrt(1-x)*psi(x)
+//
+// We take the absolute value of x, compute acos(x) using the approximation
+// and then subtract from pi if x < 0.
+//
+Value *DxilExpandTrigIntrinsics::expandACos(IRBuilder<> &builder, DxilInst_Acos acos, DxilModule &DM) {
+  assert(acos);
+  StringRef name = "acos.x";
+  Value *X = acos.get_value();
+  Value *PI = ConstantFP::get(X->getType(), math::PI);
+  Value *Zero = ConstantFP::get(X->getType(), 0.0);
+  
+  // Range reduction to [0, 1]
+  Value *absX = emitFAbs(builder, X, DM.GetOP(), name);
+
+  // Approximation
+  Value *acosX = emitSqrt1mXtimesPsiX(builder, absX, DM.GetOP(), name);
+  Value *acosmX = builder.CreateFSub(PI, acosX, name);
+
+  // Range expansion to [-1, 1]
+  Value *lt0 = builder.CreateFCmp(CmpInst::FCMP_ULT, X, Zero, name);
+  Value *r = builder.CreateSelect(lt0, acosmX, acosX, name);
+
+  return r;
+}
+
+// Atan
+// ----------------------------------------------------------------------------
+// Function
+//    arctan X
+//
+// Range
+//    -1 <= X <= 1
+//
+// Approximation
+//    arctan*(x) = c1x + c3x^3 + c5x^5 + c7x^7 + c9x^9
+//      c1 =  0.9998660
+//      c3 = -0.3302995
+//      c5 =  0.1801410
+//      c7 = -0.0851330
+//      c9 =  0.0208351
+// 	
+// The polynomial is evaluated using Horner's method to efficiently compute the
+// value
+// 
+// 	  c1x + c3x^3 + c5x^5 + c7x^7 + c9x^9 
+// 	= x(c1 + c3x^2 + c5x^4 + c7x^6 + c9x^8)
+// 	= x(c1 + x^2(c3 + c5x^2 + c7x^4 + c9x^6))
+// 	= x(c1 + x^2(c3 + x^2(c5 + c7x^2 + c9x^4)))
+// 	= x(c1 + x^2(c3 + x^2(c5 + x^2(c7 + c9x^2))))
+// 	
+// The range reduction is a little more compilicated for atan because the
+// domain of atan is [-inf, inf], but the domain of the approximation is only
+// [-1, 1]. We use the following identities for range reduction from
+// [HMF(p80),WIK]
+// 	
+// 	arctan(-x) = -arctan(x)
+//      arctan(x)   = pi/2 - arctan(1/x) if x > 0
+// 
+// The first identity allows us to only work with positive numbers. The second
+// identity allows us to reduce the range to [0,1]. We first convert the value
+// to positive by taking abs(x). Then if x > 1 we compute arctan(1/x).
+// 
+// To expand the range we check if x > 1 then subtracted the computed value from
+// pi/2 and if x is negative then negate the final value.
+//
+Value *DxilExpandTrigIntrinsics::expandATan(IRBuilder<> &builder, DxilInst_Atan atan, DxilModule &DM) {
+  assert(atan);
+  StringRef name  = "atan.x";
+  Value *X = atan.get_value();
+  Value *PI_2 = ConstantFP::get(X->getType(), math::PI_2);
+  Value *One  = ConstantFP::get(X->getType(), 1.0);
+  Value *Zero = ConstantFP::get(X->getType(), 0.0);
+  Value *c1 = ConstantFP::get(X->getType(),  0.9998660);
+  Value *c3 = ConstantFP::get(X->getType(), -0.3302995);
+  Value *c5 = ConstantFP::get(X->getType(),  0.1801410);
+  Value *c7 = ConstantFP::get(X->getType(), -0.0851330);
+  Value *c9 = ConstantFP::get(X->getType(),  0.0208351);
+
+  // Range reduction to [0, inf]
+  Value *absX = emitFAbs(builder, X, DM.GetOP(), name);
+
+  // Range reduction to [0, 1]
+  Value *gt1 = builder.CreateFCmp(CmpInst::FCMP_UGT, absX, One, name);
+  Value *r1 = builder.CreateFDiv(One, absX, name);
+  Value *r2 = builder.CreateSelect(gt1, r1, absX, name);
+
+  // Approximate
+  Value *r3 = builder.CreateFMul(r2, r2, name);
+  Value *r4 = builder.CreateFMul(r3, c9, name);
+         r4 = builder.CreateFAdd(r4, c7, name);
+         r4 = builder.CreateFMul(r4, r3, name);
+         r4 = builder.CreateFAdd(r4, c5, name);
+         r4 = builder.CreateFMul(r4, r3, name);
+         r4 = builder.CreateFAdd(r4, c3, name);
+         r4 = builder.CreateFMul(r4, r3, name);
+         r4 = builder.CreateFAdd(r4, c1, name);
+         r4 = builder.CreateFMul(r2, r4, name);
+
+  // Range Expansion to [0, inf]
+  Value *r5 = builder.CreateFSub(PI_2, r4, name);
+  Value *r6 = builder.CreateSelect(gt1, r5, r4, name);
+
+  // Range Expansion to [-inf, inf]
+  Value *r7 = builder.CreateFSub(Zero, r6, name);
+  Value *lt0 = builder.CreateFCmp(CmpInst::FCMP_ULT, X, Zero, name);
+  Value *r = builder.CreateSelect(lt0, r7, r6, name);
+
+  return r;
+}
+
+// Hcos
+// ----------------------------------------------------------------------------
+// We use the following identity for computing hcos(x) from [HMF(p83)]
+// 	
+//    cosh(x) = (e^x + e^-x) / 2
+// 
+// No range reduction is needed.
+//
+Value *DxilExpandTrigIntrinsics::expandHCos(IRBuilder<> &builder, DxilInst_Hcos hcos, DxilModule &DM) {
+  assert(hcos);
+  StringRef name = "hcos.x";
+  Value *eX, *emX;
+  Value *X = hcos.get_value();
+  Value *Two = ConstantFP::get(X->getType(), 2.0);
+
+  std::tie(eX, emX) = emitExEmx(builder, X, DM.GetOP(), name);
+  Value *r4 = builder.CreateFAdd(eX, emX, name);
+  Value *r  = builder.CreateFDiv(r4, Two, name);
+
+  return r;
+}
+
+// Hsin
+// ----------------------------------------------------------------------------
+// We use the following identity for computing hsin(x) from[HMF(p83)]
+//
+//    sinh(x) = (e^x - e^-x) / 2
+//
+// No range reduction is needed.
+//
+Value *DxilExpandTrigIntrinsics::expandHSin(IRBuilder<> &builder, DxilInst_Hsin hsin, DxilModule &DM) {
+  assert(hsin);
+  StringRef name = "hsin.x";
+  Value *eX, *emX;
+  Value *X = hsin.get_value();
+  Value *Two = ConstantFP::get(X->getType(), 2.0);
+
+  std::tie(eX, emX) = emitExEmx(builder, X, DM.GetOP(), name);
+  Value *r4 = builder.CreateFSub(eX, emX, name);
+  Value *r  = builder.CreateFDiv(r4, Two, name);
+
+  return r;
+}
+
+// Htan
+// ----------------------------------------------------------------------------
+// We use the following identity for computing hsin(x) from[HMF(p83)]
+//
+//    tanh(x) = (e^x - e^-x) / (e^x + e^-x)
+//
+// No range reduction is needed.
+//
+Value *DxilExpandTrigIntrinsics::expandHTan(IRBuilder<> &builder, DxilInst_Htan htan, DxilModule &DM) {
+  assert(htan);
+  StringRef name = "htan.x";
+  Value *eX, *emX;
+  Value *X = htan.get_value();
+
+  std::tie(eX, emX) = emitExEmx(builder, X, DM.GetOP(), name);
+  Value *r4 = builder.CreateFSub(eX, emX, name);
+  Value *r5 = builder.CreateFAdd(eX, emX, name);
+  Value *r  = builder.CreateFDiv(r4, r5, name);
+
+  return r;
+}
+
+char DxilExpandTrigIntrinsics::ID = 0;
+
+FunctionPass *llvm::createDxilExpandTrigIntrinsicsPass() {
+  return new DxilExpandTrigIntrinsics();
+}
+
+INITIALIZE_PASS(DxilExpandTrigIntrinsics,
+                "hlsl-dxil-expand-trig-intrinsics",
+                "DXIL expand trig intrinsics", false, false)

+ 27 - 0
tools/clang/test/HLSL/expand_trig/acos.hlsl

@@ -0,0 +1,27 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// CHECK: [[X:%.*]]   = call float @dx.op.loadInput.f32(i32 4
+// CHECK: [[r0:%.*]]  = call float @dx.op.unary.f32(i32 6, float [[X]]
+
+// CHECK: [[r1:%.*]]  = fsub fast float 1.000000e+00, [[r0]]
+// CHECK: [[r2:%.*]]  = call float @dx.op.unary.f32(i32 24, float [[r1]]
+
+// CHECK: [[r3a:%.*]] = fmul fast float [[r0]], 0xBF932DC600000000
+// CHECK: [[r3b:%.*]] = fadd fast float [[r3a]], 0x3FB302C4E0000000
+// CHECK: [[r3c:%.*]] = fmul fast float [[r0]], [[r3b]]
+// CHECK: [[r3d:%.*]] = fadd fast float [[r3c]], 0xBFCB269080000000
+// CHECK: [[r3e:%.*]] = fmul fast float [[r0]], [[r3d]]
+// CHECK: [[r3f:%.*]] = fadd fast float [[r3e]], 0x3FF921B480000000
+// CHECK: [[r4:%.*]]  = fmul fast float [[r2]], [[r3f]]
+
+// CHECK: [[r5:%.*]]  = fsub fast float 0x400921FB60000000, [[r4]]
+
+// CHECK: [[b0:%.*]]  = fcmp fast ult float [[X]], 0.000000e+00
+// CHECK: select i1 [[b0]], float [[r5]], float [[r4]]
+
+// CHECK-NOT: call float @dx.op.unary.f32(i32 15
+
+[RootSignature("")]
+float main(float x : A) : SV_Target {
+    return acos(x);
+}

+ 12 - 0
tools/clang/test/HLSL/expand_trig/acos_h.hlsl

@@ -0,0 +1,12 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// Make sure the expansion works for half.
+// Only checking for for minimal expansion here, full check is done for float case.
+
+// CHECK: fmul fast half %{{.*}}, 0xHA4CB
+
+
+[RootSignature("")]
+min16float main(min16float x : A) : SV_Target {
+    return acos(x);
+}

+ 28 - 0
tools/clang/test/HLSL/expand_trig/asin.hlsl

@@ -0,0 +1,28 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// CHECK: [[X:%.*]]   = call float @dx.op.loadInput.f32(i32 4
+// CHECK: [[r0:%.*]]  = call float @dx.op.unary.f32(i32 6, float [[X]]
+
+// CHECK: [[r1:%.*]]  = fsub fast float 1.000000e+00, [[r0]]
+// CHECK: [[r2:%.*]]  = call float @dx.op.unary.f32(i32 24, float [[r1]]
+
+// CHECK: [[r3a:%.*]] = fmul fast float [[r0]], 0xBF932DC600000000
+// CHECK: [[r3b:%.*]] = fadd fast float [[r3a]], 0x3FB302C4E0000000
+// CHECK: [[r3c:%.*]] = fmul fast float [[r0]], [[r3b]]
+// CHECK: [[r3d:%.*]] = fadd fast float [[r3c]], 0xBFCB269080000000
+// CHECK: [[r3e:%.*]] = fmul fast float [[r0]], [[r3d]]
+// CHECK: [[r3f:%.*]] = fadd fast float [[r3e]], 0x3FF921B480000000
+// CHECK: [[r4:%.*]]  = fmul fast float [[r2]], [[r3f]]
+
+// CHECK: [[r5:%.*]]  = fsub fast float 0x3FF921FB60000000, [[r4]]
+// CHECK: [[r6:%.*]]  = fsub fast float 0.000000e+00, [[r5]]
+
+// CHECK: [[b0:%.*]]  = fcmp fast ult float [[X]], 0.000000e+00
+// CHECK: select i1 [[b0]], float [[r6]], float [[r5]]
+
+// CHECK-NOT: call float @dx.op.unary.f32(i32 16
+
+[RootSignature("")]
+float main(float x : A) : SV_Target {
+    return asin(x);
+}

+ 12 - 0
tools/clang/test/HLSL/expand_trig/asin_h.hlsl

@@ -0,0 +1,12 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// Make sure the expansion works for half.
+// Only checking for for minimal expansion here, full check is done for float case.
+
+// CHECK: fmul fast half %{{.*}}, 0xHA4CB
+
+
+[RootSignature("")]
+min16float main(min16float x : A) : SV_Target {
+    return asin(x);
+}

+ 35 - 0
tools/clang/test/HLSL/expand_trig/atan.hlsl

@@ -0,0 +1,35 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// CHECK: [[X:%.*]]   = call float @dx.op.loadInput.f32(i32 4
+// CHECK: [[r0:%.*]]  = call float @dx.op.unary.f32(i32 6, float [[X]]
+
+// CHECK: [[b0:%.*]]  = fcmp fast ugt float [[r0]], 1.000000e+00
+// CHECK: [[r1:%.*]]  = fdiv fast float 1.000000e+00, [[r0]]
+// CHECK: [[r2:%.*]]  = select i1 [[b0]], float [[r1]], float [[r0]]
+
+// CHECK: [[r3:%.*]]  = fmul fast float [[r2]],  [[r2]]
+// CHECK: [[r4a:%.*]] = fmul fast float [[r3]],  0x3F9555CBE0000000
+// CHECK: [[r4b:%.*]] = fadd fast float [[r4a]], 0xBFB5CB46C0000000 
+// CHECK: [[r4c:%.*]] = fmul fast float [[r4b]], [[r3]]
+// CHECK: [[r4d:%.*]] = fadd fast float [[r4c]], 0x3FC70EDC40000000
+// CHECK: [[r4e:%.*]] = fmul fast float [[r4d]], [[r3]]
+// CHECK: [[r4f:%.*]] = fadd fast float [[r4e]], 0xBFD523A080000000
+// CHECK: [[r4g:%.*]] = fmul fast float [[r4f]], [[r3]]
+// CHECK: [[r4h:%.*]] = fadd fast float [[r4g]], 0x3FEFFEE700000000
+// CHECK: [[r4:%.*]]  = fmul fast float [[r2]],  [[r4h]]
+
+// CHECK: [[r5:%.*]]  = fsub fast float 0x3FF921FB60000000, [[r4]]
+// CHECK: [[r6:%.*]]  = select i1 [[b0]], float [[r5]], float [[r4]]
+
+// CHECK: [[r7:%.*]]  = fsub fast float 0.000000e+00, [[r6]]
+
+// CHECK: [[b1:%.*]]  = fcmp fast ult float [[X]], 0.000000e+00
+// CHECK: select i1 [[b1]], float [[r7]], float [[r6]]
+
+
+// CHECK-NOT: call float @dx.op.unary.f32(i32 17
+
+[RootSignature("")]
+float main(float x : A) : SV_Target {
+    return atan(x);
+}

+ 12 - 0
tools/clang/test/HLSL/expand_trig/atan_h.hlsl

@@ -0,0 +1,12 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// Make sure the expansion works for half.
+// Only checking for for minimal expansion here, full check is done for float case.
+
+// CHECK: fmul fast half %{{.*}}, 0xH2555
+
+
+[RootSignature("")]
+min16float main(min16float x : A) : SV_Target {
+    return atan(x);
+}

+ 16 - 0
tools/clang/test/HLSL/expand_trig/hcos.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// CHECK: [[X:%.*]]   = call float @dx.op.loadInput.f32(i32 4
+// CHECK: [[r0:%.*]]  = fmul fast float [[X]], 0x3FF7154760000000
+// CHECK: [[r1:%.*]]  = call float @dx.op.unary.f32(i32 21, float [[r0]]
+// CHECK: [[r2:%.*]]  = fsub fast float 0.000000e+00, [[r0]]
+// CHECK: [[r3:%.*]]  = call float @dx.op.unary.f32(i32 21, float [[r2]]
+// CHECK: [[r4:%.*]]  = fadd fast float [[r1]], [[r3]]
+// CHECK: fdiv fast float [[r4]], 2.000000e+00
+
+// CHECK-NOT: call float @dx.op.unary.f32(i32 18
+
+[RootSignature("")]
+float main(float x : A) : SV_Target {
+    return cosh(x);
+}

+ 12 - 0
tools/clang/test/HLSL/expand_trig/hcos_h.hlsl

@@ -0,0 +1,12 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// Make sure the expansion works for half.
+// Only checking for for minimal expansion here, full check is done for float case.
+
+// CHECK: fmul fast half %{{.*}}, 0xH3DC5
+
+
+[RootSignature("")]
+min16float main(min16float x : A) : SV_Target {
+    return cosh(x);
+}

+ 16 - 0
tools/clang/test/HLSL/expand_trig/hsin.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// CHECK: [[X:%.*]]   = call float @dx.op.loadInput.f32(i32 4
+// CHECK: [[r0:%.*]]  = fmul fast float [[X]], 0x3FF7154760000000
+// CHECK: [[r1:%.*]]  = call float @dx.op.unary.f32(i32 21, float [[r0]]
+// CHECK: [[r2:%.*]]  = fsub fast float 0.000000e+00, [[r0]]
+// CHECK: [[r3:%.*]]  = call float @dx.op.unary.f32(i32 21, float [[r2]]
+// CHECK: [[r4:%.*]]  = fsub fast float [[r1]], [[r3]]
+// CHECK: fdiv fast float [[r4]], 2.000000e+00
+
+// CHECK-NOT: call float @dx.op.unary.f32(i32 18
+
+[RootSignature("")]
+float main(float x : A) : SV_Target {
+    return sinh(x);
+}

+ 12 - 0
tools/clang/test/HLSL/expand_trig/hsin_h.hlsl

@@ -0,0 +1,12 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// Make sure the expansion works for half.
+// Only checking for for minimal expansion here, full check is done for float case.
+
+// CHECK: fmul fast half %{{.*}}, 0xH3DC5
+
+
+[RootSignature("")]
+min16float main(min16float x : A) : SV_Target {
+    return sinh(x);
+}

+ 17 - 0
tools/clang/test/HLSL/expand_trig/htan.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// CHECK: [[X:%.*]]   = call float @dx.op.loadInput.f32(i32 4
+// CHECK: [[r0:%.*]]  = fmul fast float [[X]], 0x3FF7154760000000
+// CHECK: [[r1:%.*]]  = call float @dx.op.unary.f32(i32 21, float [[r0]]
+// CHECK: [[r2:%.*]]  = fsub fast float 0.000000e+00, [[r0]]
+// CHECK: [[r3:%.*]]  = call float @dx.op.unary.f32(i32 21, float [[r2]]
+// CHECK: [[r4:%.*]]  = fsub fast float [[r1]], [[r3]]
+// CHECK: [[r5:%.*]]  = fadd fast float [[r1]], [[r3]]
+// CHECK: fdiv fast float [[r4]], [[r5]]
+
+// CHECK-NOT: call float @dx.op.unary.f32(i32 18
+
+[RootSignature("")]
+float main(float x : A) : SV_Target {
+    return tanh(x);
+}

+ 12 - 0
tools/clang/test/HLSL/expand_trig/htan_h.hlsl

@@ -0,0 +1,12 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// Make sure the expansion works for half.
+// Only checking for for minimal expansion here, full check is done for float case.
+
+// CHECK: fmul fast half %{{.*}}, 0xH3DC5
+
+
+[RootSignature("")]
+min16float main(min16float x : A) : SV_Target {
+    return tanh(x);
+}

+ 19 - 0
tools/clang/test/HLSL/expand_trig/keep_precise.0.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// Make sure that when the call is precise we do not use fast math flags
+// on the floating point instructions and add precise metadata to the
+// generated dxil calls.
+
+// CHECK: [[X:%.*]]   = call float @dx.op.loadInput.f32(i32 4
+// CHECK: [[r0:%.*]]  = fmul float [[X]], 0x3FF7154760000000
+// CHECK: [[r1:%.*]]  = call float @dx.op.unary.f32(i32 21, float [[r0]]), !dx.precise
+// CHECK: [[r2:%.*]]  = fsub float 0.000000e+00, [[r0]]
+// CHECK: [[r3:%.*]]  = call float @dx.op.unary.f32(i32 21, float [[r2]]), !dx.precise
+// CHECK: [[r4:%.*]]  = fsub float [[r1]], [[r3]]
+// CHECK: [[r5:%.*]]  = fadd float [[r1]], [[r3]]
+// CHECK: fdiv float [[r4]], [[r5]]
+
+[RootSignature("")]
+precise float main(float x : A) : SV_Target {
+    return tanh(x);
+}

+ 30 - 0
tools/clang/test/HLSL/expand_trig/keep_precise.1.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -Emain -Tps_6_0 %s | %opt -S -hlsl-dxil-expand-trig-intrinsics | %FileCheck %s
+
+// Make sure precise->non-precise->precise transition is handled properly.
+
+// A
+// CHECK: fmul float {{.*}}, 0x3FF7154760000000
+// CHECK: call float @dx.op.unary.f32(i32 21, float {{.*}}), !dx.precise
+// CHECK: call float @dx.op.unary.f32(i32 21, float {{.*}}), !dx.precise
+
+// B
+// CHECK: fmul fast float {{.*}}, 0x3FF7154760000000
+// CHECK: call float @dx.op.unary.f32(i32 21, float {{.*}})
+// CHECK-NOT: !dx.precise
+// CHECK: call float @dx.op.unary.f32(i32 21, float {{.*}})
+// CHECK-NOT: !dx.precise
+
+// C
+// CHECK: fmul float {{.*}}, 0x3FF7154760000000
+// CHECK: call float @dx.op.unary.f32(i32 21, float {{.*}}), !dx.precise
+// CHECK: call float @dx.op.unary.f32(i32 21, float {{.*}}), !dx.precise
+
+// CHECK: ret
+
+[RootSignature("")]
+float main(float x : A, float y : B, float z : C) : SV_Target {
+    precise float a = tanh(x);
+            float b = tanh(y);
+    precise float c = tanh(z);
+    return a + b + c;
+}

+ 18 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -441,6 +441,7 @@ public:
   TEST_METHOD(CodeGenEvalMatMember)
   TEST_METHOD(CodeGenEvalPos)
   TEST_METHOD(CodeGenExternRes)
+  TEST_METHOD(CodeGenExpandTrig)
   TEST_METHOD(CodeGenFloatCast)
   TEST_METHOD(CodeGenFloatToBool)
   TEST_METHOD(CodeGenFirstbitHi)
@@ -2518,6 +2519,23 @@ TEST_F(CompilerTest, CodeGenExternRes) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\extern_res.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenExpandTrig) {
+  CodeGenTestCheck(L"expand_trig\\acos.hlsl");
+  CodeGenTestCheck(L"expand_trig\\acos_h.hlsl");
+  CodeGenTestCheck(L"expand_trig\\asin.hlsl");
+  CodeGenTestCheck(L"expand_trig\\asin_h.hlsl");
+  CodeGenTestCheck(L"expand_trig\\atan.hlsl");
+  CodeGenTestCheck(L"expand_trig\\atan_h.hlsl");
+  CodeGenTestCheck(L"expand_trig\\hcos.hlsl");
+  CodeGenTestCheck(L"expand_trig\\hcos_h.hlsl");
+  CodeGenTestCheck(L"expand_trig\\hsin.hlsl");
+  CodeGenTestCheck(L"expand_trig\\hsin_h.hlsl");
+  CodeGenTestCheck(L"expand_trig\\htan.hlsl");
+  CodeGenTestCheck(L"expand_trig\\htan_h.hlsl");
+  CodeGenTestCheck(L"expand_trig\\keep_precise.0.hlsl");
+  CodeGenTestCheck(L"expand_trig\\keep_precise.1.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenFloatCast) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\float_cast.hlsl");
 }

+ 1 - 0
utils/hct/hctdb.py

@@ -1264,6 +1264,7 @@ class db_dxil(object):
         add_pass('hlsl-dxil-eliminate-output-dynamic', 'DxilEliminateOutputDynamicIndexing', 'DXIL eliminate ouptut dynamic indexing', [])
         add_pass('hlsl-dxilemit', 'DxilEmitMetadata', 'HLSL DXIL Metadata Emit', [])
         add_pass('hlsl-dxilload', 'DxilLoadMetadata', 'HLSL DXIL Metadata Load', [])
+        add_pass('hlsl-dxil-expand-trig', 'DxilExpandTrigIntrinsics', 'DXIL expand trig intrinsics', [])
         add_pass('hlsl-hca', 'HoistConstantArray', 'HLSL constant array hoisting', [])
         add_pass('ipsccp', 'IPSCCP', 'Interprocedural Sparse Conditional Constant Propagation', [])
         add_pass('globalopt', 'GlobalOpt', 'Global Variable Optimizer', [])