DxilPreserveAllOutputs.cpp 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPreserveAllOutputs.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Ensure we store to all elements in the output signature. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/HLSL/DxilOperations.h"
  13. #include "dxc/HLSL/DxilSignatureElement.h"
  14. #include "dxc/HLSL/DxilModule.h"
  15. #include "dxc/Support/Global.h"
  16. #include "dxc/HLSL/DxilInstructions.h"
  17. #include "llvm/IR/Module.h"
  18. #include "llvm/IR/InstIterator.h"
  19. #include "llvm/Pass.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include <llvm/ADT/DenseSet.h>
  22. using namespace llvm;
  23. using namespace hlsl;
  24. namespace {
  25. class OutputWrite {
  26. public:
  27. explicit OutputWrite(CallInst *call)
  28. : m_Call(call)
  29. {
  30. assert(DxilInst_StoreOutput(call) || DxilInst_StorePatchConstant(call));
  31. }
  32. unsigned GetSignatureID() const {
  33. Value *id = m_Call->getOperand(SignatureIndex);
  34. return cast<ConstantInt>(id)->getLimitedValue();
  35. }
  36. DxilSignatureElement &GetSignatureElement(DxilModule &DM) const {
  37. if (DxilInst_StorePatchConstant(m_Call))
  38. return DM.GetPatchConstantSignature().GetElement(GetSignatureID());
  39. else
  40. return DM.GetOutputSignature().GetElement(GetSignatureID());
  41. }
  42. CallInst *GetStore() const {
  43. return m_Call;
  44. }
  45. Value *GetValue() const {
  46. return m_Call->getOperand(ValueIndex);
  47. }
  48. Value *GetRow() const {
  49. return m_Call->getOperand(RowIndex);
  50. }
  51. Value *GetColumn() const {
  52. return m_Call->getOperand(ColumnIndex);
  53. }
  54. void DeleteStore() {
  55. m_Call->eraseFromParent();
  56. m_Call = nullptr;
  57. }
  58. private:
  59. CallInst *m_Call;
  60. enum OperandIndex {
  61. SignatureIndex = 1,
  62. RowIndex = 2,
  63. ColumnIndex = 3,
  64. ValueIndex = 4,
  65. };
  66. };
  67. class OutputElement {
  68. public:
  69. explicit OutputElement(const DxilSignatureElement &outputElement)
  70. : m_OutputElement(outputElement)
  71. , m_Rows(outputElement.GetRows())
  72. , m_Columns(outputElement.GetCols())
  73. {
  74. }
  75. void CreateAlloca(IRBuilder<> &allocaBuilder) {
  76. LLVMContext &context = allocaBuilder.getContext();
  77. Type *elementType = m_OutputElement.GetCompType().GetLLVMType(context);
  78. Type *allocaType = nullptr;
  79. if (IsSingleElement())
  80. allocaType = elementType;
  81. else
  82. allocaType = ArrayType::get(elementType, NumElements());
  83. m_Alloca = allocaBuilder.CreateAlloca(allocaType, nullptr, m_OutputElement.GetName());
  84. }
  85. void StoreTemp(IRBuilder<> &builder, Value *row, Value *col, Value *value) const {
  86. Value *addr = GetTempAddr(builder, row, col);
  87. builder.CreateStore(value, addr);
  88. }
  89. void StoreOutput(IRBuilder<> &builder, DxilModule &DM) const {
  90. for (unsigned row = 0; row < m_Rows; ++row)
  91. for (unsigned col = 0; col < m_Columns; ++col) {
  92. StoreOutput(builder, DM, row, col);
  93. }
  94. }
  95. unsigned NumElements() const {
  96. return m_Rows * m_Columns;
  97. }
  98. private:
  99. const DxilSignatureElement &m_OutputElement;
  100. unsigned m_Rows;
  101. unsigned m_Columns;
  102. AllocaInst *m_Alloca;
  103. bool IsSingleElement() const {
  104. return m_Rows == 1 && m_Columns == 1;
  105. }
  106. Value *GetAsI32(IRBuilder<> &builder, Value *col) const {
  107. assert(col->getType()->isIntegerTy());
  108. Type *i32Ty = builder.getInt32Ty();
  109. if (col->getType() != i32Ty) {
  110. if (col->getType()->getScalarSizeInBits() > i32Ty->getScalarSizeInBits())
  111. col = builder.CreateTrunc(col, i32Ty);
  112. else
  113. col = builder.CreateZExt(col, i32Ty);
  114. }
  115. return col;
  116. }
  117. Value *GetTempAddr(IRBuilder<> &builder, Value *row, Value *col) const {
  118. // Load directly from alloca for non-array output.
  119. if (IsSingleElement())
  120. return m_Alloca;
  121. else
  122. return CreateGEP(builder, row, col);
  123. }
  124. Value *CreateGEP(IRBuilder<> &builder, Value *row, Value *col) const {
  125. assert(m_Alloca);
  126. Constant *rowStride = ConstantInt::get(row->getType(), m_Columns);
  127. Value *rowOffset = builder.CreateMul(row, rowStride);
  128. Value *index = builder.CreateAdd(rowOffset, GetAsI32(builder, col));
  129. return builder.CreateInBoundsGEP(m_Alloca, {builder.getInt32(0), index});
  130. }
  131. Value *LoadTemp(IRBuilder<> &builder, Value *row, Value *col) const {
  132. Value *addr = GetTempAddr(builder, row, col);
  133. return builder.CreateLoad(addr);
  134. }
  135. void StoreOutput(IRBuilder<> &builder, DxilModule &DM, unsigned row, unsigned col) const {
  136. Value *opcodeV = builder.getInt32(static_cast<unsigned>(GetOutputOpCode()));
  137. Value *sigID = builder.getInt32(m_OutputElement.GetID());
  138. Value *rowV = builder.getInt32(row);
  139. Value *colV = builder.getInt8(col);
  140. Value *val = LoadTemp(builder, rowV, colV);
  141. Value *args[] = { opcodeV, sigID, rowV, colV, val };
  142. Function *Store = GetOutputFunction(DM);
  143. builder.CreateCall(Store, args);
  144. }
  145. DXIL::OpCode GetOutputOpCode() const {
  146. if (m_OutputElement.IsPatchConstant())
  147. return DXIL::OpCode::StorePatchConstant;
  148. else
  149. return DXIL::OpCode::StoreOutput;
  150. }
  151. Function *GetOutputFunction(DxilModule &DM) const {
  152. hlsl::OP *opInfo = DM.GetOP();
  153. return opInfo->GetOpFunc(GetOutputOpCode(), m_OutputElement.GetCompType().GetLLVMBaseType(DM.GetCtx()));
  154. }
  155. };
  156. class DxilPreserveAllOutputs : public FunctionPass {
  157. private:
  158. public:
  159. static char ID; // Pass identification, replacement for typeid
  160. DxilPreserveAllOutputs() : FunctionPass(ID) {}
  161. const char *getPassName() const override {
  162. return "DXIL preserve all outputs";
  163. }
  164. bool runOnFunction(Function &F) override;
  165. private:
  166. typedef std::vector<OutputWrite> OutputVec;
  167. typedef std::unordered_map<unsigned, OutputElement> OutputMap;
  168. OutputVec collectOutputStores(Function &F);
  169. OutputMap generateOutputMap(const OutputVec &calls, DxilModule &DM);
  170. void createTempAllocas(OutputMap &map, IRBuilder<> &builder);
  171. void insertTempOutputStores(const OutputVec &calls, const OutputMap &map, IRBuilder<> &builder);
  172. void insertFinalOutputStores(Function &F, const OutputMap &outputMap, IRBuilder<> &builder, DxilModule &DM);
  173. void removeOriginalOutputStores(OutputVec &outputStores);
  174. };
  175. bool DxilPreserveAllOutputs::runOnFunction(Function &F) {
  176. DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
  177. OutputVec outputStores = collectOutputStores(F);
  178. if (outputStores.empty())
  179. return false;
  180. IRBuilder<> builder(F.getEntryBlock().getFirstInsertionPt());
  181. OutputMap outputMap = generateOutputMap(outputStores, DM);
  182. createTempAllocas(outputMap, builder);
  183. insertTempOutputStores(outputStores, outputMap, builder);
  184. insertFinalOutputStores(F,outputMap, builder, DM);
  185. removeOriginalOutputStores(outputStores);
  186. return false;
  187. }
  188. DxilPreserveAllOutputs::OutputVec DxilPreserveAllOutputs::collectOutputStores(Function &F) {
  189. OutputVec calls;
  190. for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
  191. Instruction *inst = &*I;
  192. DxilInst_StoreOutput storeOutput(inst);
  193. DxilInst_StorePatchConstant storePatch(inst);
  194. if (storeOutput || storePatch)
  195. calls.emplace_back(cast<CallInst>(inst));
  196. }
  197. return calls;
  198. }
  199. DxilPreserveAllOutputs::OutputMap DxilPreserveAllOutputs::generateOutputMap(const OutputVec &calls, DxilModule &DM) {
  200. OutputMap map;
  201. for (const OutputWrite &output : calls) {
  202. unsigned sigID = output.GetSignatureID();
  203. if (map.count(sigID))
  204. continue;
  205. map.insert(std::make_pair(sigID, OutputElement(output.GetSignatureElement(DM))));
  206. }
  207. return map;
  208. }
  209. void DxilPreserveAllOutputs::createTempAllocas(OutputMap &outputMap, IRBuilder<> &allocaBuilder)
  210. {
  211. for (auto &iter: outputMap) {
  212. OutputElement &output = iter.second;
  213. output.CreateAlloca(allocaBuilder);
  214. }
  215. }
  216. void DxilPreserveAllOutputs::insertTempOutputStores(const OutputVec &writes, const OutputMap &map, IRBuilder<>& builder)
  217. {
  218. for (const OutputWrite& outputWrite : writes) {
  219. OutputMap::const_iterator iter = map.find(outputWrite.GetSignatureID());
  220. assert(iter != map.end());
  221. const OutputElement &output = iter->second;
  222. builder.SetInsertPoint(outputWrite.GetStore());
  223. output.StoreTemp(builder, outputWrite.GetRow(), outputWrite.GetColumn(), outputWrite.GetValue());
  224. }
  225. }
  226. void DxilPreserveAllOutputs::insertFinalOutputStores(Function &F, const OutputMap & outputMap, IRBuilder<>& builder, DxilModule & DM)
  227. {
  228. // Find all return instructions.
  229. SmallVector<ReturnInst *, 4> returns;
  230. for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
  231. Instruction *inst = &*I;
  232. if (ReturnInst *ret = dyn_cast<ReturnInst>(inst))
  233. returns.push_back(ret);
  234. }
  235. // Write all outputs before each return.
  236. for (ReturnInst *ret : returns) {
  237. for (const auto &iter : outputMap) {
  238. const OutputElement &output = iter.second;
  239. builder.SetInsertPoint(ret);
  240. output.StoreOutput(builder, DM);
  241. }
  242. }
  243. }
  244. void DxilPreserveAllOutputs::removeOriginalOutputStores(OutputVec & outputStores)
  245. {
  246. for (OutputWrite &write : outputStores) {
  247. write.DeleteStore();
  248. }
  249. }
  250. }
  251. char DxilPreserveAllOutputs::ID = 0;
  252. FunctionPass *llvm::createDxilPreserveAllOutputsPass() {
  253. return new DxilPreserveAllOutputs();
  254. }
  255. INITIALIZE_PASS(DxilPreserveAllOutputs,
  256. "hlsl-dxil-preserve-all-outputs",
  257. "DXIL preserve all outputs", false, false)