DxilConvergent.cpp 8.0 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilConvergent.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Mark convergent for hlsl. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/IR/BasicBlock.h"
  12. #include "llvm/IR/Dominators.h"
  13. #include "llvm/IR/Function.h"
  14. #include "llvm/IR/IRBuilder.h"
  15. #include "llvm/IR/Intrinsics.h"
  16. #include "llvm/IR/Module.h"
  17. #include "llvm/Support/GenericDomTree.h"
  18. #include "llvm/Support/raw_os_ostream.h"
  19. #include "dxc/HLSL/DxilConstants.h"
  20. #include "dxc/HLSL/DxilGenerationPass.h"
  21. #include "dxc/HLSL/HLOperations.h"
  22. #include "dxc/HLSL/HLModule.h"
  23. #include "dxc/HlslIntrinsicOp.h"
  24. using namespace llvm;
  25. using namespace hlsl;
  26. namespace {
  27. const StringRef kConvergentFunctionPrefix = "dxil.convergent.marker.";
  28. }
  29. ///////////////////////////////////////////////////////////////////////////////
  30. // DxilConvergent.
  31. // Mark convergent to avoid sample coordnate calculation sink into control flow.
  32. //
  33. namespace {
  34. class DxilConvergentMark : public ModulePass {
  35. public:
  36. static char ID; // Pass identification, replacement for typeid
  37. explicit DxilConvergentMark() : ModulePass(ID) {}
  38. const char *getPassName() const override {
  39. return "DxilConvergentMark";
  40. }
  41. bool runOnModule(Module &M) override {
  42. if (M.HasHLModule()) {
  43. if (!M.GetHLModule().GetShaderModel()->IsPS())
  44. return false;
  45. }
  46. bool bUpdated = false;
  47. for (Function &F : M.functions()) {
  48. if (F.isDeclaration())
  49. continue;
  50. // Compute postdominator relation.
  51. DominatorTreeBase<BasicBlock> PDR(true);
  52. PDR.recalculate(F);
  53. for (BasicBlock &bb : F.getBasicBlockList()) {
  54. for (auto it = bb.begin(); it != bb.end();) {
  55. Instruction *I = (it++);
  56. if (Value *V = FindConvergentOperand(I)) {
  57. if (PropagateConvergent(V, &F, PDR)) {
  58. // TODO: emit warning here.
  59. }
  60. bUpdated = true;
  61. }
  62. }
  63. }
  64. }
  65. return bUpdated;
  66. }
  67. private:
  68. void MarkConvergent(Value *V, IRBuilder<> &Builder, Module &M);
  69. Value *FindConvergentOperand(Instruction *I);
  70. bool PropagateConvergent(Value *V, Function *F,
  71. DominatorTreeBase<BasicBlock> &PostDom);
  72. };
  73. char DxilConvergentMark::ID = 0;
  74. void DxilConvergentMark::MarkConvergent(Value *V, IRBuilder<> &Builder,
  75. Module &M) {
  76. Type *Ty = V->getType()->getScalarType();
  77. // Only work on vector/scalar types.
  78. if (Ty->isAggregateType() ||
  79. Ty->isPointerTy())
  80. return;
  81. FunctionType *FT = FunctionType::get(Ty, Ty, false);
  82. std::string str = kConvergentFunctionPrefix;
  83. raw_string_ostream os(str);
  84. Ty->print(os);
  85. os.flush();
  86. Function *ConvF = cast<Function>(M.getOrInsertFunction(str, FT));
  87. ConvF->addFnAttr(Attribute::AttrKind::Convergent);
  88. if (VectorType *VT = dyn_cast<VectorType>(V->getType())) {
  89. Value *ConvV = UndefValue::get(V->getType());
  90. std::vector<ExtractElementInst *> extractList(VT->getNumElements());
  91. for (unsigned i = 0; i < VT->getNumElements(); i++) {
  92. ExtractElementInst *EltV =
  93. cast<ExtractElementInst>(Builder.CreateExtractElement(V, i));
  94. extractList[i] = EltV;
  95. Value *EltC = Builder.CreateCall(ConvF, {EltV});
  96. ConvV = Builder.CreateInsertElement(ConvV, EltC, i);
  97. }
  98. V->replaceAllUsesWith(ConvV);
  99. for (ExtractElementInst *E : extractList) {
  100. E->setOperand(0, V);
  101. }
  102. } else {
  103. CallInst *ConvV = Builder.CreateCall(ConvF, {V});
  104. V->replaceAllUsesWith(ConvV);
  105. ConvV->setOperand(0, V);
  106. }
  107. }
  108. bool DxilConvergentMark::PropagateConvergent(
  109. Value *V, Function *F, DominatorTreeBase<BasicBlock> &PostDom) {
  110. // Skip constant.
  111. if (isa<Constant>(V))
  112. return false;
  113. // Skip phi which cannot sink.
  114. if (isa<PHINode>(V))
  115. return false;
  116. if (Instruction *I = dyn_cast<Instruction>(V)) {
  117. BasicBlock *BB = I->getParent();
  118. if (PostDom.dominates(BB, &F->getEntryBlock())) {
  119. IRBuilder<> Builder(I->getNextNode());
  120. MarkConvergent(I, Builder, *F->getParent());
  121. return false;
  122. } else {
  123. // Propagete to each operand of I.
  124. for (Use &U : I->operands()) {
  125. PropagateConvergent(U.get(), F, PostDom);
  126. }
  127. // return true for report warning.
  128. // TODO: static indexing cbuffer is fine.
  129. return true;
  130. }
  131. } else {
  132. IRBuilder<> EntryBuilder(F->getEntryBlock().getFirstInsertionPt());
  133. MarkConvergent(V, EntryBuilder, *F->getParent());
  134. return false;
  135. }
  136. }
  137. Value *DxilConvergentMark::FindConvergentOperand(Instruction *I) {
  138. if (CallInst *CI = dyn_cast<CallInst>(I)) {
  139. if (hlsl::GetHLOpcodeGroup(CI->getCalledFunction()) ==
  140. HLOpcodeGroup::HLIntrinsic) {
  141. IntrinsicOp IOP = static_cast<IntrinsicOp>(GetHLOpcode(CI));
  142. switch (IOP) {
  143. case IntrinsicOp::IOP_ddx:
  144. case IntrinsicOp::IOP_ddx_fine:
  145. case IntrinsicOp::IOP_ddx_coarse:
  146. case IntrinsicOp::IOP_ddy:
  147. case IntrinsicOp::IOP_ddy_fine:
  148. case IntrinsicOp::IOP_ddy_coarse:
  149. return CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  150. case IntrinsicOp::MOP_Sample:
  151. case IntrinsicOp::MOP_SampleBias:
  152. case IntrinsicOp::MOP_SampleCmp:
  153. case IntrinsicOp::MOP_SampleCmpLevelZero:
  154. case IntrinsicOp::MOP_CalculateLevelOfDetail:
  155. case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
  156. return CI->getArgOperand(HLOperandIndex::kSampleCoordArgIndex);
  157. case IntrinsicOp::MOP_Gather:
  158. case IntrinsicOp::MOP_GatherAlpha:
  159. case IntrinsicOp::MOP_GatherBlue:
  160. case IntrinsicOp::MOP_GatherCmp:
  161. case IntrinsicOp::MOP_GatherCmpAlpha:
  162. case IntrinsicOp::MOP_GatherCmpBlue:
  163. case IntrinsicOp::MOP_GatherCmpGreen:
  164. case IntrinsicOp::MOP_GatherCmpRed:
  165. case IntrinsicOp::MOP_GatherGreen:
  166. case IntrinsicOp::MOP_GatherRed:
  167. return CI->getArgOperand(HLOperandIndex::kGatherCoordArgIndex);
  168. default:
  169. // No other ops have convergent operands.
  170. break;
  171. }
  172. }
  173. }
  174. return nullptr;
  175. }
  176. } // namespace
  177. INITIALIZE_PASS(DxilConvergentMark, "hlsl-dxil-convergent-mark",
  178. "Mark convergent", false, false)
  179. ModulePass *llvm::createDxilConvergentMarkPass() {
  180. return new DxilConvergentMark();
  181. }
  182. namespace {
  183. class DxilConvergentClear : public ModulePass {
  184. public:
  185. static char ID; // Pass identification, replacement for typeid
  186. explicit DxilConvergentClear() : ModulePass(ID) {}
  187. const char *getPassName() const override {
  188. return "DxilConvergentClear";
  189. }
  190. bool runOnModule(Module &M) override {
  191. std::vector<Function *> convergentList;
  192. for (Function &F : M.functions()) {
  193. if (F.getName().startswith(kConvergentFunctionPrefix)) {
  194. convergentList.emplace_back(&F);
  195. }
  196. }
  197. for (Function *F : convergentList) {
  198. ClearConvergent(F);
  199. }
  200. return convergentList.size();
  201. }
  202. private:
  203. void ClearConvergent(Function *F);
  204. };
  205. char DxilConvergentClear::ID = 0;
  206. void DxilConvergentClear::ClearConvergent(Function *F) {
  207. // Replace all users with arg.
  208. for (auto it = F->user_begin(); it != F->user_end();) {
  209. CallInst *CI = cast<CallInst>(*(it++));
  210. Value *arg = CI->getArgOperand(0);
  211. CI->replaceAllUsesWith(arg);
  212. CI->eraseFromParent();
  213. }
  214. F->eraseFromParent();
  215. }
  216. } // namespace
  217. INITIALIZE_PASS(DxilConvergentClear, "hlsl-dxil-convergent-clear",
  218. "Clear convergent before dxil emit", false, false)
  219. ModulePass *llvm::createDxilConvergentClearPass() {
  220. return new DxilConvergentClear();
  221. }