DxilModule.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilModule.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // The main class to work with DXIL, similar to LLVM module. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #pragma once
  12. #include "dxc/HLSL/DxilMetadataHelper.h"
  13. #include "dxc/HLSL/DxilCBuffer.h"
  14. #include "dxc/HLSL/DxilResource.h"
  15. #include "dxc/HLSL/DxilSampler.h"
  16. #include "dxc/HLSL/DxilSignature.h"
  17. #include "dxc/HLSL/DxilConstants.h"
  18. #include "dxc/HLSL/DxilTypeSystem.h"
  19. #include "dxc/HLSL/ComputeViewIdState.h"
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. namespace llvm {
  24. class LLVMContext;
  25. class Module;
  26. class Function;
  27. class MDTuple;
  28. class MDOperand;
  29. class DebugInfoFinder;
  30. };
  31. namespace hlsl {
  32. class ShaderModel;
  33. class OP;
  34. class RootSignatureHandle;
  35. /// Use this class to manipulate DXIL of a shader.
  36. class DxilModule {
  37. public:
  38. DxilModule(llvm::Module *pModule);
  39. ~DxilModule();
  40. // Subsystems.
  41. llvm::LLVMContext &GetCtx() const;
  42. llvm::Module *GetModule() const;
  43. OP *GetOP() const;
  44. void SetShaderModel(const ShaderModel *pSM);
  45. const ShaderModel *GetShaderModel() const;
  46. void GetDxilVersion(unsigned &DxilMajor, unsigned &DxilMinor) const;
  47. void SetValidatorVersion(unsigned ValMajor, unsigned ValMinor);
  48. bool UpgradeValidatorVersion(unsigned ValMajor, unsigned ValMinor);
  49. void GetValidatorVersion(unsigned &ValMajor, unsigned &ValMinor) const;
  50. // Return true on success, requires valid shader model and CollectShaderFlags to have been set
  51. bool GetMinValidatorVersion(unsigned &ValMajor, unsigned &ValMinor) const;
  52. // Update validator version to minimum if higher than current (ex: after CollectShaderFlags)
  53. bool UpgradeToMinValidatorVersion();
  54. // Entry functions.
  55. llvm::Function *GetEntryFunction();
  56. const llvm::Function *GetEntryFunction() const;
  57. void SetEntryFunction(llvm::Function *pEntryFunc);
  58. const std::string &GetEntryFunctionName() const;
  59. void SetEntryFunctionName(const std::string &name);
  60. llvm::Function *GetPatchConstantFunction();
  61. const llvm::Function *GetPatchConstantFunction() const;
  62. void SetPatchConstantFunction(llvm::Function *pFunc);
  63. // Flags.
  64. unsigned GetGlobalFlags() const;
  65. // TODO: move out of DxilModule as a util.
  66. void CollectShaderFlags();
  67. // Resources.
  68. unsigned AddCBuffer(std::unique_ptr<DxilCBuffer> pCB);
  69. DxilCBuffer &GetCBuffer(unsigned idx);
  70. const DxilCBuffer &GetCBuffer(unsigned idx) const;
  71. const std::vector<std::unique_ptr<DxilCBuffer> > &GetCBuffers() const;
  72. unsigned AddSampler(std::unique_ptr<DxilSampler> pSampler);
  73. DxilSampler &GetSampler(unsigned idx);
  74. const DxilSampler &GetSampler(unsigned idx) const;
  75. const std::vector<std::unique_ptr<DxilSampler> > &GetSamplers() const;
  76. unsigned AddSRV(std::unique_ptr<DxilResource> pSRV);
  77. DxilResource &GetSRV(unsigned idx);
  78. const DxilResource &GetSRV(unsigned idx) const;
  79. const std::vector<std::unique_ptr<DxilResource> > &GetSRVs() const;
  80. unsigned AddUAV(std::unique_ptr<DxilResource> pUAV);
  81. DxilResource &GetUAV(unsigned idx);
  82. const DxilResource &GetUAV(unsigned idx) const;
  83. const std::vector<std::unique_ptr<DxilResource> > &GetUAVs() const;
  84. void LoadDxilResourceBaseFromMDNode(llvm::MDNode *MD, DxilResourceBase &R);
  85. void LoadDxilResourceFromMDNode(llvm::MDNode *MD, DxilResource &R);
  86. void LoadDxilSamplerFromMDNode(llvm::MDNode *MD, DxilSampler &S);
  87. void RemoveUnusedResources();
  88. void RemoveFunction(llvm::Function *F);
  89. // Signatures.
  90. DxilSignature &GetInputSignature();
  91. const DxilSignature &GetInputSignature() const;
  92. DxilSignature &GetOutputSignature();
  93. const DxilSignature &GetOutputSignature() const;
  94. DxilSignature &GetPatchConstantSignature();
  95. const DxilSignature &GetPatchConstantSignature() const;
  96. const RootSignatureHandle &GetRootSignature() const;
  97. // Remove Root Signature from module metadata
  98. void StripRootSignatureFromMetadata();
  99. // Update validator version metadata to current setting
  100. void UpdateValidatorVersionMetadata();
  101. // DXIL type system.
  102. DxilTypeSystem &GetTypeSystem();
  103. /// Emit llvm.used array to make sure that optimizations do not remove unreferenced globals.
  104. void EmitLLVMUsed();
  105. std::vector<llvm::GlobalVariable* > &GetLLVMUsed();
  106. // ViewId state.
  107. DxilViewIdState &GetViewIdState();
  108. const DxilViewIdState &GetViewIdState() const;
  109. // DXIL metadata manipulation.
  110. /// Serialize DXIL in-memory form to metadata form.
  111. void EmitDxilMetadata();
  112. /// Deserialize DXIL metadata form into in-memory form.
  113. void LoadDxilMetadata();
  114. /// Check if a Named meta data node is known by dxil module.
  115. static bool IsKnownNamedMetaData(llvm::NamedMDNode &Node);
  116. // Reset functions used to transfer ownership.
  117. void ResetInputSignature(DxilSignature *pValue);
  118. void ResetOutputSignature(DxilSignature *pValue);
  119. void ResetPatchConstantSignature(DxilSignature *pValue);
  120. void ResetRootSignature(RootSignatureHandle *pValue);
  121. void ResetTypeSystem(DxilTypeSystem *pValue);
  122. void ResetOP(hlsl::OP *hlslOP);
  123. void StripDebugRelatedCode();
  124. llvm::DebugInfoFinder &GetOrCreateDebugInfoFinder();
  125. static DxilModule *TryGetDxilModule(llvm::Module *pModule);
  126. public:
  127. // Shader properties.
  128. class ShaderFlags {
  129. public:
  130. ShaderFlags();
  131. unsigned GetGlobalFlags() const;
  132. void SetDisableOptimizations(bool flag) { m_bDisableOptimizations = flag; }
  133. bool GetDisableOptimizations() const { return m_bDisableOptimizations; }
  134. void SetDisableMathRefactoring(bool flag) { m_bDisableMathRefactoring = flag; }
  135. bool GetDisableMathRefactoring() const { return m_bDisableMathRefactoring; }
  136. void SetEnableDoublePrecision(bool flag) { m_bEnableDoublePrecision = flag; }
  137. bool GetEnableDoublePrecision() const { return m_bEnableDoublePrecision; }
  138. void SetForceEarlyDepthStencil(bool flag) { m_bForceEarlyDepthStencil = flag; }
  139. bool GetForceEarlyDepthStencil() const { return m_bForceEarlyDepthStencil; }
  140. void SetEnableRawAndStructuredBuffers(bool flag) { m_bEnableRawAndStructuredBuffers = flag; }
  141. bool GetEnableRawAndStructuredBuffers() const { return m_bEnableRawAndStructuredBuffers; }
  142. void SetEnableMinPrecision(bool flag) { m_bEnableMinPrecision = flag; }
  143. bool GetEnableMinPrecision() const { return m_bEnableMinPrecision; }
  144. void SetEnableDoubleExtensions(bool flag) { m_bEnableDoubleExtensions = flag; }
  145. bool GetEnableDoubleExtensions() const { return m_bEnableDoubleExtensions; }
  146. void SetEnableMSAD(bool flag) { m_bEnableMSAD = flag; }
  147. bool GetEnableMSAD() const { return m_bEnableMSAD; }
  148. void SetAllResourcesBound(bool flag) { m_bAllResourcesBound = flag; }
  149. bool GetAllResourcesBound() const { return m_bAllResourcesBound; }
  150. uint64_t GetFeatureInfo() const;
  151. void SetCSRawAndStructuredViaShader4X(bool flag) { m_bCSRawAndStructuredViaShader4X = flag; }
  152. bool GetCSRawAndStructuredViaShader4X() const { return m_bCSRawAndStructuredViaShader4X; }
  153. void SetROVs(bool flag) { m_bROVS = flag; }
  154. bool GetROVs() const { return m_bROVS; }
  155. void SetWaveOps(bool flag) { m_bWaveOps = flag; }
  156. bool GetWaveOps() const { return m_bWaveOps; }
  157. void SetInt64Ops(bool flag) { m_bInt64Ops = flag; }
  158. bool GetInt64Ops() const { return m_bInt64Ops; }
  159. void SetTiledResources(bool flag) { m_bTiledResources = flag; }
  160. bool GetTiledResources() const { return m_bTiledResources; }
  161. void SetStencilRef(bool flag) { m_bStencilRef = flag; }
  162. bool GetStencilRef() const { return m_bStencilRef; }
  163. void SetInnerCoverage(bool flag) { m_bInnerCoverage = flag; }
  164. bool GetInnerCoverage() const { return m_bInnerCoverage; }
  165. void SetViewportAndRTArrayIndex(bool flag) { m_bViewportAndRTArrayIndex = flag; }
  166. bool GetViewportAndRTArrayIndex() const { return m_bViewportAndRTArrayIndex; }
  167. void SetUAVLoadAdditionalFormats(bool flag) { m_bUAVLoadAdditionalFormats = flag; }
  168. bool GetUAVLoadAdditionalFormats() const { return m_bUAVLoadAdditionalFormats; }
  169. void SetLevel9ComparisonFiltering(bool flag) { m_bLevel9ComparisonFiltering = flag; }
  170. bool GetLevel9ComparisonFiltering() const { return m_bLevel9ComparisonFiltering; }
  171. void Set64UAVs(bool flag) { m_b64UAVs = flag; }
  172. bool Get64UAVs() const { return m_b64UAVs; }
  173. void SetUAVsAtEveryStage(bool flag) { m_UAVsAtEveryStage = flag; }
  174. bool GetUAVsAtEveryStage() const { return m_UAVsAtEveryStage; }
  175. void SetViewID(bool flag) { m_bViewID = flag; }
  176. bool GetViewID() const { return m_bViewID; }
  177. void SetBarycentrics(bool flag) { m_bBarycentrics = flag; }
  178. bool GetBarycentrics() const { return m_bBarycentrics; }
  179. static uint64_t GetShaderFlagsRawForCollection(); // some flags are collected (eg use 64-bit), some provided (eg allow refactoring)
  180. uint64_t GetShaderFlagsRaw() const;
  181. void SetShaderFlagsRaw(uint64_t data);
  182. private:
  183. unsigned m_bDisableOptimizations :1; // D3D11_1_SB_GLOBAL_FLAG_SKIP_OPTIMIZATION
  184. unsigned m_bDisableMathRefactoring :1; //~D3D10_SB_GLOBAL_FLAG_REFACTORING_ALLOWED
  185. unsigned m_bEnableDoublePrecision :1; // D3D11_SB_GLOBAL_FLAG_ENABLE_DOUBLE_PRECISION_FLOAT_OPS
  186. unsigned m_bForceEarlyDepthStencil :1; // D3D11_SB_GLOBAL_FLAG_FORCE_EARLY_DEPTH_STENCIL
  187. unsigned m_bEnableRawAndStructuredBuffers :1; // D3D11_SB_GLOBAL_FLAG_ENABLE_RAW_AND_STRUCTURED_BUFFERS
  188. unsigned m_bEnableMinPrecision :1; // D3D11_1_SB_GLOBAL_FLAG_ENABLE_MINIMUM_PRECISION
  189. unsigned m_bEnableDoubleExtensions :1; // D3D11_1_SB_GLOBAL_FLAG_ENABLE_DOUBLE_EXTENSIONS
  190. unsigned m_bEnableMSAD :1; // D3D11_1_SB_GLOBAL_FLAG_ENABLE_SHADER_EXTENSIONS
  191. unsigned m_bAllResourcesBound :1; // D3D12_SB_GLOBAL_FLAG_ALL_RESOURCES_BOUND
  192. unsigned m_bViewportAndRTArrayIndex :1; // SHADER_FEATURE_VIEWPORT_AND_RT_ARRAY_INDEX_FROM_ANY_SHADER_FEEDING_RASTERIZER
  193. unsigned m_bInnerCoverage :1; // SHADER_FEATURE_INNER_COVERAGE
  194. unsigned m_bStencilRef :1; // SHADER_FEATURE_STENCIL_REF
  195. unsigned m_bTiledResources :1; // SHADER_FEATURE_TILED_RESOURCES
  196. unsigned m_bUAVLoadAdditionalFormats :1; // SHADER_FEATURE_TYPED_UAV_LOAD_ADDITIONAL_FORMATS
  197. unsigned m_bLevel9ComparisonFiltering :1; // SHADER_FEATURE_LEVEL_9_COMPARISON_FILTERING
  198. // SHADER_FEATURE_11_1_SHADER_EXTENSIONS shared with EnableMSAD
  199. unsigned m_b64UAVs :1; // SHADER_FEATURE_64_UAVS
  200. unsigned m_UAVsAtEveryStage :1; // SHADER_FEATURE_UAVS_AT_EVERY_STAGE
  201. unsigned m_bCSRawAndStructuredViaShader4X : 1; // SHADER_FEATURE_COMPUTE_SHADERS_PLUS_RAW_AND_STRUCTURED_BUFFERS_VIA_SHADER_4_X
  202. // SHADER_FEATURE_COMPUTE_SHADERS_PLUS_RAW_AND_STRUCTURED_BUFFERS_VIA_SHADER_4_X is specifically
  203. // about shader model 4.x.
  204. unsigned m_bROVS :1; // SHADER_FEATURE_ROVS
  205. unsigned m_bWaveOps :1; // SHADER_FEATURE_WAVE_OPS
  206. unsigned m_bInt64Ops :1; // SHADER_FEATURE_INT64_OPS
  207. unsigned m_bViewID : 1; // SHADER_FEATURE_VIEWID
  208. unsigned m_bBarycentrics : 1; // SHADER_FEATURE_BARYCENTRICS
  209. unsigned m_align0 : 9; // align to 32 bit.
  210. uint32_t m_align1; // align to 64 bit.
  211. };
  212. ShaderFlags m_ShaderFlags;
  213. void CollectShaderFlags(ShaderFlags &Flags);
  214. // Check if DxilModule contains multi component UAV Loads.
  215. // This funciton must be called after unused resources are removed from DxilModule
  216. bool ModuleHasMulticomponentUAVLoads();
  217. // Compute shader.
  218. unsigned m_NumThreads[3];
  219. // Geometry shader.
  220. DXIL::InputPrimitive GetInputPrimitive() const;
  221. void SetInputPrimitive(DXIL::InputPrimitive IP);
  222. unsigned GetMaxVertexCount() const;
  223. void SetMaxVertexCount(unsigned Count);
  224. DXIL::PrimitiveTopology GetStreamPrimitiveTopology() const;
  225. void SetStreamPrimitiveTopology(DXIL::PrimitiveTopology Topology);
  226. bool HasMultipleOutputStreams() const;
  227. unsigned GetOutputStream() const;
  228. unsigned GetGSInstanceCount() const;
  229. void SetGSInstanceCount(unsigned Count);
  230. bool IsStreamActive(unsigned Stream) const;
  231. void SetStreamActive(unsigned Stream, bool bActive);
  232. void SetActiveStreamMask(unsigned Mask);
  233. unsigned GetActiveStreamMask() const;
  234. // Hull and Domain shaders.
  235. unsigned GetInputControlPointCount() const;
  236. void SetInputControlPointCount(unsigned NumICPs);
  237. DXIL::TessellatorDomain GetTessellatorDomain() const;
  238. void SetTessellatorDomain(DXIL::TessellatorDomain TessDomain);
  239. // Hull shader.
  240. unsigned GetOutputControlPointCount() const;
  241. void SetOutputControlPointCount(unsigned NumOCPs);
  242. DXIL::TessellatorPartitioning GetTessellatorPartitioning() const;
  243. void SetTessellatorPartitioning(DXIL::TessellatorPartitioning TessPartitioning);
  244. DXIL::TessellatorOutputPrimitive GetTessellatorOutputPrimitive() const;
  245. void SetTessellatorOutputPrimitive(DXIL::TessellatorOutputPrimitive TessOutputPrimitive);
  246. float GetMaxTessellationFactor() const;
  247. void SetMaxTessellationFactor(float MaxTessellationFactor);
  248. private:
  249. // Signatures.
  250. std::unique_ptr<DxilSignature> m_InputSignature;
  251. std::unique_ptr<DxilSignature> m_OutputSignature;
  252. std::unique_ptr<DxilSignature> m_PatchConstantSignature;
  253. std::unique_ptr<RootSignatureHandle> m_RootSignature;
  254. // Shader resources.
  255. std::vector<std::unique_ptr<DxilResource> > m_SRVs;
  256. std::vector<std::unique_ptr<DxilResource> > m_UAVs;
  257. std::vector<std::unique_ptr<DxilCBuffer> > m_CBuffers;
  258. std::vector<std::unique_ptr<DxilSampler> > m_Samplers;
  259. // Geometry shader.
  260. DXIL::InputPrimitive m_InputPrimitive;
  261. unsigned m_MaxVertexCount;
  262. DXIL::PrimitiveTopology m_StreamPrimitiveTopology;
  263. unsigned m_ActiveStreamMask;
  264. unsigned m_NumGSInstances;
  265. // Hull and Domain shaders.
  266. unsigned m_InputControlPointCount;
  267. DXIL::TessellatorDomain m_TessellatorDomain;
  268. // Hull shader.
  269. unsigned m_OutputControlPointCount;
  270. DXIL::TessellatorPartitioning m_TessellatorPartitioning;
  271. DXIL::TessellatorOutputPrimitive m_TessellatorOutputPrimitive;
  272. float m_MaxTessellationFactor;
  273. private:
  274. llvm::LLVMContext &m_Ctx;
  275. llvm::Module *m_pModule;
  276. llvm::Function *m_pEntryFunc;
  277. llvm::Function *m_pPatchConstantFunc;
  278. std::string m_EntryName;
  279. std::unique_ptr<DxilMDHelper> m_pMDHelper;
  280. std::unique_ptr<llvm::DebugInfoFinder> m_pDebugInfoFinder;
  281. const ShaderModel *m_pSM;
  282. unsigned m_DxilMajor;
  283. unsigned m_DxilMinor;
  284. unsigned m_ValMajor;
  285. unsigned m_ValMinor;
  286. std::unique_ptr<OP> m_pOP;
  287. size_t m_pUnused;
  288. // LLVM used.
  289. std::vector<llvm::GlobalVariable*> m_LLVMUsed;
  290. // Type annotations.
  291. std::unique_ptr<DxilTypeSystem> m_pTypeSystem;
  292. // ViewId state.
  293. std::unique_ptr<DxilViewIdState> m_pViewIdState;
  294. // DXIL metadata serialization/deserialization.
  295. llvm::MDTuple *EmitDxilResources();
  296. void LoadDxilResources(const llvm::MDOperand &MDO);
  297. llvm::MDTuple *EmitDxilShaderProperties();
  298. void LoadDxilShaderProperties(const llvm::MDOperand &MDO);
  299. // Helpers.
  300. template<typename T> unsigned AddResource(std::vector<std::unique_ptr<T> > &Vec, std::unique_ptr<T> pRes);
  301. void LoadDxilSignature(const llvm::MDTuple *pSigTuple, DxilSignature &Sig, bool bInput);
  302. };
  303. } // namespace hlsl