DxilPrecisePropagatePass.cpp 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPrecisePropagatePass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/DXIL/DxilModule.h"
  10. #include "dxc/HLSL/DxilGenerationPass.h"
  11. #include "dxc/HLSL/HLModule.h"
  12. #include "dxc/HLSL/HLOperations.h"
  13. #include "llvm/Pass.h"
  14. #include "llvm/IR/Function.h"
  15. #include "llvm/IR/Instruction.h"
  16. #include "llvm/IR/Instructions.h"
  17. #include "llvm/IR/Operator.h"
  18. #include "llvm/IR/Module.h"
  19. #include "llvm/Support/Casting.h"
  20. #include <unordered_set>
  21. #include <vector>
  22. using namespace llvm;
  23. using namespace hlsl;
  24. namespace {
  25. class DxilPrecisePropagatePass : public ModulePass {
  26. public:
  27. static char ID; // Pass identification, replacement for typeid
  28. explicit DxilPrecisePropagatePass() : ModulePass(ID) {}
  29. const char *getPassName() const override { return "DXIL Precise Propagate"; }
  30. bool runOnModule(Module &M) override {
  31. DxilModule &dxilModule = M.GetOrCreateDxilModule();
  32. DxilTypeSystem &typeSys = dxilModule.GetTypeSystem();
  33. std::unordered_set<Instruction*> processedSet;
  34. std::vector<Function*> deadList;
  35. for (Function &F : M.functions()) {
  36. if (HLModule::HasPreciseAttribute(&F)) {
  37. PropagatePreciseOnFunctionUser(F, typeSys, processedSet);
  38. deadList.emplace_back(&F);
  39. }
  40. }
  41. for (Function *F : deadList)
  42. F->eraseFromParent();
  43. return true;
  44. }
  45. private:
  46. void PropagatePreciseOnFunctionUser(
  47. Function &F, DxilTypeSystem &typeSys,
  48. std::unordered_set<Instruction *> &processedSet);
  49. };
  50. char DxilPrecisePropagatePass::ID = 0;
  51. }
  52. static void PropagatePreciseAttribute(Instruction *I, DxilTypeSystem &typeSys,
  53. std::unordered_set<Instruction *> &processedSet);
  54. static void PropagatePreciseAttributeOnOperand(
  55. Value *V, DxilTypeSystem &typeSys, LLVMContext &Context,
  56. std::unordered_set<Instruction *> &processedSet) {
  57. Instruction *I = dyn_cast<Instruction>(V);
  58. // Skip none inst.
  59. if (!I)
  60. return;
  61. FPMathOperator *FPMath = dyn_cast<FPMathOperator>(I);
  62. // Skip none FPMath
  63. if (!FPMath)
  64. return;
  65. // Skip inst already marked.
  66. if (processedSet.count(I) > 0)
  67. return;
  68. // TODO: skip precise on integer type, sample instruction...
  69. processedSet.insert(I);
  70. // Set precise fast math on those instructions that support it.
  71. if (DxilModule::PreservesFastMathFlags(I))
  72. DxilModule::SetPreciseFastMathFlags(I);
  73. // Fast math not work on call, use metadata.
  74. if (CallInst *CI = dyn_cast<CallInst>(I))
  75. HLModule::MarkPreciseAttributeWithMetadata(CI);
  76. PropagatePreciseAttribute(I, typeSys, processedSet);
  77. }
  78. static void PropagatePreciseAttributeOnPointer(
  79. Value *Ptr, DxilTypeSystem &typeSys, LLVMContext &Context,
  80. std::unordered_set<Instruction *> &processedSet) {
  81. // Find all store and propagate on the val operand of store.
  82. // For CallInst, if Ptr is used as out parameter, mark it.
  83. for (User *U : Ptr->users()) {
  84. Instruction *user = cast<Instruction>(U);
  85. if (StoreInst *stInst = dyn_cast<StoreInst>(user)) {
  86. Value *val = stInst->getValueOperand();
  87. PropagatePreciseAttributeOnOperand(val, typeSys, Context, processedSet);
  88. } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
  89. bool bReadOnly = true;
  90. Function *F = CI->getCalledFunction();
  91. const DxilFunctionAnnotation *funcAnnotation =
  92. typeSys.GetFunctionAnnotation(F);
  93. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  94. if (Ptr != CI->getArgOperand(i))
  95. continue;
  96. const DxilParameterAnnotation &paramAnnotation =
  97. funcAnnotation->GetParameterAnnotation(i);
  98. // OutputPatch and OutputStream will be checked after scalar repl.
  99. // Here only check out/inout
  100. if (paramAnnotation.GetParamInputQual() == DxilParamInputQual::Out ||
  101. paramAnnotation.GetParamInputQual() == DxilParamInputQual::Inout) {
  102. bReadOnly = false;
  103. break;
  104. }
  105. }
  106. if (!bReadOnly)
  107. PropagatePreciseAttributeOnOperand(CI, typeSys, Context, processedSet);
  108. }
  109. }
  110. }
  111. static void
  112. PropagatePreciseAttribute(Instruction *I, DxilTypeSystem &typeSys,
  113. std::unordered_set<Instruction *> &processedSet) {
  114. LLVMContext &Context = I->getContext();
  115. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
  116. PropagatePreciseAttributeOnPointer(AI, typeSys, Context, processedSet);
  117. } else if (dyn_cast<CallInst>(I)) {
  118. // Propagate every argument.
  119. // TODO: only propagate precise argument.
  120. for (Value *src : I->operands())
  121. PropagatePreciseAttributeOnOperand(src, typeSys, Context, processedSet);
  122. } else if (dyn_cast<FPMathOperator>(I)) {
  123. // TODO: only propagate precise argument.
  124. for (Value *src : I->operands())
  125. PropagatePreciseAttributeOnOperand(src, typeSys, Context, processedSet);
  126. } else if (LoadInst *ldInst = dyn_cast<LoadInst>(I)) {
  127. Value *Ptr = ldInst->getPointerOperand();
  128. PropagatePreciseAttributeOnPointer(Ptr, typeSys, Context, processedSet);
  129. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I))
  130. PropagatePreciseAttributeOnPointer(GEP, typeSys, Context, processedSet);
  131. // TODO: support more case which need
  132. }
  133. void DxilPrecisePropagatePass::PropagatePreciseOnFunctionUser(
  134. Function &F, DxilTypeSystem &typeSys,
  135. std::unordered_set<Instruction *> &processedSet) {
  136. LLVMContext &Context = F.getContext();
  137. for (auto U = F.user_begin(), E = F.user_end(); U != E;) {
  138. CallInst *CI = cast<CallInst>(*(U++));
  139. Value *V = CI->getArgOperand(0);
  140. PropagatePreciseAttributeOnOperand(V, typeSys, Context, processedSet);
  141. CI->eraseFromParent();
  142. }
  143. }
  144. ModulePass *llvm::createDxilPrecisePropagatePass() {
  145. return new DxilPrecisePropagatePass();
  146. }
  147. INITIALIZE_PASS(DxilPrecisePropagatePass, "hlsl-dxil-precise", "DXIL precise attribute propagate", false, false)