Răsfoiți Sursa

Support vector constant in TryEvalIntrinsic. (#902)

Xiang Li 7 ani în urmă
părinte
comite
c011ec3d55

+ 63 - 0
tools/clang/lib/CodeGen/CGExpr.cpp

@@ -1339,6 +1339,32 @@ void CodeGenFunction::EmitStoreOfScalar(llvm::Value *value, LValue lvalue,
                     lvalue.getTBAAOffset());
 }
 
+// HLSL Change Begin - find immediate value for literal.
+static llvm::Value *GetStoredValue(llvm::Value *Ptr) {
+  llvm::Value *V = nullptr;
+  for (llvm::User *U : Ptr->users()) {
+    if (llvm::StoreInst *ST = dyn_cast<llvm::StoreInst>(U)) {
+      if (V) {
+        // More than one store.
+        // Skip.
+        V = nullptr;
+        break;
+      }
+      V = ST->getValueOperand();
+    }
+  }
+  return V;
+}
+static bool IsLiteralType(QualType QT) {
+  if (const BuiltinType *BTy = QT->getAs<BuiltinType>()) {
+    if (BTy->getKind() == BuiltinType::LitFloat ||
+        BTy->getKind() == BuiltinType::LitInt)
+      return true;
+  }
+  return false;
+}
+// HLSL Change End.
+
 /// EmitLoadOfLValue - Given an expression that represents a value lvalue, this
 /// method emits the address of the lvalue, then loads the result as an rvalue,
 /// returning the rvalue.
