ComputeViewIdState.h 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // ComputeViewIdSets.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Computes output registers dependent on ViewID. //
  9. // Computes sets of input registers on which output registers depend. //
  10. // Computes which input/output shapes are dynamically indexed. //
  11. // //
  12. ///////////////////////////////////////////////////////////////////////////////
  13. #pragma once
  14. #include "llvm/Pass.h"
  15. #include "dxc/HLSL/ControlDependence.h"
  16. #include "llvm/Support/GenericDomTree.h"
  17. #include <memory>
  18. #include <bitset>
  19. #include <unordered_set>
  20. #include <unordered_map>
  21. #include <set>
  22. #include <map>
  23. namespace llvm {
  24. class Module;
  25. class Function;
  26. class BasicBlock;
  27. class Instruction;
  28. class ReturnInst;
  29. class Value;
  30. class PHINode;
  31. class AnalysisUsage;
  32. class CallGraph;
  33. class CallGraphNode;
  34. class ModulePass;
  35. class raw_ostream;
  36. }
  37. namespace hlsl {
  38. class DxilModule;
  39. class DxilSignature;
  40. class DxilSignatureElement;
  41. class DxilViewIdState {
  42. static const unsigned kNumComps = 4;
  43. static const unsigned kMaxSigScalars = 32*4;
  44. public:
  45. using OutputsDependentOnViewIdType = std::bitset<kMaxSigScalars>;
  46. using InputsContributingToOutputType = std::map<unsigned, std::set<unsigned>>;
  47. DxilViewIdState(DxilModule *pDxilModule);
  48. unsigned getNumInputSigScalars() const;
  49. unsigned getNumOutputSigScalars(unsigned StreamId) const;
  50. unsigned getNumPCSigScalars() const;
  51. const OutputsDependentOnViewIdType &getOutputsDependentOnViewId(unsigned StreamId) const;
  52. const OutputsDependentOnViewIdType &getPCOutputsDependentOnViewId() const;
  53. const InputsContributingToOutputType &getInputsContributingToOutputs(unsigned StreamId) const;
  54. const InputsContributingToOutputType &getInputsContributingToPCOutputs() const;
  55. const InputsContributingToOutputType &getPCInputsContributingToOutputs() const;
  56. void Compute();
  57. void Serialize();
  58. const std::vector<unsigned> &GetSerialized();
  59. const std::vector<unsigned> &GetSerialized() const; // returns previously serialized data
  60. void Deserialize(const unsigned *pData, unsigned DataSizeInUINTs);
  61. void PrintSets(llvm::raw_ostream &OS);
  62. private:
  63. static const unsigned kNumStreams = 4;
  64. DxilModule *m_pModule;
  65. bool m_bUsesViewId = false;
  66. unsigned m_NumInputSigScalars = 0;
  67. unsigned m_NumOutputSigScalars[kNumStreams] = {0,0,0,0};
  68. unsigned m_NumPCSigScalars = 0;
  69. // Dynamically indexed components of signature elements.
  70. using DynamicallyIndexedElemsType = std::unordered_map<unsigned, unsigned>;
  71. DynamicallyIndexedElemsType m_InpSigDynIdxElems;
  72. DynamicallyIndexedElemsType m_OutSigDynIdxElems;
  73. DynamicallyIndexedElemsType m_PCSigDynIdxElems;
  74. // Set of scalar outputs dependent on ViewID.
  75. OutputsDependentOnViewIdType m_OutputsDependentOnViewId[kNumStreams];
  76. OutputsDependentOnViewIdType m_PCOutputsDependentOnViewId;
  77. // Set of scalar inputs contributing to computation of scalar outputs.
  78. InputsContributingToOutputType m_InputsContributingToOutputs[kNumStreams];
  79. InputsContributingToOutputType m_InputsContributingToPCOutputs; // HS PC only.
  80. InputsContributingToOutputType m_PCInputsContributingToOutputs; // DS only.
  81. // Information per entry point.
  82. using FunctionSetType = std::unordered_set<llvm::Function *>;
  83. using InstructionSetType = std::unordered_set<llvm::Instruction *>;
  84. struct EntryInfo {
  85. llvm::Function *pEntryFunc = nullptr;
  86. // Sets of functions that may be reachable from an entry.
  87. FunctionSetType Functions;
  88. // Outputs to analyze.
  89. InstructionSetType Outputs;
  90. // Contributing instructions per output.
  91. std::unordered_map<unsigned, InstructionSetType> ContributingInstructions[kNumStreams];
  92. void Clear();
  93. };
  94. EntryInfo m_Entry;
  95. EntryInfo m_PCEntry;
  96. // Information per function.
  97. using FunctionReturnSet = std::unordered_set<llvm::ReturnInst *>;
  98. struct FuncInfo {
  99. FunctionReturnSet Returns;
  100. ControlDependence CtrlDep;
  101. std::unique_ptr<llvm::DominatorTreeBase<llvm::BasicBlock> > pDomTree;
  102. void Clear();
  103. };
  104. std::unordered_map<llvm::Function *, std::unique_ptr<FuncInfo>> m_FuncInfo;
  105. // Cache of decls (global/alloca) reaching a pointer value.
  106. using ValueSetType = std::unordered_set<llvm::Value *>;
  107. std::unordered_map<llvm::Value *, ValueSetType> m_ReachingDeclsCache;
  108. // Cache of stores for each decl.
  109. std::unordered_map<llvm::Value *, ValueSetType> m_StoresPerDeclCache;
  110. // Serialized form.
  111. std::vector<unsigned> m_SerializedState;
  112. void Clear();
  113. void DetermineMaxPackedLocation(DxilSignature &DxilSig, unsigned *pMaxSigLoc, unsigned NumStreams);
  114. void ComputeReachableFunctionsRec(llvm::CallGraph &CG, llvm::CallGraphNode *pNode, FunctionSetType &FuncSet);
  115. void AnalyzeFunctions(EntryInfo &Entry);
  116. void CollectValuesContributingToOutputs(EntryInfo &Entry);
  117. void CollectValuesContributingToOutputRec(EntryInfo &Entry,
  118. llvm::Value *pContributingValue,
  119. InstructionSetType &ContributingInstructions);
  120. void CollectPhiCFValuesContributingToOutputRec(llvm::PHINode *pPhi,
  121. EntryInfo &Entry,
  122. InstructionSetType &ContributingInstructions);
  123. const ValueSetType &CollectReachingDecls(llvm::Value *pValue);
  124. void CollectReachingDeclsRec(llvm::Value *pValue, ValueSetType &ReachingDecls, ValueSetType &Visited);
  125. const ValueSetType &CollectStores(llvm::Value *pValue);
  126. void CollectStoresRec(llvm::Value *pValue, ValueSetType &Stores, ValueSetType &Visited);
  127. void UpdateDynamicIndexUsageState() const;
  128. void CreateViewIdSets(const std::unordered_map<unsigned, InstructionSetType> &ContributingInstructions,
  129. OutputsDependentOnViewIdType &OutputsDependentOnViewId,
  130. InputsContributingToOutputType &InputsContributingToOutputs, bool bPC);
  131. void UpdateDynamicIndexUsageStateForSig(DxilSignature &Sig, const DynamicallyIndexedElemsType &DynIdxElems) const;
  132. void SerializeOutputsDependentOnViewId(unsigned NumOutputs,
  133. const OutputsDependentOnViewIdType &OutputsDependentOnViewId,
  134. unsigned *&pData);
  135. void SerializeInputsContributingToOutput(unsigned NumInputs, unsigned NumOutputs,
  136. const InputsContributingToOutputType &InputsContributingToOutputs,
  137. unsigned *&pData);
  138. unsigned DeserializeOutputsDependentOnViewId(unsigned NumOutputs,
  139. OutputsDependentOnViewIdType &OutputsDependentOnViewId,
  140. const unsigned *pData, unsigned DataSize);
  141. unsigned DeserializeInputsContributingToOutput(unsigned NumInputs, unsigned NumOutputs,
  142. InputsContributingToOutputType &InputsContributingToOutputs,
  143. const unsigned *pData, unsigned DataSize);
  144. unsigned GetLinearIndex(DxilSignatureElement &SigElem, int row, unsigned col) const;
  145. void PrintOutputsDependentOnViewId(llvm::raw_ostream &OS,
  146. llvm::StringRef SetName, unsigned NumOutputs,
  147. const OutputsDependentOnViewIdType &OutputsDependentOnViewId);
  148. void PrintInputsContributingToOutputs(llvm::raw_ostream &OS,
  149. llvm::StringRef InputSetName, llvm::StringRef OutputSetName,
  150. const InputsContributingToOutputType &InputsContributingToOutputs);
  151. };
  152. } // end of hlsl namespace
  153. namespace llvm {
  154. class ComputeViewIdState : public ModulePass {
  155. public:
  156. static char ID; // Pass ID, replacement for typeid
  157. ComputeViewIdState();
  158. bool runOnModule(Module &M) override;
  159. void getAnalysisUsage(AnalysisUsage &AU) const override;
  160. };
  161. void initializeComputeViewIdStatePass(llvm::PassRegistry &);
  162. llvm::ModulePass *createComputeViewIdStatePass();
  163. } // end of llvm namespace