DxilEliminateOutputDynamicIndexing.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilEliminateOutputDynamicIndexing.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Eliminate dynamic indexing on output. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/HLSL/DxilOperations.h"
  13. #include "dxc/HLSL/DxilSignatureElement.h"
  14. #include "dxc/HLSL/DxilModule.h"
  15. #include "dxc/HLSL/DxilUtil.h"
  16. #include "dxc/Support/Global.h"
  17. #include "dxc/HLSL/DxilInstructions.h"
  18. #include "llvm/IR/Module.h"
  19. #include "llvm/Pass.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/ADT/MapVector.h"
  22. using namespace llvm;
  23. using namespace hlsl;
  24. namespace {
  25. class DxilEliminateOutputDynamicIndexing : public ModulePass {
  26. private:
  27. public:
  28. static char ID; // Pass identification, replacement for typeid
  29. explicit DxilEliminateOutputDynamicIndexing() : ModulePass(ID) {}
  30. const char *getPassName() const override {
  31. return "DXIL eliminate ouptut dynamic indexing";
  32. }
  33. bool runOnModule(Module &M) override {
  34. DxilModule &DM = M.GetOrCreateDxilModule();
  35. bool bUpdated = false;
  36. if (DM.GetShaderModel()->IsHS()) {
  37. hlsl::OP *hlslOP = DM.GetOP();
  38. bUpdated = EliminateDynamicOutput(
  39. hlslOP, DXIL::OpCode::StorePatchConstant,
  40. DM.GetPatchConstantSignature(), DM.GetPatchConstantFunction());
  41. }
  42. // Skip pass thru entry.
  43. if (!DM.GetEntryFunction())
  44. return bUpdated;
  45. hlsl::OP *hlslOP = DM.GetOP();
  46. bUpdated |=
  47. EliminateDynamicOutput(hlslOP, DXIL::OpCode::StoreOutput,
  48. DM.GetOutputSignature(), DM.GetEntryFunction());
  49. return bUpdated;
  50. }
  51. private:
  52. bool EliminateDynamicOutput(hlsl::OP *hlslOP, DXIL::OpCode opcode, DxilSignature &outputSig, Function *Entry);
  53. void ReplaceDynamicOutput(ArrayRef<Value *> tmpSigElts, Value * sigID, Value *zero, Function *F);
  54. void StoreTmpSigToOutput(ArrayRef<Value *> tmpSigElts, unsigned row,
  55. Value *opcode, Value *sigID, Function *StoreOutput,
  56. Function *Entry);
  57. };
  58. // Wrapper for StoreOutput and StorePachConstant which has same signature.
  59. // void (opcode, sigId, rowIndex, colIndex, value);
  60. class DxilOutputStore {
  61. public:
  62. const llvm::CallInst *Instr;
  63. // Construction and identification
  64. DxilOutputStore(llvm::CallInst *pInstr) : Instr(pInstr) {}
  65. // Validation support
  66. bool isAllowed() const { return true; }
  67. bool isArgumentListValid() const {
  68. if (5 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
  69. return false;
  70. return true;
  71. }
  72. // Accessors
  73. llvm::Value *get_outputSigId() const {
  74. return Instr->getOperand(DXIL::OperandIndex::kStoreOutputIDOpIdx);
  75. }
  76. llvm::Value *get_rowIndex() const {
  77. return Instr->getOperand(DXIL::OperandIndex::kStoreOutputRowOpIdx);
  78. }
  79. uint64_t get_colIndex() const {
  80. Value *col = Instr->getOperand(DXIL::OperandIndex::kStoreOutputColOpIdx);
  81. return cast<ConstantInt>(col)->getLimitedValue();
  82. }
  83. llvm::Value *get_value() const {
  84. return Instr->getOperand(DXIL::OperandIndex::kStoreOutputValOpIdx);
  85. }
  86. };
  87. bool DxilEliminateOutputDynamicIndexing::EliminateDynamicOutput(
  88. hlsl::OP *hlslOP, DXIL::OpCode opcode, DxilSignature &outputSig,
  89. Function *Entry) {
  90. auto &storeOutputs =
  91. hlslOP->GetOpFuncList(opcode);
  92. MapVector<Value *, Type *> dynamicSigSet;
  93. for (auto it : storeOutputs) {
  94. Function *F = it.second;
  95. // Skip overload not used.
  96. if (!F)
  97. continue;
  98. for (User *U : F->users()) {
  99. CallInst *CI = cast<CallInst>(U);
  100. DxilOutputStore store(CI);
  101. // Save dynamic indeed sigID.
  102. if (!isa<ConstantInt>(store.get_rowIndex())) {
  103. Value *sigID = store.get_outputSigId();
  104. dynamicSigSet[sigID] = store.get_value()->getType();
  105. }
  106. }
  107. }
  108. if (dynamicSigSet.empty())
  109. return false;
  110. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Entry));
  111. Value *opcodeV = AllocaBuilder.getInt32(static_cast<unsigned>(opcode));
  112. Value *zero = AllocaBuilder.getInt32(0);
  113. for (auto sig : dynamicSigSet) {
  114. Value *sigID = sig.first;
  115. Type *EltTy = sig.second;
  116. unsigned ID = cast<ConstantInt>(sigID)->getLimitedValue();
  117. DxilSignatureElement &sigElt = outputSig.GetElement(ID);
  118. unsigned row = sigElt.GetRows();
  119. unsigned col = sigElt.GetCols();
  120. Type *AT = ArrayType::get(EltTy, row);
  121. std::vector<Value *> tmpSigElts(col);
  122. for (unsigned c = 0; c < col; c++) {
  123. Value *newCol = AllocaBuilder.CreateAlloca(AT);
  124. tmpSigElts[c] = newCol;
  125. }
  126. Function *F = hlslOP->GetOpFunc(opcode, EltTy);
  127. // Change store output to store tmpSigElts.
  128. ReplaceDynamicOutput(tmpSigElts, sigID, zero, F);
  129. // Store tmpSigElts to Output before return.
  130. StoreTmpSigToOutput(tmpSigElts, row, opcodeV, sigID, F, Entry);
  131. }
  132. return true;
  133. }
  134. void DxilEliminateOutputDynamicIndexing::ReplaceDynamicOutput(
  135. ArrayRef<Value *> tmpSigElts, Value *sigID, Value *zero, Function *F) {
  136. for (auto it = F->user_begin(); it != F->user_end();) {
  137. CallInst *CI = cast<CallInst>(*(it++));
  138. DxilOutputStore store(CI);
  139. if (sigID == store.get_outputSigId()) {
  140. uint64_t col = store.get_colIndex();
  141. Value *tmpSigElt = tmpSigElts[col];
  142. IRBuilder<> Builder(CI);
  143. Value *r = store.get_rowIndex();
  144. // Store to tmpSigElt.
  145. Value *GEP = Builder.CreateInBoundsGEP(tmpSigElt, {zero, r});
  146. Builder.CreateStore(store.get_value(), GEP);
  147. // Remove store output.
  148. CI->eraseFromParent();
  149. }
  150. }
  151. }
  152. void DxilEliminateOutputDynamicIndexing::StoreTmpSigToOutput(
  153. ArrayRef<Value *> tmpSigElts, unsigned row, Value *opcode, Value *sigID,
  154. Function *StoreOutput, Function *Entry) {
  155. Value *args[] = {opcode, sigID, /*row*/ nullptr, /*col*/ nullptr,
  156. /*val*/ nullptr};
  157. // Store the tmpSigElts to Output before every return.
  158. for (auto &BB : Entry->getBasicBlockList()) {
  159. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  160. IRBuilder<> Builder(RI);
  161. Value *zero = Builder.getInt32(0);
  162. for (unsigned c = 0; c<tmpSigElts.size(); c++) {
  163. Value *col = tmpSigElts[c];
  164. args[DXIL::OperandIndex::kStoreOutputColOpIdx] = Builder.getInt8(c);
  165. for (unsigned r = 0; r < row; r++) {
  166. Value *GEP =
  167. Builder.CreateInBoundsGEP(col, {zero, Builder.getInt32(r)});
  168. Value *V = Builder.CreateLoad(GEP);
  169. args[DXIL::OperandIndex::kStoreOutputRowOpIdx] = Builder.getInt32(r);
  170. args[DXIL::OperandIndex::kStoreOutputValOpIdx] = V;
  171. Builder.CreateCall(StoreOutput, args);
  172. }
  173. }
  174. }
  175. }
  176. }
  177. }
  178. char DxilEliminateOutputDynamicIndexing::ID = 0;
  179. ModulePass *llvm::createDxilEliminateOutputDynamicIndexingPass() {
  180. return new DxilEliminateOutputDynamicIndexing();
  181. }
  182. INITIALIZE_PASS(DxilEliminateOutputDynamicIndexing,
  183. "hlsl-dxil-eliminate-output-dynamic",
  184. "DXIL eliminate ouptut dynamic indexing", false, false)