DxilConvergent.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilConvergent.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Mark convergent for hlsl. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/IR/BasicBlock.h"
  12. #include "llvm/IR/Dominators.h"
  13. #include "llvm/IR/Function.h"
  14. #include "llvm/IR/IRBuilder.h"
  15. #include "llvm/IR/Intrinsics.h"
  16. #include "llvm/IR/Module.h"
  17. #include "llvm/Support/GenericDomTree.h"
  18. #include "llvm/Support/raw_os_ostream.h"
  19. #include "dxc/DXIL/DxilConstants.h"
  20. #include "dxc/HLSL/DxilGenerationPass.h"
  21. #include "dxc/HLSL/HLOperations.h"
  22. #include "dxc/HLSL/HLModule.h"
  23. #include "dxc/HlslIntrinsicOp.h"
  24. using namespace llvm;
  25. using namespace hlsl;
  26. namespace {
  27. const StringRef kConvergentFunctionPrefix = "dxil.convergent.marker.";
  28. }
  29. ///////////////////////////////////////////////////////////////////////////////
  30. // DxilConvergent.
  31. // Mark convergent to avoid sample coordnate calculation sink into control flow.
  32. //
  33. namespace {
  34. class DxilConvergentMark : public ModulePass {
  35. public:
  36. static char ID; // Pass identification, replacement for typeid
  37. explicit DxilConvergentMark() : ModulePass(ID) {}
  38. const char *getPassName() const override {
  39. return "DxilConvergentMark";
  40. }
  41. bool runOnModule(Module &M) override {
  42. if (M.HasHLModule()) {
  43. if (!M.GetHLModule().GetShaderModel()->IsPS())
  44. return false;
  45. }
  46. bool bUpdated = false;
  47. for (Function &F : M.functions()) {
  48. if (F.isDeclaration())
  49. continue;
  50. // Compute postdominator relation.
  51. DominatorTreeBase<BasicBlock> PDR(true);
  52. PDR.recalculate(F);
  53. for (BasicBlock &bb : F.getBasicBlockList()) {
  54. for (auto it = bb.begin(); it != bb.end();) {
  55. Instruction *I = (it++);
  56. if (Value *V = FindConvergentOperand(I)) {
  57. if (PropagateConvergent(V, &F, PDR)) {
  58. // TODO: emit warning here.
  59. }
  60. bUpdated = true;
  61. }
  62. }
  63. }
  64. }
  65. return bUpdated;
  66. }
  67. private:
  68. void MarkConvergent(Value *V, IRBuilder<> &Builder, Module &M);
  69. Value *FindConvergentOperand(Instruction *I);
  70. bool PropagateConvergent(Value *V, Function *F,
  71. DominatorTreeBase<BasicBlock> &PostDom);
  72. bool PropagateConvergentImpl(Value *V, Function *F,
  73. DominatorTreeBase<BasicBlock> &PostDom, std::set<Value*>& visited);
  74. };
  75. char DxilConvergentMark::ID = 0;
  76. void DxilConvergentMark::MarkConvergent(Value *V, IRBuilder<> &Builder,
  77. Module &M) {
  78. Type *Ty = V->getType()->getScalarType();
  79. // Only work on vector/scalar types.
  80. if (Ty->isAggregateType() ||
  81. Ty->isPointerTy())
  82. return;
  83. FunctionType *FT = FunctionType::get(Ty, Ty, false);
  84. std::string str = kConvergentFunctionPrefix;
  85. raw_string_ostream os(str);
  86. Ty->print(os);
  87. os.flush();
  88. Function *ConvF = cast<Function>(M.getOrInsertFunction(str, FT));
  89. ConvF->addFnAttr(Attribute::AttrKind::Convergent);
  90. if (VectorType *VT = dyn_cast<VectorType>(V->getType())) {
  91. Value *ConvV = UndefValue::get(V->getType());
  92. std::vector<ExtractElementInst *> extractList(VT->getNumElements());
  93. for (unsigned i = 0; i < VT->getNumElements(); i++) {
  94. ExtractElementInst *EltV =
  95. cast<ExtractElementInst>(Builder.CreateExtractElement(V, i));
  96. extractList[i] = EltV;
  97. Value *EltC = Builder.CreateCall(ConvF, {EltV});
  98. ConvV = Builder.CreateInsertElement(ConvV, EltC, i);
  99. }
  100. V->replaceAllUsesWith(ConvV);
  101. for (ExtractElementInst *E : extractList) {
  102. E->setOperand(0, V);
  103. }
  104. } else {
  105. CallInst *ConvV = Builder.CreateCall(ConvF, {V});
  106. V->replaceAllUsesWith(ConvV);
  107. ConvV->setOperand(0, V);
  108. }
  109. }
  110. bool DxilConvergentMark::PropagateConvergent(
  111. Value *V, Function *F, DominatorTreeBase<BasicBlock> &PostDom) {
  112. std::set<Value *> visited;
  113. return PropagateConvergentImpl(V, F, PostDom, visited);
  114. }
  115. bool DxilConvergentMark::PropagateConvergentImpl(Value *V, Function *F,
  116. DominatorTreeBase<BasicBlock> &PostDom, std::set<Value*>& visited) {
  117. // Don't go through already visted nodes
  118. if (visited.find(V) != visited.end())
  119. return false;
  120. // Mark as visited
  121. visited.insert(V);
  122. // Skip constant.
  123. if (isa<Constant>(V))
  124. return false;
  125. // Skip phi which cannot sink.
  126. if (isa<PHINode>(V))
  127. return false;
  128. if (Instruction *I = dyn_cast<Instruction>(V)) {
  129. BasicBlock *BB = I->getParent();
  130. if (PostDom.dominates(BB, &F->getEntryBlock())) {
  131. IRBuilder<> Builder(I->getNextNode());
  132. MarkConvergent(I, Builder, *F->getParent());
  133. return false;
  134. } else {
  135. // Propagete to each operand of I.
  136. for (Use &U : I->operands()) {
  137. PropagateConvergentImpl(U.get(), F, PostDom, visited);
  138. }
  139. // return true for report warning.
  140. // TODO: static indexing cbuffer is fine.
  141. return true;
  142. }
  143. } else {
  144. IRBuilder<> EntryBuilder(F->getEntryBlock().getFirstInsertionPt());
  145. MarkConvergent(V, EntryBuilder, *F->getParent());
  146. return false;
  147. }
  148. }
  149. Value *DxilConvergentMark::FindConvergentOperand(Instruction *I) {
  150. if (CallInst *CI = dyn_cast<CallInst>(I)) {
  151. if (hlsl::GetHLOpcodeGroup(CI->getCalledFunction()) ==
  152. HLOpcodeGroup::HLIntrinsic) {
  153. IntrinsicOp IOP = static_cast<IntrinsicOp>(GetHLOpcode(CI));
  154. switch (IOP) {
  155. case IntrinsicOp::IOP_ddx:
  156. case IntrinsicOp::IOP_ddx_fine:
  157. case IntrinsicOp::IOP_ddx_coarse:
  158. case IntrinsicOp::IOP_ddy:
  159. case IntrinsicOp::IOP_ddy_fine:
  160. case IntrinsicOp::IOP_ddy_coarse:
  161. return CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  162. case IntrinsicOp::MOP_Sample:
  163. case IntrinsicOp::MOP_SampleBias:
  164. case IntrinsicOp::MOP_SampleCmp:
  165. case IntrinsicOp::MOP_SampleCmpLevelZero:
  166. case IntrinsicOp::MOP_CalculateLevelOfDetail:
  167. case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
  168. return CI->getArgOperand(HLOperandIndex::kSampleCoordArgIndex);
  169. case IntrinsicOp::MOP_Gather:
  170. case IntrinsicOp::MOP_GatherAlpha:
  171. case IntrinsicOp::MOP_GatherBlue:
  172. case IntrinsicOp::MOP_GatherCmp:
  173. case IntrinsicOp::MOP_GatherCmpAlpha:
  174. case IntrinsicOp::MOP_GatherCmpBlue:
  175. case IntrinsicOp::MOP_GatherCmpGreen:
  176. case IntrinsicOp::MOP_GatherCmpRed:
  177. case IntrinsicOp::MOP_GatherGreen:
  178. case IntrinsicOp::MOP_GatherRed:
  179. return CI->getArgOperand(HLOperandIndex::kGatherCoordArgIndex);
  180. default:
  181. // No other ops have convergent operands.
  182. break;
  183. }
  184. }
  185. }
  186. return nullptr;
  187. }
  188. } // namespace
  189. INITIALIZE_PASS(DxilConvergentMark, "hlsl-dxil-convergent-mark",
  190. "Mark convergent", false, false)
  191. ModulePass *llvm::createDxilConvergentMarkPass() {
  192. return new DxilConvergentMark();
  193. }
  194. namespace {
  195. class DxilConvergentClear : public ModulePass {
  196. public:
  197. static char ID; // Pass identification, replacement for typeid
  198. explicit DxilConvergentClear() : ModulePass(ID) {}
  199. const char *getPassName() const override {
  200. return "DxilConvergentClear";
  201. }
  202. bool runOnModule(Module &M) override {
  203. std::vector<Function *> convergentList;
  204. for (Function &F : M.functions()) {
  205. if (F.getName().startswith(kConvergentFunctionPrefix)) {
  206. convergentList.emplace_back(&F);
  207. }
  208. }
  209. for (Function *F : convergentList) {
  210. ClearConvergent(F);
  211. }
  212. return convergentList.size();
  213. }
  214. private:
  215. void ClearConvergent(Function *F);
  216. };
  217. char DxilConvergentClear::ID = 0;
  218. void DxilConvergentClear::ClearConvergent(Function *F) {
  219. // Replace all users with arg.
  220. for (auto it = F->user_begin(); it != F->user_end();) {
  221. CallInst *CI = cast<CallInst>(*(it++));
  222. Value *arg = CI->getArgOperand(0);
  223. CI->replaceAllUsesWith(arg);
  224. CI->eraseFromParent();
  225. }
  226. F->eraseFromParent();
  227. }
  228. } // namespace
  229. INITIALIZE_PASS(DxilConvergentClear, "hlsl-dxil-convergent-clear",
  230. "Clear convergent before dxil emit", false, false)
  231. ModulePass *llvm::createDxilConvergentClearPass() {
  232. return new DxilConvergentClear();
  233. }