DxilEliminateOutputDynamicIndexing.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilEliminateOutputDynamicIndexing.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Eliminate dynamic indexing on output. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/DXIL/DxilSignatureElement.h"
  14. #include "dxc/DXIL/DxilModule.h"
  15. #include "dxc/DXIL/DxilUtil.h"
  16. #include "dxc/Support/Global.h"
  17. #include "dxc/DXIL/DxilInstructions.h"
  18. #include "llvm/IR/Module.h"
  19. #include "llvm/Pass.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/ADT/MapVector.h"
  22. using namespace llvm;
  23. using namespace hlsl;
  24. namespace {
  25. class DxilEliminateOutputDynamicIndexing : public ModulePass {
  26. private:
  27. public:
  28. static char ID; // Pass identification, replacement for typeid
  29. explicit DxilEliminateOutputDynamicIndexing() : ModulePass(ID) {}
  30. const char *getPassName() const override {
  31. return "DXIL eliminate output dynamic indexing";
  32. }
  33. bool runOnModule(Module &M) override {
  34. DxilModule &DM = M.GetOrCreateDxilModule();
  35. bool bUpdated = false;
  36. if (DM.GetShaderModel()->IsHS()) {
  37. // HS write outputs into share memory, dynamic indexing is OK.
  38. return bUpdated;
  39. }
  40. // Skip pass thru entry.
  41. if (!DM.GetEntryFunction())
  42. return bUpdated;
  43. hlsl::OP *hlslOP = DM.GetOP();
  44. bUpdated |=
  45. EliminateDynamicOutput(hlslOP, DXIL::OpCode::StoreOutput,
  46. DM.GetOutputSignature(), DM.GetEntryFunction());
  47. return bUpdated;
  48. }
  49. private:
  50. bool EliminateDynamicOutput(hlsl::OP *hlslOP, DXIL::OpCode opcode, DxilSignature &outputSig, Function *Entry);
  51. void ReplaceDynamicOutput(ArrayRef<Value *> tmpSigElts, Value * sigID, Value *zero, Function *F);
  52. void StoreTmpSigToOutput(ArrayRef<Value *> tmpSigElts, unsigned row,
  53. Value *opcode, Value *sigID, Function *StoreOutput,
  54. Function *Entry);
  55. };
  56. // Wrapper for StoreOutput and StorePachConstant which has same signature.
  57. // void (opcode, sigId, rowIndex, colIndex, value);
  58. class DxilOutputStore {
  59. public:
  60. const llvm::CallInst *Instr;
  61. // Construction and identification
  62. DxilOutputStore(llvm::CallInst *pInstr) : Instr(pInstr) {}
  63. // Validation support
  64. bool isAllowed() const { return true; }
  65. bool isArgumentListValid() const {
  66. if (5 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
  67. return false;
  68. return true;
  69. }
  70. // Accessors
  71. llvm::Value *get_outputSigId() const {
  72. return Instr->getOperand(DXIL::OperandIndex::kStoreOutputIDOpIdx);
  73. }
  74. llvm::Value *get_rowIndex() const {
  75. return Instr->getOperand(DXIL::OperandIndex::kStoreOutputRowOpIdx);
  76. }
  77. uint64_t get_colIndex() const {
  78. Value *col = Instr->getOperand(DXIL::OperandIndex::kStoreOutputColOpIdx);
  79. return cast<ConstantInt>(col)->getLimitedValue();
  80. }
  81. llvm::Value *get_value() const {
  82. return Instr->getOperand(DXIL::OperandIndex::kStoreOutputValOpIdx);
  83. }
  84. };
  85. bool DxilEliminateOutputDynamicIndexing::EliminateDynamicOutput(
  86. hlsl::OP *hlslOP, DXIL::OpCode opcode, DxilSignature &outputSig,
  87. Function *Entry) {
  88. auto &storeOutputs =
  89. hlslOP->GetOpFuncList(opcode);
  90. MapVector<Value *, Type *> dynamicSigSet;
  91. for (auto it : storeOutputs) {
  92. Function *F = it.second;
  93. // Skip overload not used.
  94. if (!F)
  95. continue;
  96. for (User *U : F->users()) {
  97. CallInst *CI = cast<CallInst>(U);
  98. DxilOutputStore store(CI);
  99. // Save dynamic indeed sigID.
  100. if (!isa<ConstantInt>(store.get_rowIndex())) {
  101. Value *sigID = store.get_outputSigId();
  102. dynamicSigSet[sigID] = store.get_value()->getType();
  103. }
  104. }
  105. }
  106. if (dynamicSigSet.empty())
  107. return false;
  108. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Entry));
  109. Value *opcodeV = AllocaBuilder.getInt32(static_cast<unsigned>(opcode));
  110. Value *zero = AllocaBuilder.getInt32(0);
  111. for (auto sig : dynamicSigSet) {
  112. Value *sigID = sig.first;
  113. Type *EltTy = sig.second;
  114. unsigned ID = cast<ConstantInt>(sigID)->getLimitedValue();
  115. DxilSignatureElement &sigElt = outputSig.GetElement(ID);
  116. unsigned row = sigElt.GetRows();
  117. unsigned col = sigElt.GetCols();
  118. Type *AT = ArrayType::get(EltTy, row);
  119. std::vector<Value *> tmpSigElts(col);
  120. for (unsigned c = 0; c < col; c++) {
  121. Value *newCol = AllocaBuilder.CreateAlloca(AT);
  122. tmpSigElts[c] = newCol;
  123. }
  124. Function *F = hlslOP->GetOpFunc(opcode, EltTy);
  125. // Change store output to store tmpSigElts.
  126. ReplaceDynamicOutput(tmpSigElts, sigID, zero, F);
  127. // Store tmpSigElts to Output before return.
  128. StoreTmpSigToOutput(tmpSigElts, row, opcodeV, sigID, F, Entry);
  129. }
  130. return true;
  131. }
  132. void DxilEliminateOutputDynamicIndexing::ReplaceDynamicOutput(
  133. ArrayRef<Value *> tmpSigElts, Value *sigID, Value *zero, Function *F) {
  134. for (auto it = F->user_begin(); it != F->user_end();) {
  135. CallInst *CI = cast<CallInst>(*(it++));
  136. DxilOutputStore store(CI);
  137. if (sigID == store.get_outputSigId()) {
  138. uint64_t col = store.get_colIndex();
  139. Value *tmpSigElt = tmpSigElts[col];
  140. IRBuilder<> Builder(CI);
  141. Value *r = store.get_rowIndex();
  142. // Store to tmpSigElt.
  143. Value *GEP = Builder.CreateInBoundsGEP(tmpSigElt, {zero, r});
  144. Builder.CreateStore(store.get_value(), GEP);
  145. // Remove store output.
  146. CI->eraseFromParent();
  147. }
  148. }
  149. }
  150. void DxilEliminateOutputDynamicIndexing::StoreTmpSigToOutput(
  151. ArrayRef<Value *> tmpSigElts, unsigned row, Value *opcode, Value *sigID,
  152. Function *StoreOutput, Function *Entry) {
  153. Value *args[] = {opcode, sigID, /*row*/ nullptr, /*col*/ nullptr,
  154. /*val*/ nullptr};
  155. // Store the tmpSigElts to Output before every return.
  156. for (auto &BB : Entry->getBasicBlockList()) {
  157. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  158. IRBuilder<> Builder(RI);
  159. Value *zero = Builder.getInt32(0);
  160. for (unsigned c = 0; c<tmpSigElts.size(); c++) {
  161. Value *col = tmpSigElts[c];
  162. args[DXIL::OperandIndex::kStoreOutputColOpIdx] = Builder.getInt8(c);
  163. for (unsigned r = 0; r < row; r++) {
  164. Value *GEP =
  165. Builder.CreateInBoundsGEP(col, {zero, Builder.getInt32(r)});
  166. Value *V = Builder.CreateLoad(GEP);
  167. args[DXIL::OperandIndex::kStoreOutputRowOpIdx] = Builder.getInt32(r);
  168. args[DXIL::OperandIndex::kStoreOutputValOpIdx] = V;
  169. Builder.CreateCall(StoreOutput, args);
  170. }
  171. }
  172. }
  173. }
  174. }
  175. }
  176. char DxilEliminateOutputDynamicIndexing::ID = 0;
  177. ModulePass *llvm::createDxilEliminateOutputDynamicIndexingPass() {
  178. return new DxilEliminateOutputDynamicIndexing();
  179. }
  180. INITIALIZE_PASS(DxilEliminateOutputDynamicIndexing,
  181. "hlsl-dxil-eliminate-output-dynamic",
  182. "DXIL eliminate output dynamic indexing", false, false)