@@ -1357,12 +1383,33 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
 
   if (LV.isSimple()) {
     assert(!LV.getType()->isFunctionType());
+    // HLSL Change Begin - find immediate value for literal.
+    if (IsLiteralType(LV.getType())) {
+      // The value must be stored only once.
+      // Scan all use to find it.
+      llvm::Value *Ptr = LV.getAddress();
+      if (llvm::Value *V = GetStoredValue(Ptr)) {
+        return RValue::get(V);
+      }
+    }
+    // HLSL Change End.
 
     // Everything needs a load.
     return RValue::get(EmitLoadOfScalar(LV, Loc));
   }
 
   if (LV.isVectorElt()) {
+    // HLSL Change Begin - find immediate value for literal.
+    if (IsLiteralType(LV.getType())) {
+      // The value must be stored only once.
+      // Scan all use to find it.
+      llvm::Value *Ptr = LV.getAddress();
+      if (llvm::Value *V = GetStoredValue(Ptr)) {
+        return RValue::get(Builder.CreateExtractElement(V,
+               LV.getVectorIdx(), "vecext"));
+      }
+    }
+    // HLSL Change End.
     llvm::LoadInst *Load = Builder.CreateLoad(LV.getVectorAddr(),
                                               LV.isVolatileQualified());
     Load->setAlignment(LV.getAlignment().getQuantity());
@@ -1438,6 +1485,22 @@ RValue CodeGenFunction::EmitLoadOfExtVectorElementLValue(LValue LV) {
     ExprVT =
         hlsl::ConvertHLSLVecMatTypeToExtVectorType(getContext(), LV.getType());
   // HLSL Change Ends
+
+  // HLSL Change Begin - find immediate value for literal.
+  QualType QT = LV.getType();
+  if (ExprVT) {
+    QT = ExprVT->getElementType();
+  }
+  if (IsLiteralType(QT)) {
+    // The value must be stored only once.
+    // Scan all use to find it.
+    llvm::Value *Ptr = LV.getExtVectorAddr();
+    if (llvm::Value *V = GetStoredValue(Ptr)) {
+      Vec = V;
+    }
+  }
+  // HLSL Change End.
+
   if (!ExprVT) {
     unsigned InIdx = getAccessedFieldNo(0, Elts);
     llvm::Value *Elt = llvm::ConstantInt::get(SizeTy, InIdx);

+ 63 - 24
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -3585,56 +3585,94 @@ typedef double(__cdecl *DoubleUnaryEvalFuncType)(double);
 typedef float(__cdecl *FloatBinaryEvalFuncType)(float, float);
 typedef double(__cdecl *DoubleBinaryEvalFuncType)(double, double);
 
-static Value * EvalUnaryIntrinsic(CallInst *CI,
+static Value * EvalUnaryIntrinsic(ConstantFP *fpV,
                                FloatUnaryEvalFuncType floatEvalFunc,
                                DoubleUnaryEvalFuncType doubleEvalFunc) {
-  Value *V = CI->getArgOperand(0);
-  ConstantFP *fpV = cast<ConstantFP>(V);
-  llvm::Type *Ty = CI->getType();
+  llvm::Type *Ty = fpV->getType();
   Value *Result = nullptr;
   if (Ty->isDoubleTy()) {
     double dV = fpV->getValueAPF().convertToDouble();
-    Value *dResult = ConstantFP::get(V->getType(), doubleEvalFunc(dV));
-
-    CI->replaceAllUsesWith(dResult);
+    Value *dResult = ConstantFP::get(Ty, doubleEvalFunc(dV));
     Result = dResult;
   } else {
     DXASSERT_NOMSG(Ty->isFloatTy());
     float fV = fpV->getValueAPF().convertToFloat();
-    Value *dResult = ConstantFP::get(V->getType(), floatEvalFunc(fV));
-
-    CI->replaceAllUsesWith(dResult);
+    Value *dResult = ConstantFP::get(Ty, floatEvalFunc(fV));
     Result = dResult;
   }
-
-  CI->eraseFromParent();
   return Result;
 }
 
-static Value * EvalBinaryIntrinsic(CallInst *CI,
+static Value * EvalBinaryIntrinsic(ConstantFP *fpV0, ConstantFP *fpV1,
                                FloatBinaryEvalFuncType floatEvalFunc,
                                DoubleBinaryEvalFuncType doubleEvalFunc) {
-  Value *V0 = CI->getArgOperand(0);
-  ConstantFP *fpV0 = cast<ConstantFP>(V0);
-  Value *V1 = CI->getArgOperand(1);
-  ConstantFP *fpV1 = cast<ConstantFP>(V1);
-  llvm::Type *Ty = CI->getType();
+  llvm::Type *Ty = fpV0->getType();
   Value *Result = nullptr;
   if (Ty->isDoubleTy()) {
     double dV0 = fpV0->getValueAPF().convertToDouble();
     double dV1 = fpV1->getValueAPF().convertToDouble();
-    Value *dResult = ConstantFP::get(V0->getType(), doubleEvalFunc(dV0, dV1));
-    CI->replaceAllUsesWith(dResult);
+    Value *dResult = ConstantFP::get(Ty, doubleEvalFunc(dV0, dV1));
     Result = dResult;
   } else {
     DXASSERT_NOMSG(Ty->isFloatTy());
     float fV0 = fpV0->getValueAPF().convertToFloat();
     float fV1 = fpV1->getValueAPF().convertToFloat();
-    Value *dResult = ConstantFP::get(V0->getType(), floatEvalFunc(fV0, fV1));
-
-    CI->replaceAllUsesWith(dResult);
+    Value *dResult = ConstantFP::get(Ty, floatEvalFunc(fV0, fV1));
     Result = dResult;
   }
+  return Result;
+}
+
+static Value * EvalUnaryIntrinsic(CallInst *CI,
+                               FloatUnaryEvalFuncType floatEvalFunc,
+                               DoubleUnaryEvalFuncType doubleEvalFunc) {
+  Value *V = CI->getArgOperand(0);
+  llvm::Type *Ty = CI->getType();
+  Value *Result = nullptr;
+  if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
+    Result = UndefValue::get(Ty);
+    Constant *CV = cast<Constant>(V);
+    IRBuilder<> Builder(CI);
+    for (unsigned i=0;i<VT->getNumElements();i++) {
+      ConstantFP *fpV = cast<ConstantFP>(CV->getAggregateElement(i));
+      Value *EltResult = EvalUnaryIntrinsic(fpV, floatEvalFunc, doubleEvalFunc);
+      Result = Builder.CreateInsertElement(Result, EltResult, i);
+    }
+  } else {
+    ConstantFP *fpV = cast<ConstantFP>(V);
+    Result = EvalUnaryIntrinsic(fpV, floatEvalFunc, doubleEvalFunc);
+  }
+  CI->replaceAllUsesWith(Result);
+  CI->eraseFromParent();
+  return Result;
+}
+
+static Value * EvalBinaryIntrinsic(CallInst *CI,
+                               FloatBinaryEvalFuncType floatEvalFunc,
+                               DoubleBinaryEvalFuncType doubleEvalFunc) {
+  Value *V0 = CI->getArgOperand(0);
+  Value *V1 = CI->getArgOperand(1);
+  llvm::Type *Ty = CI->getType();
+  Value *Result = nullptr;
+  if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
+    Result = UndefValue::get(Ty);
+    Constant *CV0 = cast<Constant>(V0);
+    Constant *CV1 = cast<Constant>(V1);
+    IRBuilder<> Builder(CI);
+    for (unsigned i=0;i<VT->getNumElements();i++) {
+      ConstantFP *fpV0 = cast<ConstantFP>(CV0->getAggregateElement(i));
+      ConstantFP *fpV1 = cast<ConstantFP>(CV1->getAggregateElement(i));
+      Value *EltResult = EvalBinaryIntrinsic(fpV0, fpV1, floatEvalFunc, doubleEvalFunc);
+      Result = Builder.CreateInsertElement(Result, EltResult, i);
+    }
+  } else {
+    ConstantFP *fpV0 = cast<ConstantFP>(V0);
+    ConstantFP *fpV1 = cast<ConstantFP>(V1);
+    Result = EvalBinaryIntrinsic(fpV0, fpV1, floatEvalFunc, doubleEvalFunc);
+  }
+  CI->replaceAllUsesWith(Result);
+  CI->eraseFromParent();
+  return Result;
 
   CI->eraseFromParent();
   return Result;
@@ -4302,7 +4340,8 @@ RValue CGMSHLSLRuntime::EmitHLSLBuiltinCallExpr(CodeGenFunction &CGF,
       if (group == HLOpcodeGroup::HLIntrinsic) {
         bool allOperandImm = true;
         for (auto &operand : CI->arg_operands()) {
-          bool isImm = isa<ConstantInt>(operand) || isa<ConstantFP>(operand);
+          bool isImm = isa<ConstantInt>(operand) || isa<ConstantFP>(operand) ||
+              isa<ConstantAggregateZero>(operand) || isa<ConstantDataVector>(operand);
           if (!isImm) {
             allOperandImm = false;
             break;

+ 9 - 0
tools/clang/test/CodeGenHLSL/quick-test/vec_imm_sqrt.hlsl

@@ -0,0 +1,9 @@
+// RUN: %dxc -T ps_6_0 -E main %s | FileCheck %s
+
+// Make sure vector immediate sqrt works.
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 5.000000e-01)
+
+float2 main() : SV_TARGET
+{
+  return sqrt(0.25.xx);
+}