DxilConvergent.cpp 8.2 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilConvergent.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Mark convergent for hlsl. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/IR/BasicBlock.h"
  12. #include "llvm/IR/Dominators.h"
  13. #include "llvm/IR/Function.h"
  14. #include "llvm/IR/IRBuilder.h"
  15. #include "llvm/IR/Intrinsics.h"
  16. #include "llvm/IR/Module.h"
  17. #include "llvm/Support/GenericDomTree.h"
  18. #include "llvm/Support/raw_os_ostream.h"
  19. #include "dxc/DXIL/DxilConstants.h"
  20. #include "dxc/HLSL/DxilGenerationPass.h"
  21. #include "dxc/HLSL/HLOperations.h"
  22. #include "dxc/HLSL/HLModule.h"
  23. #include "dxc/HlslIntrinsicOp.h"
  24. #include "dxc/HLSL/DxilConvergentName.h"
  25. using namespace llvm;
  26. using namespace hlsl;
  27. ///////////////////////////////////////////////////////////////////////////////
  28. // DxilConvergent.
  29. // Mark convergent to avoid sample coordnate calculation sink into control flow.
  30. //
  31. namespace {
  32. class DxilConvergentMark : public ModulePass {
  33. public:
  34. static char ID; // Pass identification, replacement for typeid
  35. explicit DxilConvergentMark() : ModulePass(ID) {}
  36. const char *getPassName() const override {
  37. return "DxilConvergentMark";
  38. }
  39. bool runOnModule(Module &M) override {
  40. if (M.HasHLModule()) {
  41. const ShaderModel *SM = M.GetHLModule().GetShaderModel();
  42. if (!SM->IsPS() && !SM->IsLib() && (!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS())))
  43. return false;
  44. }
  45. bool bUpdated = false;
  46. for (Function &F : M.functions()) {
  47. if (F.isDeclaration())
  48. continue;
  49. // Compute postdominator relation.
  50. DominatorTreeBase<BasicBlock> PDR(true);
  51. PDR.recalculate(F);
  52. for (BasicBlock &bb : F.getBasicBlockList()) {
  53. for (auto it = bb.begin(); it != bb.end();) {
  54. Instruction *I = (it++);
  55. if (Value *V = FindConvergentOperand(I)) {
  56. if (PropagateConvergent(V, &F, PDR)) {
  57. // TODO: emit warning here.
  58. }
  59. bUpdated = true;
  60. }
  61. }
  62. }
  63. }
  64. return bUpdated;
  65. }
  66. private:
  67. void MarkConvergent(Value *V, IRBuilder<> &Builder, Module &M);
  68. Value *FindConvergentOperand(Instruction *I);
  69. bool PropagateConvergent(Value *V, Function *F,
  70. DominatorTreeBase<BasicBlock> &PostDom);
  71. bool PropagateConvergentImpl(Value *V, Function *F,
  72. DominatorTreeBase<BasicBlock> &PostDom, std::set<Value*>& visited);
  73. };
  74. char DxilConvergentMark::ID = 0;
  75. void DxilConvergentMark::MarkConvergent(Value *V, IRBuilder<> &Builder,
  76. Module &M) {
  77. Type *Ty = V->getType()->getScalarType();
  78. // Only work on vector/scalar types.
  79. if (Ty->isAggregateType() ||
  80. Ty->isPointerTy())
  81. return;
  82. FunctionType *FT = FunctionType::get(Ty, Ty, false);
  83. std::string str = kConvergentFunctionPrefix;
  84. raw_string_ostream os(str);
  85. Ty->print(os);
  86. os.flush();
  87. Function *ConvF = cast<Function>(M.getOrInsertFunction(str, FT));
  88. ConvF->addFnAttr(Attribute::AttrKind::Convergent);
  89. if (VectorType *VT = dyn_cast<VectorType>(V->getType())) {
  90. Value *ConvV = UndefValue::get(V->getType());
  91. std::vector<ExtractElementInst *> extractList(VT->getNumElements());
  92. for (unsigned i = 0; i < VT->getNumElements(); i++) {
  93. ExtractElementInst *EltV =
  94. cast<ExtractElementInst>(Builder.CreateExtractElement(V, i));
  95. extractList[i] = EltV;
  96. Value *EltC = Builder.CreateCall(ConvF, {EltV});
  97. ConvV = Builder.CreateInsertElement(ConvV, EltC, i);
  98. }
  99. V->replaceAllUsesWith(ConvV);
  100. for (ExtractElementInst *E : extractList) {
  101. E->setOperand(0, V);
  102. }
  103. } else {
  104. CallInst *ConvV = Builder.CreateCall(ConvF, {V});
  105. V->replaceAllUsesWith(ConvV);
  106. ConvV->setOperand(0, V);
  107. }
  108. }
  109. bool DxilConvergentMark::PropagateConvergent(
  110. Value *V, Function *F, DominatorTreeBase<BasicBlock> &PostDom) {
  111. std::set<Value *> visited;
  112. return PropagateConvergentImpl(V, F, PostDom, visited);
  113. }
  114. bool DxilConvergentMark::PropagateConvergentImpl(Value *V, Function *F,
  115. DominatorTreeBase<BasicBlock> &PostDom, std::set<Value*>& visited) {
  116. // Don't go through already visted nodes
  117. if (visited.find(V) != visited.end())
  118. return false;
  119. // Mark as visited
  120. visited.insert(V);
  121. // Skip constant.
  122. if (isa<Constant>(V))
  123. return false;
  124. // Skip phi which cannot sink.
  125. if (isa<PHINode>(V))
  126. return false;
  127. if (Instruction *I = dyn_cast<Instruction>(V)) {
  128. BasicBlock *BB = I->getParent();
  129. if (PostDom.dominates(BB, &F->getEntryBlock())) {
  130. IRBuilder<> Builder(I->getNextNode());
  131. MarkConvergent(I, Builder, *F->getParent());
  132. return false;
  133. } else {
  134. // Propagete to each operand of I.
  135. for (Use &U : I->operands()) {
  136. PropagateConvergentImpl(U.get(), F, PostDom, visited);
  137. }
  138. // return true for report warning.
  139. // TODO: static indexing cbuffer is fine.
  140. return true;
  141. }
  142. } else {
  143. IRBuilder<> EntryBuilder(F->getEntryBlock().getFirstInsertionPt());
  144. MarkConvergent(V, EntryBuilder, *F->getParent());
  145. return false;
  146. }
  147. }
  148. Value *DxilConvergentMark::FindConvergentOperand(Instruction *I) {
  149. if (CallInst *CI = dyn_cast<CallInst>(I)) {
  150. if (hlsl::GetHLOpcodeGroup(CI->getCalledFunction()) ==
  151. HLOpcodeGroup::HLIntrinsic) {
  152. IntrinsicOp IOP = static_cast<IntrinsicOp>(GetHLOpcode(CI));
  153. switch (IOP) {
  154. case IntrinsicOp::IOP_ddx:
  155. case IntrinsicOp::IOP_ddx_fine:
  156. case IntrinsicOp::IOP_ddx_coarse:
  157. case IntrinsicOp::IOP_ddy:
  158. case IntrinsicOp::IOP_ddy_fine:
  159. case IntrinsicOp::IOP_ddy_coarse:
  160. return CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  161. case IntrinsicOp::MOP_Sample:
  162. case IntrinsicOp::MOP_SampleBias:
  163. case IntrinsicOp::MOP_SampleCmp:
  164. case IntrinsicOp::MOP_CalculateLevelOfDetail:
  165. case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
  166. return CI->getArgOperand(HLOperandIndex::kSampleCoordArgIndex);
  167. case IntrinsicOp::MOP_WriteSamplerFeedback:
  168. case IntrinsicOp::MOP_WriteSamplerFeedbackBias:
  169. return CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackCoordArgIndex);
  170. default:
  171. // No other ops have convergent operands.
  172. break;
  173. }
  174. }
  175. }
  176. return nullptr;
  177. }
  178. } // namespace
  179. INITIALIZE_PASS(DxilConvergentMark, "hlsl-dxil-convergent-mark",
  180. "Mark convergent", false, false)
  181. ModulePass *llvm::createDxilConvergentMarkPass() {
  182. return new DxilConvergentMark();
  183. }
  184. namespace {
  185. class DxilConvergentClear : public ModulePass {
  186. public:
  187. static char ID; // Pass identification, replacement for typeid
  188. explicit DxilConvergentClear() : ModulePass(ID) {}
  189. const char *getPassName() const override {
  190. return "DxilConvergentClear";
  191. }
  192. bool runOnModule(Module &M) override {
  193. std::vector<Function *> convergentList;
  194. for (Function &F : M.functions()) {
  195. if (F.getName().startswith(kConvergentFunctionPrefix)) {
  196. convergentList.emplace_back(&F);
  197. }
  198. }
  199. for (Function *F : convergentList) {
  200. ClearConvergent(F);
  201. }
  202. return convergentList.size();
  203. }
  204. private:
  205. void ClearConvergent(Function *F);
  206. };
  207. char DxilConvergentClear::ID = 0;
  208. void DxilConvergentClear::ClearConvergent(Function *F) {
  209. // Replace all users with arg.
  210. for (auto it = F->user_begin(); it != F->user_end();) {
  211. CallInst *CI = cast<CallInst>(*(it++));
  212. Value *arg = CI->getArgOperand(0);
  213. CI->replaceAllUsesWith(arg);
  214. CI->eraseFromParent();
  215. }
  216. F->eraseFromParent();
  217. }
  218. } // namespace
  219. INITIALIZE_PASS(DxilConvergentClear, "hlsl-dxil-convergent-clear",
  220. "Clear convergent before dxil emit", false, false)
  221. ModulePass *llvm::createDxilConvergentClearPass() {
  222. return new DxilConvergentClear();
  223. }