瀏覽代碼

Add constant evaluation for clamp() (#3581)

(cherry picked from commit 2bda44fc0602a6f6c97ca7e8f40c26133dc8d1a3)
Tex Riddell 4 年之前
父節點
當前提交
08e2587dca

+ 92 - 0
tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

@@ -1460,6 +1460,10 @@ typedef APInt(__cdecl *IntBinaryEvalFuncType)(const APInt &, const APInt &);
 typedef float(__cdecl *FloatBinaryEvalFuncType)(float, float);
 typedef double(__cdecl *DoubleBinaryEvalFuncType)(double, double);
 
+typedef APInt(__cdecl *IntTernaryEvalFuncType)(const APInt &, const APInt &, const APInt &);
+typedef float(__cdecl *FloatTernaryEvalFuncType)(float, float, float);
+typedef double(__cdecl *DoubleTernaryEvalFuncType)(double, double, double);
+
 Value *EvalUnaryIntrinsic(ConstantFP *fpV, FloatUnaryEvalFuncType floatEvalFunc,
                           DoubleUnaryEvalFuncType doubleEvalFunc) {
   llvm::Type *Ty = fpV->getType();
@@ -1510,6 +1514,45 @@ Value *EvalBinaryIntrinsic(Constant *cV0, Constant *cV1,
   return Result;
 }
 
+Value *EvalTernaryIntrinsic(Constant *cV0, Constant *cV1, Constant *cV2,
+                             FloatTernaryEvalFuncType floatEvalFunc,
+                             DoubleTernaryEvalFuncType doubleEvalFunc,
+                             IntTernaryEvalFuncType intEvalFunc) {
+  llvm::Type *Ty = cV0->getType();
+  Value *Result = nullptr;
+  if (Ty->isDoubleTy()) {
+    ConstantFP *fpV0 = cast<ConstantFP>(cV0);
+    ConstantFP *fpV1 = cast<ConstantFP>(cV1);
+    ConstantFP *fpV2 = cast<ConstantFP>(cV2);
+    double dV0 = fpV0->getValueAPF().convertToDouble();
+    double dV1 = fpV1->getValueAPF().convertToDouble();
+    double dV2 = fpV2->getValueAPF().convertToDouble();
+    Value *dResult = ConstantFP::get(Ty, doubleEvalFunc(dV0, dV1, dV2));
+    Result = dResult;
+  } else if (Ty->isFloatTy()) {
+    ConstantFP *fpV0 = cast<ConstantFP>(cV0);
+    ConstantFP *fpV1 = cast<ConstantFP>(cV1);
+    ConstantFP *fpV2 = cast<ConstantFP>(cV2);
+    float fV0 = fpV0->getValueAPF().convertToFloat();
+    float fV1 = fpV1->getValueAPF().convertToFloat();
+    float fV2 = fpV2->getValueAPF().convertToFloat();
+    Value *dResult = ConstantFP::get(Ty, floatEvalFunc(fV0, fV1, fV2));
+    Result = dResult;
+  } else {
+    DXASSERT_NOMSG(Ty->isIntegerTy());
+    DXASSERT_NOMSG(intEvalFunc);
+    ConstantInt *ciV0 = cast<ConstantInt>(cV0);
+    ConstantInt *ciV1 = cast<ConstantInt>(cV1);
+    ConstantInt *ciV2 = cast<ConstantInt>(cV2);
+    const APInt &iV0 = ciV0->getValue();
+    const APInt &iV1 = ciV1->getValue();
+    const APInt &iV2 = ciV2->getValue();
+    Value *dResult = ConstantInt::get(Ty, intEvalFunc(iV0, iV1, iV2));
+    Result = dResult;
+  }
+  return Result;
+}
+
 Value *EvalUnaryIntrinsic(CallInst *CI, FloatUnaryEvalFuncType floatEvalFunc,
                           DoubleUnaryEvalFuncType doubleEvalFunc) {
   Value *V = CI->getArgOperand(0);
@@ -1566,6 +1609,43 @@ Value *EvalBinaryIntrinsic(CallInst *CI, FloatBinaryEvalFuncType floatEvalFunc,
   return Result;
 }
 
+Value *EvalTernaryIntrinsic(CallInst *CI, FloatTernaryEvalFuncType floatEvalFunc,
+                             DoubleTernaryEvalFuncType doubleEvalFunc,
+                             IntTernaryEvalFuncType intEvalFunc = nullptr) {
+  Value *V0 = CI->getArgOperand(0);
+  Value *V1 = CI->getArgOperand(1);
+  Value *V2 = CI->getArgOperand(2);
+  llvm::Type *Ty = CI->getType();
+  Value *Result = nullptr;
+  if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
+    Result = UndefValue::get(Ty);
+    Constant *CV0 = cast<Constant>(V0);
+    Constant *CV1 = cast<Constant>(V1);
+    Constant *CV2 = cast<Constant>(V2);
+    IRBuilder<> Builder(CI);
+    for (unsigned i = 0; i < VT->getNumElements(); i++) {
+      Constant *cV0 = cast<Constant>(CV0->getAggregateElement(i));
+      Constant *cV1 = cast<Constant>(CV1->getAggregateElement(i));
+      Constant *cV2 = cast<Constant>(CV2->getAggregateElement(i));
+      Value *EltResult = EvalTernaryIntrinsic(cV0, cV1, cV2, floatEvalFunc,
+                                             doubleEvalFunc, intEvalFunc);
+      Result = Builder.CreateInsertElement(Result, EltResult, i);
+    }
+  } else {
+    Constant *cV0 = cast<Constant>(V0);
+    Constant *cV1 = cast<Constant>(V1);
+    Constant *cV2 = cast<Constant>(V2);
+    Result = EvalTernaryIntrinsic(cV0, cV1, cV2, floatEvalFunc, doubleEvalFunc,
+                                 intEvalFunc);
+  }
+  CI->replaceAllUsesWith(Result);
+  CI->eraseFromParent();
+  return Result;
+
+  CI->eraseFromParent();
+  return Result;
+}
+
 void SimpleTransformForHLDXIRInst(Instruction *I, SmallInstSet &deadInsts) {
 
   unsigned opcode = I->getOpcode();
@@ -1789,6 +1869,18 @@ Value *TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp,
     CI->eraseFromParent();
     return cNan;
   } break;
+  case IntrinsicOp::IOP_clamp: {
+    auto clampF = [](float a, float b, float c) {
+      return a < b ? b : a > c ? c : a;
+    };
+    auto clampD = [](double a, double b, double c) {
+      return a < b ? b : a > c ? c : a;
+    };
+    auto clampI = [](const APInt &a, const APInt &b, const APInt &c) -> APInt {
+      return a.slt(b) ? b : a.sgt(c) ? c : a;
+    };
+    return EvalTernaryIntrinsic(CI, clampF, clampD, clampI);
+  } break;
   default:
     return nullptr;
   }

+ 14 - 0
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/basic/clamp_const_prop.hlsl

@@ -0,0 +1,14 @@
+// RUN: %dxc -T ps_6_0 %s -E main | %FileCheck %s
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 1.000000e+00)
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float -1.250000e+00)
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float 3.000000e+00)
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float 2.000000e+00)
+
+[RootSignature("")]
+float4 main() : SV_Target {
+  return float4(
+    clamp(10, 0, 1),
+    clamp(-1.0f, -2.5f, -1.25f),
+    clamp((double)3, (double)-2, (double)5),
+    clamp(-5LL, 2LL, 5LL));
+}