DxilPreserveAllOutputs.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPreserveAllOutputs.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Ensure we store to all elements in the output signature. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/DXIL/DxilSignatureElement.h"
  14. #include "dxc/DXIL/DxilModule.h"
  15. #include "dxc/Support/Global.h"
  16. #include "dxc/DXIL/DxilInstructions.h"
  17. #include "llvm/IR/Module.h"
  18. #include "llvm/IR/InstIterator.h"
  19. #include "llvm/Pass.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include <llvm/ADT/DenseSet.h>
  22. using namespace llvm;
  23. using namespace hlsl;
  24. namespace {
  25. class OutputWrite {
  26. public:
  27. explicit OutputWrite(CallInst *call)
  28. : m_Call(call)
  29. {
  30. assert(DxilInst_StoreOutput(call) ||
  31. DxilInst_StoreVertexOutput(call) ||
  32. DxilInst_StorePrimitiveOutput(call) ||
  33. DxilInst_StorePatchConstant(call));
  34. }
  35. unsigned GetSignatureID() const {
  36. Value *id = m_Call->getOperand(SignatureIndex);
  37. return cast<ConstantInt>(id)->getLimitedValue();
  38. }
  39. DxilSignatureElement &GetSignatureElement(DxilModule &DM) const {
  40. if (DxilInst_StorePatchConstant(m_Call) || DxilInst_StorePrimitiveOutput(m_Call))
  41. return DM.GetPatchConstOrPrimSignature().GetElement(GetSignatureID());
  42. else
  43. return DM.GetOutputSignature().GetElement(GetSignatureID());
  44. }
  45. CallInst *GetStore() const {
  46. return m_Call;
  47. }
  48. Value *GetValue() const {
  49. return m_Call->getOperand(ValueIndex);
  50. }
  51. Value *GetRow() const {
  52. return m_Call->getOperand(RowIndex);
  53. }
  54. Value *GetColumn() const {
  55. return m_Call->getOperand(ColumnIndex);
  56. }
  57. void DeleteStore() {
  58. m_Call->eraseFromParent();
  59. m_Call = nullptr;
  60. }
  61. private:
  62. CallInst *m_Call;
  63. enum OperandIndex {
  64. SignatureIndex = 1,
  65. RowIndex = 2,
  66. ColumnIndex = 3,
  67. ValueIndex = 4,
  68. };
  69. };
  70. class OutputElement {
  71. public:
  72. explicit OutputElement(const DxilSignatureElement &outputElement)
  73. : m_OutputElement(outputElement)
  74. , m_Rows(outputElement.GetRows())
  75. , m_Columns(outputElement.GetCols())
  76. {
  77. }
  78. void CreateAlloca(IRBuilder<> &allocaBuilder) {
  79. LLVMContext &context = allocaBuilder.getContext();
  80. Type *elementType = m_OutputElement.GetCompType().GetLLVMType(context);
  81. Type *allocaType = nullptr;
  82. if (IsSingleElement())
  83. allocaType = elementType;
  84. else
  85. allocaType = ArrayType::get(elementType, NumElements());
  86. m_Alloca = allocaBuilder.CreateAlloca(allocaType, nullptr, m_OutputElement.GetName());
  87. }
  88. void StoreTemp(IRBuilder<> &builder, Value *row, Value *col, Value *value) const {
  89. Value *addr = GetTempAddr(builder, row, col);
  90. builder.CreateStore(value, addr);
  91. }
  92. void StoreOutput(IRBuilder<> &builder, DxilModule &DM) const {
  93. for (unsigned row = 0; row < m_Rows; ++row)
  94. for (unsigned col = 0; col < m_Columns; ++col) {
  95. StoreOutput(builder, DM, row, col);
  96. }
  97. }
  98. unsigned NumElements() const {
  99. return m_Rows * m_Columns;
  100. }
  101. private:
  102. const DxilSignatureElement &m_OutputElement;
  103. unsigned m_Rows;
  104. unsigned m_Columns;
  105. AllocaInst *m_Alloca;
  106. bool IsSingleElement() const {
  107. return m_Rows == 1 && m_Columns == 1;
  108. }
  109. Value *GetAsI32(IRBuilder<> &builder, Value *col) const {
  110. assert(col->getType()->isIntegerTy());
  111. Type *i32Ty = builder.getInt32Ty();
  112. if (col->getType() != i32Ty) {
  113. if (col->getType()->getScalarSizeInBits() > i32Ty->getScalarSizeInBits())
  114. col = builder.CreateTrunc(col, i32Ty);
  115. else
  116. col = builder.CreateZExt(col, i32Ty);
  117. }
  118. return col;
  119. }
  120. Value *GetTempAddr(IRBuilder<> &builder, Value *row, Value *col) const {
  121. // Load directly from alloca for non-array output.
  122. if (IsSingleElement())
  123. return m_Alloca;
  124. else
  125. return CreateGEP(builder, row, col);
  126. }
  127. Value *CreateGEP(IRBuilder<> &builder, Value *row, Value *col) const {
  128. assert(m_Alloca);
  129. Constant *rowStride = ConstantInt::get(row->getType(), m_Columns);
  130. Value *rowOffset = builder.CreateMul(row, rowStride);
  131. Value *index = builder.CreateAdd(rowOffset, GetAsI32(builder, col));
  132. return builder.CreateInBoundsGEP(m_Alloca, {builder.getInt32(0), index});
  133. }
  134. Value *LoadTemp(IRBuilder<> &builder, Value *row, Value *col) const {
  135. Value *addr = GetTempAddr(builder, row, col);
  136. return builder.CreateLoad(addr);
  137. }
  138. void StoreOutput(IRBuilder<> &builder, DxilModule &DM, unsigned row, unsigned col) const {
  139. Value *opcodeV = builder.getInt32(static_cast<unsigned>(GetOutputOpCode()));
  140. Value *sigID = builder.getInt32(m_OutputElement.GetID());
  141. Value *rowV = builder.getInt32(row);
  142. Value *colV = builder.getInt8(col);
  143. Value *val = LoadTemp(builder, rowV, colV);
  144. Value *args[] = { opcodeV, sigID, rowV, colV, val };
  145. Function *Store = GetOutputFunction(DM);
  146. builder.CreateCall(Store, args);
  147. }
  148. DXIL::OpCode GetOutputOpCode() const {
  149. if (m_OutputElement.IsPatchConstOrPrim()) {
  150. if (m_OutputElement.GetSigPointKind() == DXIL::SigPointKind::PCOut)
  151. return DXIL::OpCode::StorePatchConstant;
  152. else {
  153. assert(m_OutputElement.GetSigPointKind() == DXIL::SigPointKind::MSPOut);
  154. return DXIL::OpCode::StorePrimitiveOutput;
  155. }
  156. }
  157. else
  158. return DXIL::OpCode::StoreOutput;
  159. }
  160. Function *GetOutputFunction(DxilModule &DM) const {
  161. hlsl::OP *opInfo = DM.GetOP();
  162. return opInfo->GetOpFunc(GetOutputOpCode(), m_OutputElement.GetCompType().GetLLVMBaseType(DM.GetCtx()));
  163. }
  164. };
  165. class DxilPreserveAllOutputs : public FunctionPass {
  166. private:
  167. public:
  168. static char ID; // Pass identification, replacement for typeid
  169. DxilPreserveAllOutputs() : FunctionPass(ID) {}
  170. const char *getPassName() const override {
  171. return "DXIL preserve all outputs";
  172. }
  173. bool runOnFunction(Function &F) override;
  174. private:
  175. typedef std::vector<OutputWrite> OutputVec;
  176. typedef std::map<unsigned, OutputElement> OutputMap;
  177. OutputVec collectOutputStores(Function &F);
  178. OutputMap generateOutputMap(const OutputVec &calls, DxilModule &DM);
  179. void createTempAllocas(OutputMap &map, IRBuilder<> &builder);
  180. void insertTempOutputStores(const OutputVec &calls, const OutputMap &map, IRBuilder<> &builder);
  181. void insertFinalOutputStores(Function &F, const OutputMap &outputMap, IRBuilder<> &builder, DxilModule &DM);
  182. void removeOriginalOutputStores(OutputVec &outputStores);
  183. };
  184. bool DxilPreserveAllOutputs::runOnFunction(Function &F) {
  185. DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
  186. OutputVec outputStores = collectOutputStores(F);
  187. if (outputStores.empty())
  188. return false;
  189. IRBuilder<> builder(F.getEntryBlock().getFirstInsertionPt());
  190. OutputMap outputMap = generateOutputMap(outputStores, DM);
  191. createTempAllocas(outputMap, builder);
  192. insertTempOutputStores(outputStores, outputMap, builder);
  193. insertFinalOutputStores(F,outputMap, builder, DM);
  194. removeOriginalOutputStores(outputStores);
  195. return false;
  196. }
  197. DxilPreserveAllOutputs::OutputVec DxilPreserveAllOutputs::collectOutputStores(Function &F) {
  198. OutputVec calls;
  199. for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
  200. Instruction *inst = &*I;
  201. DxilInst_StoreOutput storeOutput(inst);
  202. DxilInst_StoreVertexOutput storeVertexOutput(inst);
  203. DxilInst_StorePrimitiveOutput storePrimitiveOutput(inst);
  204. DxilInst_StorePatchConstant storePatch(inst);
  205. if (storeOutput || storeVertexOutput || storePrimitiveOutput || storePatch)
  206. calls.emplace_back(cast<CallInst>(inst));
  207. }
  208. return calls;
  209. }
  210. DxilPreserveAllOutputs::OutputMap DxilPreserveAllOutputs::generateOutputMap(const OutputVec &calls, DxilModule &DM) {
  211. OutputMap map;
  212. for (const OutputWrite &output : calls) {
  213. unsigned sigID = output.GetSignatureID();
  214. if (map.count(sigID))
  215. continue;
  216. map.insert(std::make_pair(sigID, OutputElement(output.GetSignatureElement(DM))));
  217. }
  218. return map;
  219. }
  220. void DxilPreserveAllOutputs::createTempAllocas(OutputMap &outputMap, IRBuilder<> &allocaBuilder)
  221. {
  222. for (auto &iter: outputMap) {
  223. OutputElement &output = iter.second;
  224. output.CreateAlloca(allocaBuilder);
  225. }
  226. }
  227. void DxilPreserveAllOutputs::insertTempOutputStores(const OutputVec &writes, const OutputMap &map, IRBuilder<>& builder)
  228. {
  229. for (const OutputWrite& outputWrite : writes) {
  230. OutputMap::const_iterator iter = map.find(outputWrite.GetSignatureID());
  231. assert(iter != map.end());
  232. const OutputElement &output = iter->second;
  233. builder.SetInsertPoint(outputWrite.GetStore());
  234. output.StoreTemp(builder, outputWrite.GetRow(), outputWrite.GetColumn(), outputWrite.GetValue());
  235. }
  236. }
  237. void DxilPreserveAllOutputs::insertFinalOutputStores(Function &F, const OutputMap & outputMap, IRBuilder<>& builder, DxilModule & DM)
  238. {
  239. // Find all return instructions.
  240. SmallVector<ReturnInst *, 4> returns;
  241. for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
  242. Instruction *inst = &*I;
  243. if (ReturnInst *ret = dyn_cast<ReturnInst>(inst))
  244. returns.push_back(ret);
  245. }
  246. // Write all outputs before each return.
  247. for (ReturnInst *ret : returns) {
  248. for (const auto &iter : outputMap) {
  249. const OutputElement &output = iter.second;
  250. builder.SetInsertPoint(ret);
  251. output.StoreOutput(builder, DM);
  252. }
  253. }
  254. }
  255. void DxilPreserveAllOutputs::removeOriginalOutputStores(OutputVec & outputStores)
  256. {
  257. for (OutputWrite &write : outputStores) {
  258. write.DeleteStore();
  259. }
  260. }
  261. }
  262. char DxilPreserveAllOutputs::ID = 0;
  263. FunctionPass *llvm::createDxilPreserveAllOutputsPass() {
  264. return new DxilPreserveAllOutputs();
  265. }
  266. INITIALIZE_PASS(DxilPreserveAllOutputs,
  267. "hlsl-dxil-preserve-all-outputs",
  268. "DXIL preserve all outputs", false, false)