DxilConvergent.cpp 8.4 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilConvergent.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Mark convergent for hlsl. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/IR/BasicBlock.h"
  12. #include "llvm/IR/Dominators.h"
  13. #include "llvm/IR/Function.h"
  14. #include "llvm/IR/IRBuilder.h"
  15. #include "llvm/IR/Intrinsics.h"
  16. #include "llvm/IR/Module.h"
  17. #include "llvm/Support/GenericDomTree.h"
  18. #include "llvm/Support/raw_os_ostream.h"
  19. #include "dxc/DXIL/DxilConstants.h"
  20. #include "dxc/HLSL/DxilGenerationPass.h"
  21. #include "dxc/HLSL/HLOperations.h"
  22. #include "dxc/HLSL/HLModule.h"
  23. #include "dxc/HLSL/DxilConvergent.h"
  24. #include "dxc/HlslIntrinsicOp.h"
  25. #include "dxc/HLSL/DxilConvergentName.h"
  26. using namespace llvm;
  27. using namespace hlsl;
  28. bool hlsl::IsConvergentMarker(Value *V) {
  29. CallInst *CI = dyn_cast<CallInst>(V);
  30. if (!CI)
  31. return false;
  32. Function *F = CI->getCalledFunction();
  33. return F->getName().startswith(kConvergentFunctionPrefix);
  34. }
  35. Value *hlsl::GetConvergentSource(Value *V) {
  36. return cast<CallInst>(V)->getOperand(0);
  37. }
  38. ///////////////////////////////////////////////////////////////////////////////
  39. // DxilConvergent.
  40. // Mark convergent to avoid sample coordnate calculation sink into control flow.
  41. //
  42. namespace {
  43. class DxilConvergentMark : public ModulePass {
  44. public:
  45. static char ID; // Pass identification, replacement for typeid
  46. explicit DxilConvergentMark() : ModulePass(ID) {}
  47. const char *getPassName() const override {
  48. return "DxilConvergentMark";
  49. }
  50. bool runOnModule(Module &M) override {
  51. if (M.HasHLModule()) {
  52. if (!M.GetHLModule().GetShaderModel()->IsPS())
  53. return false;
  54. }
  55. bool bUpdated = false;
  56. for (Function &F : M.functions()) {
  57. if (F.isDeclaration())
  58. continue;
  59. // Compute postdominator relation.
  60. DominatorTreeBase<BasicBlock> PDR(true);
  61. PDR.recalculate(F);
  62. for (BasicBlock &bb : F.getBasicBlockList()) {
  63. for (auto it = bb.begin(); it != bb.end();) {
  64. Instruction *I = (it++);
  65. if (Value *V = FindConvergentOperand(I)) {
  66. if (PropagateConvergent(V, &F, PDR)) {
  67. // TODO: emit warning here.
  68. }
  69. bUpdated = true;
  70. }
  71. }
  72. }
  73. }
  74. return bUpdated;
  75. }
  76. private:
  77. void MarkConvergent(Value *V, IRBuilder<> &Builder, Module &M);
  78. Value *FindConvergentOperand(Instruction *I);
  79. bool PropagateConvergent(Value *V, Function *F,
  80. DominatorTreeBase<BasicBlock> &PostDom);
  81. bool PropagateConvergentImpl(Value *V, Function *F,
  82. DominatorTreeBase<BasicBlock> &PostDom, std::set<Value*>& visited);
  83. };
  84. char DxilConvergentMark::ID = 0;
  85. void DxilConvergentMark::MarkConvergent(Value *V, IRBuilder<> &Builder,
  86. Module &M) {
  87. Type *Ty = V->getType()->getScalarType();
  88. // Only work on vector/scalar types.
  89. if (Ty->isAggregateType() ||
  90. Ty->isPointerTy())
  91. return;
  92. FunctionType *FT = FunctionType::get(Ty, Ty, false);
  93. std::string str = kConvergentFunctionPrefix;
  94. raw_string_ostream os(str);
  95. Ty->print(os);
  96. os.flush();
  97. Function *ConvF = cast<Function>(M.getOrInsertFunction(str, FT));
  98. ConvF->addFnAttr(Attribute::AttrKind::Convergent);
  99. if (VectorType *VT = dyn_cast<VectorType>(V->getType())) {
  100. Value *ConvV = UndefValue::get(V->getType());
  101. std::vector<ExtractElementInst *> extractList(VT->getNumElements());
  102. for (unsigned i = 0; i < VT->getNumElements(); i++) {
  103. ExtractElementInst *EltV =
  104. cast<ExtractElementInst>(Builder.CreateExtractElement(V, i));
  105. extractList[i] = EltV;
  106. Value *EltC = Builder.CreateCall(ConvF, {EltV});
  107. ConvV = Builder.CreateInsertElement(ConvV, EltC, i);
  108. }
  109. V->replaceAllUsesWith(ConvV);
  110. for (ExtractElementInst *E : extractList) {
  111. E->setOperand(0, V);
  112. }
  113. } else {
  114. CallInst *ConvV = Builder.CreateCall(ConvF, {V});
  115. V->replaceAllUsesWith(ConvV);
  116. ConvV->setOperand(0, V);
  117. }
  118. }
  119. bool DxilConvergentMark::PropagateConvergent(
  120. Value *V, Function *F, DominatorTreeBase<BasicBlock> &PostDom) {
  121. std::set<Value *> visited;
  122. return PropagateConvergentImpl(V, F, PostDom, visited);
  123. }
  124. bool DxilConvergentMark::PropagateConvergentImpl(Value *V, Function *F,
  125. DominatorTreeBase<BasicBlock> &PostDom, std::set<Value*>& visited) {
  126. // Don't go through already visted nodes
  127. if (visited.find(V) != visited.end())
  128. return false;
  129. // Mark as visited
  130. visited.insert(V);
  131. // Skip constant.
  132. if (isa<Constant>(V))
  133. return false;
  134. // Skip phi which cannot sink.
  135. if (isa<PHINode>(V))
  136. return false;
  137. if (Instruction *I = dyn_cast<Instruction>(V)) {
  138. BasicBlock *BB = I->getParent();
  139. if (PostDom.dominates(BB, &F->getEntryBlock())) {
  140. IRBuilder<> Builder(I->getNextNode());
  141. MarkConvergent(I, Builder, *F->getParent());
  142. return false;
  143. } else {
  144. // Propagete to each operand of I.
  145. for (Use &U : I->operands()) {
  146. PropagateConvergentImpl(U.get(), F, PostDom, visited);
  147. }
  148. // return true for report warning.
  149. // TODO: static indexing cbuffer is fine.
  150. return true;
  151. }
  152. } else {
  153. IRBuilder<> EntryBuilder(F->getEntryBlock().getFirstInsertionPt());
  154. MarkConvergent(V, EntryBuilder, *F->getParent());
  155. return false;
  156. }
  157. }
  158. Value *DxilConvergentMark::FindConvergentOperand(Instruction *I) {
  159. if (CallInst *CI = dyn_cast<CallInst>(I)) {
  160. if (hlsl::GetHLOpcodeGroup(CI->getCalledFunction()) ==
  161. HLOpcodeGroup::HLIntrinsic) {
  162. IntrinsicOp IOP = static_cast<IntrinsicOp>(GetHLOpcode(CI));
  163. switch (IOP) {
  164. case IntrinsicOp::IOP_ddx:
  165. case IntrinsicOp::IOP_ddx_fine:
  166. case IntrinsicOp::IOP_ddx_coarse:
  167. case IntrinsicOp::IOP_ddy:
  168. case IntrinsicOp::IOP_ddy_fine:
  169. case IntrinsicOp::IOP_ddy_coarse:
  170. return CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  171. case IntrinsicOp::MOP_Sample:
  172. case IntrinsicOp::MOP_SampleBias:
  173. case IntrinsicOp::MOP_SampleCmp:
  174. case IntrinsicOp::MOP_CalculateLevelOfDetail:
  175. case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
  176. return CI->getArgOperand(HLOperandIndex::kSampleCoordArgIndex);
  177. case IntrinsicOp::MOP_WriteSamplerFeedback:
  178. case IntrinsicOp::MOP_WriteSamplerFeedbackBias:
  179. return CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackCoordArgIndex);
  180. default:
  181. // No other ops have convergent operands.
  182. break;
  183. }
  184. }
  185. }
  186. return nullptr;
  187. }
  188. } // namespace
  189. INITIALIZE_PASS(DxilConvergentMark, "hlsl-dxil-convergent-mark",
  190. "Mark convergent", false, false)
  191. ModulePass *llvm::createDxilConvergentMarkPass() {
  192. return new DxilConvergentMark();
  193. }
  194. namespace {
  195. class DxilConvergentClear : public ModulePass {
  196. public:
  197. static char ID; // Pass identification, replacement for typeid
  198. explicit DxilConvergentClear() : ModulePass(ID) {}
  199. const char *getPassName() const override {
  200. return "DxilConvergentClear";
  201. }
  202. bool runOnModule(Module &M) override {
  203. std::vector<Function *> convergentList;
  204. for (Function &F : M.functions()) {
  205. if (F.getName().startswith(kConvergentFunctionPrefix)) {
  206. convergentList.emplace_back(&F);
  207. }
  208. }
  209. for (Function *F : convergentList) {
  210. ClearConvergent(F);
  211. }
  212. return convergentList.size();
  213. }
  214. private:
  215. void ClearConvergent(Function *F);
  216. };
  217. char DxilConvergentClear::ID = 0;
  218. void DxilConvergentClear::ClearConvergent(Function *F) {
  219. // Replace all users with arg.
  220. for (auto it = F->user_begin(); it != F->user_end();) {
  221. CallInst *CI = cast<CallInst>(*(it++));
  222. Value *arg = CI->getArgOperand(0);
  223. CI->replaceAllUsesWith(arg);
  224. CI->eraseFromParent();
  225. }
  226. F->eraseFromParent();
  227. }
  228. } // namespace
  229. INITIALIZE_PASS(DxilConvergentClear, "hlsl-dxil-convergent-clear",
  230. "Clear convergent before dxil emit", false, false)
  231. ModulePass *llvm::createDxilConvergentClearPass() {
  232. return new DxilConvergentClear();
  233. }