DxilContainerReflection.cpp 87 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilContainerReflection.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Provides support for reading DXIL container structures. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/Bitcode/ReaderWriter.h"
  12. #include "llvm/IR/LLVMContext.h"
  13. #include "llvm/IR/InstIterator.h"
  14. #include "dxc/DxilContainer/DxilContainer.h"
  15. #include "dxc/DXIL/DxilModule.h"
  16. #include "dxc/DXIL/DxilShaderModel.h"
  17. #include "dxc/DXIL/DxilOperations.h"
  18. #include "dxc/DXIL/DxilInstructions.h"
  19. #include "dxc/Support/Global.h"
  20. #include "dxc/Support/Unicode.h"
  21. #include "dxc/Support/WinIncludes.h"
  22. #include "dxc/Support/microcom.h"
  23. #include "dxc/Support/FileIOHelper.h"
  24. #include "dxc/Support/dxcapi.impl.h"
  25. #include "dxc/DXIL/DxilFunctionProps.h"
  26. #include <unordered_set>
  27. #include "llvm/ADT/SetVector.h"
  28. #include "dxc/dxcapi.h"
  29. #ifdef LLVM_ON_WIN32
  30. #include "d3d12shader.h" // for compatibility
  31. #include "d3d11shader.h" // for compatibility
  32. const GUID IID_ID3D11ShaderReflection_43 = {
  33. 0x0a233719,
  34. 0x3960,
  35. 0x4578,
  36. {0x9d, 0x7c, 0x20, 0x3b, 0x8b, 0x1d, 0x9c, 0xc1}};
  37. const GUID IID_ID3D11ShaderReflection_47 = {
  38. 0x8d536ca1,
  39. 0x0cca,
  40. 0x4956,
  41. {0xa8, 0x37, 0x78, 0x69, 0x63, 0x75, 0x55, 0x84}};
  42. using namespace llvm;
  43. using namespace hlsl;
  44. using namespace hlsl::DXIL;
  45. class DxilContainerReflection : public IDxcContainerReflection {
  46. private:
  47. DXC_MICROCOM_TM_REF_FIELDS()
  48. CComPtr<IDxcBlob> m_container;
  49. const DxilContainerHeader *m_pHeader = nullptr;
  50. uint32_t m_headerLen = 0;
  51. bool IsLoaded() const { return m_pHeader != nullptr; }
  52. public:
  53. DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
  54. DXC_MICROCOM_TM_CTOR(DxilContainerReflection)
  55. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
  56. return DoBasicQueryInterface<IDxcContainerReflection>(this, iid, ppvObject);
  57. }
  58. HRESULT STDMETHODCALLTYPE Load(_In_ IDxcBlob *pContainer) override;
  59. HRESULT STDMETHODCALLTYPE GetPartCount(_Out_ UINT32 *pResult) override;
  60. HRESULT STDMETHODCALLTYPE GetPartKind(UINT32 idx, _Out_ UINT32 *pResult) override;
  61. HRESULT STDMETHODCALLTYPE GetPartContent(UINT32 idx, _COM_Outptr_ IDxcBlob **ppResult) override;
  62. HRESULT STDMETHODCALLTYPE FindFirstPartKind(UINT32 kind, _Out_ UINT32 *pResult) override;
  63. HRESULT STDMETHODCALLTYPE GetPartReflection(UINT32 idx, REFIID iid, _COM_Outptr_ void **ppvObject) override;
  64. };
  65. class CShaderReflectionConstantBuffer;
  66. class CShaderReflectionType;
  67. enum class PublicAPI { D3D12 = 0, D3D11_47 = 1, D3D11_43 = 2 };
  68. class DxilModuleReflection {
  69. public:
  70. CComPtr<IDxcBlob> m_pContainer;
  71. LLVMContext Context;
  72. std::unique_ptr<Module> m_pModule; // Must come after LLVMContext, otherwise unique_ptr will over-delete.
  73. DxilModule *m_pDxilModule = nullptr;
  74. std::vector<std::unique_ptr<CShaderReflectionConstantBuffer>> m_CBs;
  75. std::vector<D3D12_SHADER_INPUT_BIND_DESC> m_Resources;
  76. std::vector<std::unique_ptr<CShaderReflectionType>> m_Types;
  77. void CreateReflectionObjects();
  78. void CreateReflectionObjectForResource(DxilResourceBase *R);
  79. HRESULT LoadModule(IDxcBlob *pBlob, const DxilPartHeader *pPart);
  80. // Common code
  81. ID3D12ShaderReflectionConstantBuffer* _GetConstantBufferByIndex(UINT Index);
  82. ID3D12ShaderReflectionConstantBuffer* _GetConstantBufferByName(LPCSTR Name);
  83. HRESULT _GetResourceBindingDesc(UINT ResourceIndex,
  84. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc,
  85. PublicAPI api = PublicAPI::D3D12);
  86. ID3D12ShaderReflectionVariable* _GetVariableByName(LPCSTR Name);
  87. HRESULT _GetResourceBindingDescByName(LPCSTR Name,
  88. D3D12_SHADER_INPUT_BIND_DESC *pDesc,
  89. PublicAPI api = PublicAPI::D3D12);
  90. };
  91. class DxilShaderReflection : public DxilModuleReflection, public ID3D12ShaderReflection {
  92. private:
  93. DXC_MICROCOM_TM_REF_FIELDS()
  94. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> m_InputSignature;
  95. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> m_OutputSignature;
  96. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> m_PatchConstantSignature;
  97. std::vector<std::unique_ptr<char[]>> m_UpperCaseNames;
  98. void SetCBufferUsage();
  99. void CreateReflectionObjectsForSignature(
  100. const DxilSignature &Sig,
  101. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> &Descs);
  102. LPCSTR CreateUpperCase(LPCSTR pValue);
  103. void MarkUsedSignatureElements();
  104. public:
  105. PublicAPI m_PublicAPI;
  106. void SetPublicAPI(PublicAPI value) { m_PublicAPI = value; }
  107. static PublicAPI IIDToAPI(REFIID iid) {
  108. PublicAPI api = PublicAPI::D3D12;
  109. if (IsEqualIID(IID_ID3D11ShaderReflection_43, iid))
  110. api = PublicAPI::D3D11_43;
  111. else if (IsEqualIID(IID_ID3D11ShaderReflection_47, iid))
  112. api = PublicAPI::D3D11_47;
  113. return api;
  114. }
  115. DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
  116. DXC_MICROCOM_TM_CTOR(DxilShaderReflection)
  117. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
  118. HRESULT hr = DoBasicQueryInterface<ID3D12ShaderReflection>(this, iid, ppvObject);
  119. if (hr == E_NOINTERFACE) {
  120. // ID3D11ShaderReflection is identical to ID3D12ShaderReflection, except
  121. // for some shorter data structures in some out parameters.
  122. PublicAPI api = IIDToAPI(iid);
  123. if (api == m_PublicAPI) {
  124. *ppvObject = (ID3D12ShaderReflection *)this;
  125. this->AddRef();
  126. hr = S_OK;
  127. }
  128. }
  129. return hr;
  130. }
  131. HRESULT Load(IDxcBlob *pBlob, const DxilPartHeader *pPart);
  132. // ID3D12ShaderReflection
  133. STDMETHODIMP GetDesc(THIS_ _Out_ D3D12_SHADER_DESC *pDesc);
  134. STDMETHODIMP_(ID3D12ShaderReflectionConstantBuffer*) GetConstantBufferByIndex(THIS_ _In_ UINT Index);
  135. STDMETHODIMP_(ID3D12ShaderReflectionConstantBuffer*) GetConstantBufferByName(THIS_ _In_ LPCSTR Name);
  136. STDMETHODIMP GetResourceBindingDesc(THIS_ _In_ UINT ResourceIndex,
  137. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc);
  138. STDMETHODIMP GetInputParameterDesc(THIS_ _In_ UINT ParameterIndex,
  139. _Out_ D3D12_SIGNATURE_PARAMETER_DESC *pDesc);
  140. STDMETHODIMP GetOutputParameterDesc(THIS_ _In_ UINT ParameterIndex,
  141. _Out_ D3D12_SIGNATURE_PARAMETER_DESC *pDesc);
  142. STDMETHODIMP GetPatchConstantParameterDesc(THIS_ _In_ UINT ParameterIndex,
  143. _Out_ D3D12_SIGNATURE_PARAMETER_DESC *pDesc);
  144. STDMETHODIMP_(ID3D12ShaderReflectionVariable*) GetVariableByName(THIS_ _In_ LPCSTR Name);
  145. STDMETHODIMP GetResourceBindingDescByName(THIS_ _In_ LPCSTR Name,
  146. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc);
  147. STDMETHODIMP_(UINT) GetMovInstructionCount(THIS);
  148. STDMETHODIMP_(UINT) GetMovcInstructionCount(THIS);
  149. STDMETHODIMP_(UINT) GetConversionInstructionCount(THIS);
  150. STDMETHODIMP_(UINT) GetBitwiseInstructionCount(THIS);
  151. STDMETHODIMP_(D3D_PRIMITIVE) GetGSInputPrimitive(THIS);
  152. STDMETHODIMP_(BOOL) IsSampleFrequencyShader(THIS);
  153. STDMETHODIMP_(UINT) GetNumInterfaceSlots(THIS);
  154. STDMETHODIMP GetMinFeatureLevel(THIS_ _Out_ enum D3D_FEATURE_LEVEL* pLevel);
  155. STDMETHODIMP_(UINT) GetThreadGroupSize(THIS_
  156. _Out_opt_ UINT* pSizeX,
  157. _Out_opt_ UINT* pSizeY,
  158. _Out_opt_ UINT* pSizeZ);
  159. STDMETHODIMP_(UINT64) GetRequiresFlags(THIS);
  160. };
  161. class CFunctionReflection;
  162. class DxilLibraryReflection : public DxilModuleReflection, public ID3D12LibraryReflection {
  163. private:
  164. DXC_MICROCOM_TM_REF_FIELDS()
  165. // Storage, and function by name:
  166. typedef DenseMap<StringRef, std::unique_ptr<CFunctionReflection> > FunctionMap;
  167. typedef DenseMap<const Function*, CFunctionReflection*> FunctionsByPtr;
  168. FunctionMap m_FunctionMap;
  169. FunctionsByPtr m_FunctionsByPtr;
  170. // Enable indexing into functions in deterministic order:
  171. std::vector<CFunctionReflection*> m_FunctionVector;
  172. void AddResourceUseToFunctions(DxilResourceBase &resource, unsigned resIndex);
  173. void AddResourceDependencies();
  174. public:
  175. DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
  176. DXC_MICROCOM_TM_CTOR(DxilLibraryReflection)
  177. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
  178. return DoBasicQueryInterface<ID3D12LibraryReflection>(this, iid, ppvObject);
  179. }
  180. HRESULT Load(IDxcBlob *pBlob, const DxilPartHeader *pPart);
  181. // ID3D12LibraryReflection
  182. STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_LIBRARY_DESC * pDesc);
  183. STDMETHOD_(ID3D12FunctionReflection *, GetFunctionByIndex)(THIS_ _In_ INT FunctionIndex);
  184. };
  185. _Use_decl_annotations_
  186. HRESULT DxilContainerReflection::Load(IDxcBlob *pContainer) {
  187. if (pContainer == nullptr) {
  188. m_container.Release();
  189. m_pHeader = nullptr;
  190. m_headerLen = 0;
  191. return S_OK;
  192. }
  193. uint32_t bufLen = pContainer->GetBufferSize();
  194. const DxilContainerHeader *pHeader =
  195. IsDxilContainerLike(pContainer->GetBufferPointer(), bufLen);
  196. if (pHeader == nullptr) {
  197. return E_INVALIDARG;
  198. }
  199. if (!IsValidDxilContainer(pHeader, bufLen)) {
  200. return E_INVALIDARG;
  201. }
  202. m_container = pContainer;
  203. m_headerLen = bufLen;
  204. m_pHeader = pHeader;
  205. return S_OK;
  206. }
  207. _Use_decl_annotations_
  208. HRESULT DxilContainerReflection::GetPartCount(UINT32 *pResult) {
  209. if (pResult == nullptr) return E_POINTER;
  210. if (!IsLoaded()) return E_NOT_VALID_STATE;
  211. *pResult = m_pHeader->PartCount;
  212. return S_OK;
  213. }
  214. _Use_decl_annotations_
  215. HRESULT DxilContainerReflection::GetPartKind(UINT32 idx, _Out_ UINT32 *pResult) {
  216. if (pResult == nullptr) return E_POINTER;
  217. if (!IsLoaded()) return E_NOT_VALID_STATE;
  218. if (idx >= m_pHeader->PartCount) return E_BOUNDS;
  219. const DxilPartHeader *pPart = GetDxilContainerPart(m_pHeader, idx);
  220. *pResult = pPart->PartFourCC;
  221. return S_OK;
  222. }
  223. _Use_decl_annotations_
  224. HRESULT DxilContainerReflection::GetPartContent(UINT32 idx, _COM_Outptr_ IDxcBlob **ppResult) {
  225. if (ppResult == nullptr) return E_POINTER;
  226. *ppResult = nullptr;
  227. if (!IsLoaded()) return E_NOT_VALID_STATE;
  228. if (idx >= m_pHeader->PartCount) return E_BOUNDS;
  229. const DxilPartHeader *pPart = GetDxilContainerPart(m_pHeader, idx);
  230. const char *pData = GetDxilPartData(pPart);
  231. uint32_t offset = (uint32_t)(pData - (char*)m_container->GetBufferPointer()); // Offset from the beginning.
  232. uint32_t length = pPart->PartSize;
  233. DxcThreadMalloc TM(m_pMalloc);
  234. return DxcCreateBlobFromBlob(m_container, offset, length, ppResult);
  235. }
  236. _Use_decl_annotations_
  237. HRESULT DxilContainerReflection::FindFirstPartKind(UINT32 kind, _Out_ UINT32 *pResult) {
  238. if (pResult == nullptr) return E_POINTER;
  239. *pResult = 0;
  240. if (!IsLoaded()) return E_NOT_VALID_STATE;
  241. DxilPartIterator it = std::find_if(begin(m_pHeader), end(m_pHeader), DxilPartIsType(kind));
  242. if (it == end(m_pHeader)) return HRESULT_FROM_WIN32(ERROR_NOT_FOUND);
  243. *pResult = it.index;
  244. return S_OK;
  245. }
  246. _Use_decl_annotations_
  247. HRESULT DxilContainerReflection::GetPartReflection(UINT32 idx, REFIID iid, void **ppvObject) {
  248. if (ppvObject == nullptr) return E_POINTER;
  249. *ppvObject = nullptr;
  250. if (!IsLoaded()) return E_NOT_VALID_STATE;
  251. if (idx >= m_pHeader->PartCount) return E_BOUNDS;
  252. const DxilPartHeader *pPart = GetDxilContainerPart(m_pHeader, idx);
  253. if (pPart->PartFourCC != DFCC_DXIL && pPart->PartFourCC != DFCC_ShaderDebugInfoDXIL) {
  254. return E_NOTIMPL;
  255. }
  256. DxcThreadMalloc TM(m_pMalloc);
  257. HRESULT hr = S_OK;
  258. const DxilProgramHeader *pProgramHeader =
  259. reinterpret_cast<const DxilProgramHeader*>(GetDxilPartData(pPart));
  260. if (!IsValidDxilProgramHeader(pProgramHeader, pPart->PartSize)) {
  261. return E_INVALIDARG;
  262. }
  263. DXIL::ShaderKind SK = GetVersionShaderType(pProgramHeader->ProgramVersion);
  264. if (SK == DXIL::ShaderKind::Library) {
  265. CComPtr<DxilLibraryReflection> pReflection = DxilLibraryReflection::Alloc(m_pMalloc);
  266. IFCOOM(pReflection.p);
  267. IFC(pReflection->Load(m_container, pPart));
  268. IFC(pReflection.p->QueryInterface(iid, ppvObject));
  269. } else {
  270. CComPtr<DxilShaderReflection> pReflection = DxilShaderReflection::Alloc(m_pMalloc);
  271. IFCOOM(pReflection.p);
  272. PublicAPI api = DxilShaderReflection::IIDToAPI(iid);
  273. pReflection->SetPublicAPI(api);
  274. IFC(pReflection->Load(m_container, pPart));
  275. IFC(pReflection.p->QueryInterface(iid, ppvObject));
  276. }
  277. Cleanup:
  278. return hr;
  279. }
  280. void hlsl::CreateDxcContainerReflection(IDxcContainerReflection **ppResult) {
  281. CComPtr<DxilContainerReflection> pReflection = DxilContainerReflection::Alloc(DxcGetThreadMallocNoRef());
  282. *ppResult = pReflection.Detach();
  283. if (*ppResult == nullptr) throw std::bad_alloc();
  284. }
  285. ///////////////////////////////////////////////////////////////////////////////
  286. // DxilShaderReflection implementation - helper objects. //
  287. class CShaderReflectionType;
  288. class CShaderReflectionVariable;
  289. class CShaderReflectionConstantBuffer;
  290. class CShaderReflection;
  291. struct D3D11_INTERNALSHADER_RESOURCE_DEF;
  292. class CShaderReflectionType : public ID3D12ShaderReflectionType
  293. {
  294. protected:
  295. D3D12_SHADER_TYPE_DESC m_Desc;
  296. std::string m_Name;
  297. std::vector<StringRef> m_MemberNames;
  298. std::vector<CShaderReflectionType*> m_MemberTypes;
  299. CShaderReflectionType* m_pSubType;
  300. CShaderReflectionType* m_pBaseClass;
  301. std::vector<CShaderReflectionType*> m_Interfaces;
  302. ULONG_PTR m_Identity;
  303. public:
  304. // Internal
  305. HRESULT Initialize(
  306. DxilModule &M,
  307. llvm::Type *type,
  308. DxilFieldAnnotation &typeAnnotation,
  309. unsigned int baseOffset,
  310. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
  311. // ID3D12ShaderReflectionType
  312. STDMETHOD(GetDesc)(D3D12_SHADER_TYPE_DESC *pDesc);
  313. STDMETHOD_(ID3D12ShaderReflectionType*, GetMemberTypeByIndex)(UINT Index);
  314. STDMETHOD_(ID3D12ShaderReflectionType*, GetMemberTypeByName)(LPCSTR Name);
  315. STDMETHOD_(LPCSTR, GetMemberTypeName)(UINT Index);
  316. STDMETHOD(IsEqual)(THIS_ ID3D12ShaderReflectionType* pType);
  317. STDMETHOD_(ID3D12ShaderReflectionType*, GetSubType)(THIS);
  318. STDMETHOD_(ID3D12ShaderReflectionType*, GetBaseClass)(THIS);
  319. STDMETHOD_(UINT, GetNumInterfaces)(THIS);
  320. STDMETHOD_(ID3D12ShaderReflectionType*, GetInterfaceByIndex)(THIS_ UINT uIndex);
  321. STDMETHOD(IsOfType)(THIS_ ID3D12ShaderReflectionType* pType);
  322. STDMETHOD(ImplementsInterface)(THIS_ ID3D12ShaderReflectionType* pBase);
  323. bool CheckEqual(_In_ CShaderReflectionType *pOther) {
  324. return m_Identity == pOther->m_Identity;
  325. }
  326. };
  327. class CShaderReflectionVariable : public ID3D12ShaderReflectionVariable
  328. {
  329. protected:
  330. D3D12_SHADER_VARIABLE_DESC m_Desc;
  331. CShaderReflectionType *m_pType;
  332. CShaderReflectionConstantBuffer *m_pBuffer;
  333. BYTE *m_pDefaultValue;
  334. public:
  335. void Initialize(CShaderReflectionConstantBuffer *pBuffer,
  336. D3D12_SHADER_VARIABLE_DESC *pDesc,
  337. CShaderReflectionType *pType, BYTE *pDefaultValue);
  338. LPCSTR GetName() { return m_Desc.Name; }
  339. // ID3D12ShaderReflectionVariable
  340. STDMETHOD(GetDesc)(D3D12_SHADER_VARIABLE_DESC *pDesc);
  341. STDMETHOD_(ID3D12ShaderReflectionType*, GetType)();
  342. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer*, GetBuffer)();
  343. STDMETHOD_(UINT, GetInterfaceSlot)(THIS_ UINT uArrayIndex);
  344. };
  345. class CShaderReflectionConstantBuffer : public ID3D12ShaderReflectionConstantBuffer
  346. {
  347. protected:
  348. D3D12_SHADER_BUFFER_DESC m_Desc;
  349. std::vector<CShaderReflectionVariable> m_Variables;
  350. public:
  351. CShaderReflectionConstantBuffer() = default;
  352. CShaderReflectionConstantBuffer(CShaderReflectionConstantBuffer &&other) {
  353. m_Desc = other.m_Desc;
  354. std::swap(m_Variables, other.m_Variables);
  355. }
  356. void Initialize(DxilModule &M,
  357. DxilCBuffer &CB,
  358. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
  359. void InitializeStructuredBuffer(DxilModule &M,
  360. DxilResource &R,
  361. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
  362. LPCSTR GetName() { return m_Desc.Name; }
  363. // ID3D12ShaderReflectionConstantBuffer
  364. STDMETHOD(GetDesc)(D3D12_SHADER_BUFFER_DESC *pDesc);
  365. STDMETHOD_(ID3D12ShaderReflectionVariable*, GetVariableByIndex)(UINT Index);
  366. STDMETHOD_(ID3D12ShaderReflectionVariable*, GetVariableByName)(LPCSTR Name);
  367. };
  368. // Invalid type sentinel definitions
  369. class CInvalidSRType;
  370. class CInvalidSRVariable;
  371. class CInvalidSRConstantBuffer;
  372. class CInvalidSRLibraryFunction;
  373. class CInvalidSRFunctionParameter;
  374. class CInvalidSRType : public ID3D12ShaderReflectionType {
  375. STDMETHOD(GetDesc)(D3D12_SHADER_TYPE_DESC *pDesc) { return E_FAIL; }
  376. STDMETHOD_(ID3D12ShaderReflectionType*, GetMemberTypeByIndex)(UINT Index);
  377. STDMETHOD_(ID3D12ShaderReflectionType*, GetMemberTypeByName)(LPCSTR Name);
  378. STDMETHOD_(LPCSTR, GetMemberTypeName)(UINT Index) { return "$Invalid"; }
  379. STDMETHOD(IsEqual)(THIS_ ID3D12ShaderReflectionType* pType) { return E_FAIL; }
  380. STDMETHOD_(ID3D12ShaderReflectionType*, GetSubType)(THIS);
  381. STDMETHOD_(ID3D12ShaderReflectionType*, GetBaseClass)(THIS);
  382. STDMETHOD_(UINT, GetNumInterfaces)(THIS) { return 0; }
  383. STDMETHOD_(ID3D12ShaderReflectionType*, GetInterfaceByIndex)(THIS_ UINT uIndex);
  384. STDMETHOD(IsOfType)(THIS_ ID3D12ShaderReflectionType* pType) { return E_FAIL; }
  385. STDMETHOD(ImplementsInterface)(THIS_ ID3D12ShaderReflectionType* pBase) { return E_FAIL; }
  386. };
  387. static CInvalidSRType g_InvalidSRType;
  388. ID3D12ShaderReflectionType* CInvalidSRType::GetMemberTypeByIndex(UINT) { return &g_InvalidSRType; }
  389. ID3D12ShaderReflectionType* CInvalidSRType::GetMemberTypeByName(LPCSTR) { return &g_InvalidSRType; }
  390. ID3D12ShaderReflectionType* CInvalidSRType::GetSubType() { return &g_InvalidSRType; }
  391. ID3D12ShaderReflectionType* CInvalidSRType::GetBaseClass() { return &g_InvalidSRType; }
  392. ID3D12ShaderReflectionType* CInvalidSRType::GetInterfaceByIndex(UINT) { return &g_InvalidSRType; }
  393. class CInvalidSRVariable : public ID3D12ShaderReflectionVariable {
  394. STDMETHOD(GetDesc)(D3D12_SHADER_VARIABLE_DESC *pDesc) { return E_FAIL; }
  395. STDMETHOD_(ID3D12ShaderReflectionType*, GetType)() { return &g_InvalidSRType; }
  396. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer*, GetBuffer)();
  397. STDMETHOD_(UINT, GetInterfaceSlot)(THIS_ UINT uIndex) { return UINT_MAX; }
  398. };
  399. static CInvalidSRVariable g_InvalidSRVariable;
  400. class CInvalidSRConstantBuffer : public ID3D12ShaderReflectionConstantBuffer {
  401. STDMETHOD(GetDesc)(D3D12_SHADER_BUFFER_DESC *pDesc) { return E_FAIL; }
  402. STDMETHOD_(ID3D12ShaderReflectionVariable*, GetVariableByIndex)(UINT Index) { return &g_InvalidSRVariable; }
  403. STDMETHOD_(ID3D12ShaderReflectionVariable*, GetVariableByName)(LPCSTR Name) { return &g_InvalidSRVariable; }
  404. };
  405. static CInvalidSRConstantBuffer g_InvalidSRConstantBuffer;
  406. class CInvalidFunctionParameter : public ID3D12FunctionParameterReflection {
  407. STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_PARAMETER_DESC * pDesc) { return E_FAIL; }
  408. };
  409. CInvalidFunctionParameter g_InvalidFunctionParameter;
  410. class CInvalidFunction : public ID3D12FunctionReflection {
  411. STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_FUNCTION_DESC * pDesc) { return E_FAIL; }
  412. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer *, GetConstantBufferByIndex)(THIS_ _In_ UINT BufferIndex) { return &g_InvalidSRConstantBuffer; }
  413. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer *, GetConstantBufferByName)(THIS_ _In_ LPCSTR Name) { return &g_InvalidSRConstantBuffer; }
  414. STDMETHOD(GetResourceBindingDesc)(THIS_ _In_ UINT ResourceIndex,
  415. _Out_ D3D12_SHADER_INPUT_BIND_DESC * pDesc) { return E_FAIL; }
  416. STDMETHOD_(ID3D12ShaderReflectionVariable *, GetVariableByName)(THIS_ _In_ LPCSTR Name) { return nullptr; }
  417. STDMETHOD(GetResourceBindingDescByName)(THIS_ _In_ LPCSTR Name,
  418. _Out_ D3D12_SHADER_INPUT_BIND_DESC * pDesc) { return E_FAIL; }
  419. // Use D3D_RETURN_PARAMETER_INDEX to get description of the return value.
  420. STDMETHOD_(ID3D12FunctionParameterReflection *, GetFunctionParameter)(THIS_ _In_ INT ParameterIndex) { return &g_InvalidFunctionParameter; }
  421. };
  422. CInvalidFunction g_InvalidFunction;
  423. void CShaderReflectionVariable::Initialize(
  424. CShaderReflectionConstantBuffer *pBuffer, D3D12_SHADER_VARIABLE_DESC *pDesc,
  425. CShaderReflectionType *pType, BYTE *pDefaultValue) {
  426. m_pBuffer = pBuffer;
  427. memcpy(&m_Desc, pDesc, sizeof(m_Desc));
  428. m_pType = pType;
  429. m_pDefaultValue = pDefaultValue;
  430. }
  431. HRESULT CShaderReflectionVariable::GetDesc(D3D12_SHADER_VARIABLE_DESC *pDesc) {
  432. if (!pDesc) return E_POINTER;
  433. memcpy(pDesc, &m_Desc, sizeof(m_Desc));
  434. return S_OK;
  435. }
  436. ID3D12ShaderReflectionType *CShaderReflectionVariable::GetType() {
  437. return m_pType;
  438. }
  439. ID3D12ShaderReflectionConstantBuffer *CShaderReflectionVariable::GetBuffer() {
  440. return m_pBuffer;
  441. }
  442. UINT CShaderReflectionVariable::GetInterfaceSlot(UINT uArrayIndex) {
  443. return UINT_MAX;
  444. }
  445. ID3D12ShaderReflectionConstantBuffer *CInvalidSRVariable::GetBuffer() {
  446. return &g_InvalidSRConstantBuffer;
  447. }
  448. STDMETHODIMP CShaderReflectionType::GetDesc(D3D12_SHADER_TYPE_DESC *pDesc)
  449. {
  450. if (!pDesc) return E_POINTER;
  451. memcpy(pDesc, &m_Desc, sizeof(m_Desc));
  452. return S_OK;
  453. }
  454. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetMemberTypeByIndex(UINT Index)
  455. {
  456. if (Index >= m_MemberTypes.size()) {
  457. return &g_InvalidSRType;
  458. }
  459. return m_MemberTypes[Index];
  460. }
  461. STDMETHODIMP_(LPCSTR) CShaderReflectionType::GetMemberTypeName(UINT Index)
  462. {
  463. if (Index >= m_MemberTypes.size()) {
  464. return nullptr;
  465. }
  466. return (LPCSTR) m_MemberNames[Index].bytes_begin();
  467. }
  468. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetMemberTypeByName(LPCSTR Name)
  469. {
  470. UINT memberCount = m_Desc.Members;
  471. for( UINT mm = 0; mm < memberCount; ++mm ) {
  472. if( m_MemberNames[mm] == Name ) {
  473. return m_MemberTypes[mm];
  474. }
  475. }
  476. return nullptr;
  477. }
  478. STDMETHODIMP CShaderReflectionType::IsEqual(THIS_ ID3D12ShaderReflectionType* pType)
  479. {
  480. // TODO: implement this check, if users actually depend on it
  481. return S_FALSE;
  482. }
  483. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetSubType(THIS)
  484. {
  485. // TODO: implement `class`-related features, if requested
  486. return nullptr;
  487. }
  488. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetBaseClass(THIS)
  489. {
  490. // TODO: implement `class`-related features, if requested
  491. return nullptr;
  492. }
  493. STDMETHODIMP_(UINT) CShaderReflectionType::GetNumInterfaces(THIS)
  494. {
  495. // HLSL interfaces have been deprecated
  496. return 0;
  497. }
  498. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetInterfaceByIndex(THIS_ UINT uIndex)
  499. {
  500. // HLSL interfaces have been deprecated
  501. return nullptr;
  502. }
  503. STDMETHODIMP CShaderReflectionType::IsOfType(THIS_ ID3D12ShaderReflectionType* pType)
  504. {
  505. // TODO: implement `class`-related features, if requested
  506. return S_FALSE;
  507. }
  508. STDMETHODIMP CShaderReflectionType::ImplementsInterface(THIS_ ID3D12ShaderReflectionType* pBase)
  509. {
  510. // HLSL interfaces have been deprecated
  511. return S_FALSE;
  512. }
  513. // Helper routine for types that don't have an obvious mapping
  514. // to the existing shader reflection interface.
  515. static bool ProcessUnhandledObjectType(
  516. llvm::StructType *structType,
  517. D3D_SHADER_VARIABLE_TYPE *outObjectType)
  518. {
  519. // Don't actually make this a hard error, but instead report the problem using a suitable debug message.
  520. #ifdef DBG
  521. OutputDebugFormatA("DxilContainerReflection.cpp: error: unhandled object type '%s'.\n", structType->getName().str().c_str());
  522. #endif
  523. *outObjectType = D3D_SVT_VOID;
  524. return true;
  525. }
  526. // Helper routine to try to detect if a type represents an HLSL "object" type
  527. // (a texture, sampler, buffer, etc.), and to extract the coresponding shader
  528. // reflection type.
  529. static bool TryToDetectObjectType(
  530. llvm::StructType *structType,
  531. D3D_SHADER_VARIABLE_TYPE *outObjectType)
  532. {
  533. // Note: This logic is largely duplicated from `dxilutil::IsHLSLObjectType`
  534. // with the addition of returning the appropriate reflection type tag.
  535. //
  536. // That logic looks error-prone, since it relies on string tests against
  537. // type names, including cases that just test against a prefix.
  538. // This code doesn't try to be any more robust.
  539. StringRef name = structType->getName();
  540. if(name.startswith("dx.types.wave_t") )
  541. {
  542. return ProcessUnhandledObjectType(structType, outObjectType);
  543. }
  544. // Strip off some prefixes we are likely to see.
  545. name = name.ltrim("class.");
  546. name = name.ltrim("struct.");
  547. // Slice types occur as intermediates (they aren not objects)
  548. if(name.endswith("_slice_type")) { return false; }
  549. // We might check for an exact name match, or a prefix match
  550. #define EXACT_MATCH(NAME, TAG) \
  551. else if(name == #NAME) do { *outObjectType = TAG; return true; } while(0)
  552. #define PREFIX_MATCH(NAME, TAG) \
  553. else if(name.startswith(#NAME)) do { *outObjectType = TAG; return true; } while(0)
  554. if(0) {}
  555. EXACT_MATCH(SamplerState, D3D_SVT_SAMPLER);
  556. EXACT_MATCH(SamplerComparisonState, D3D_SVT_SAMPLER);
  557. // Note: GS output stream types are supported in the reflection interface.
  558. else if(name.startswith("TriangleStream")) { return ProcessUnhandledObjectType(structType, outObjectType); }
  559. else if(name.startswith("PointStream")) { return ProcessUnhandledObjectType(structType, outObjectType); }
  560. else if(name.startswith("LineStream")) { return ProcessUnhandledObjectType(structType, outObjectType); }
  561. PREFIX_MATCH(AppendStructuredBuffer, D3D_SVT_APPEND_STRUCTURED_BUFFER);
  562. PREFIX_MATCH(ConsumeStructuredBuffer, D3D_SVT_CONSUME_STRUCTURED_BUFFER);
  563. PREFIX_MATCH(ConstantBuffer, D3D_SVT_CBUFFER);
  564. // Note: the `HLModule` code does this trick to avoid checking more names
  565. // than it has to, but it doesn't seem 100% correct to do this.
  566. // TODO: consider just listing the `RasterizerOrdered` cases explicitly,
  567. // just as we do for the `RW` cases already.
  568. name = name.ltrim("RasterizerOrdered");
  569. if(0) {}
  570. EXACT_MATCH(ByteAddressBuffer, D3D_SVT_BYTEADDRESS_BUFFER);
  571. EXACT_MATCH(RWByteAddressBuffer, D3D_SVT_RWBYTEADDRESS_BUFFER);
  572. PREFIX_MATCH(Buffer, D3D_SVT_BUFFER);
  573. PREFIX_MATCH(RWBuffer, D3D_SVT_RWBUFFER);
  574. PREFIX_MATCH(StructuredBuffer, D3D_SVT_STRUCTURED_BUFFER);
  575. PREFIX_MATCH(RWStructuredBuffer, D3D_SVT_RWSTRUCTURED_BUFFER);
  576. PREFIX_MATCH(Texture1D, D3D_SVT_TEXTURE1D);
  577. PREFIX_MATCH(RWTexture1D, D3D_SVT_RWTEXTURE1D);
  578. PREFIX_MATCH(Texture1DArray, D3D_SVT_TEXTURE1DARRAY);
  579. PREFIX_MATCH(RWTexture1DArray, D3D_SVT_RWTEXTURE1DARRAY);
  580. PREFIX_MATCH(Texture2D, D3D_SVT_TEXTURE2D);
  581. PREFIX_MATCH(RWTexture2D, D3D_SVT_RWTEXTURE2D);
  582. PREFIX_MATCH(Texture2DArray, D3D_SVT_TEXTURE2DARRAY);
  583. PREFIX_MATCH(RWTexture2DArray, D3D_SVT_RWTEXTURE2DARRAY);
  584. PREFIX_MATCH(Texture3D, D3D_SVT_TEXTURE3D);
  585. PREFIX_MATCH(RWTexture3D, D3D_SVT_RWTEXTURE3D);
  586. PREFIX_MATCH(TextureCube, D3D_SVT_TEXTURECUBE);
  587. PREFIX_MATCH(TextureCubeArray, D3D_SVT_TEXTURECUBEARRAY);
  588. PREFIX_MATCH(Texture2DMS, D3D_SVT_TEXTURE2DMS);
  589. PREFIX_MATCH(Texture2DMSArray, D3D_SVT_TEXTURE2DMSARRAY);
  590. #undef EXACT_MATCH
  591. #undef PREFIX_MATCH
  592. // Default: not an object type
  593. return false;
  594. }
  595. // Helper to determine if an LLVM type represents an HLSL
  596. // object type (uses the `TryToDetectObjectType()` function
  597. // defined previously).
  598. static bool IsObjectType(
  599. llvm::Type* inType)
  600. {
  601. llvm::Type* type = inType;
  602. while(type->isArrayTy())
  603. {
  604. type = type->getArrayElementType();
  605. }
  606. llvm::StructType* structType = dyn_cast<StructType>(type);
  607. if(!structType)
  608. return false;
  609. D3D_SHADER_VARIABLE_TYPE ignored;
  610. return TryToDetectObjectType(structType, &ignored);
  611. }
  612. // Main logic for translating an LLVM type and associated
  613. // annotations into a D3D shader reflection type.
  614. HRESULT CShaderReflectionType::Initialize(
  615. DxilModule &M,
  616. llvm::Type *inType,
  617. DxilFieldAnnotation &typeAnnotation,
  618. unsigned int baseOffset,
  619. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes)
  620. {
  621. DXASSERT_NOMSG(inType);
  622. // Set a bunch of fields to default values, to avoid duplication.
  623. m_Desc.Rows = 0;
  624. m_Desc.Columns = 0;
  625. m_Desc.Elements = 0;
  626. m_Desc.Members = 0;
  627. // Extract offset relative to parent.
  628. // Note: the `baseOffset` is used in the case where the type in
  629. // question is a field in a constant buffer, since then both the
  630. // field and the variable store the same offset information, and
  631. // we need to zero out the value in the type to avoid the user
  632. // of the reflection interface seeing 2x the correct value.
  633. m_Desc.Offset = typeAnnotation.GetCBufferOffset() - baseOffset;
  634. // Arrays don't seem to be represented directly in the reflection
  635. // data, but only as the `Elements` field being non-zero.
  636. // We "unwrap" any array type here, and then proceed to look
  637. // at the element type.
  638. llvm::Type* type = inType;
  639. while(type->isArrayTy())
  640. {
  641. llvm::Type* elementType = type->getArrayElementType();
  642. // Note: At this point an HLSL matrix type may appear as an ordinary
  643. // array (not wrapped in a `struct`), so `dxilutil::IsHLSLMatrixType()`
  644. // is not sufficient. Instead we need to check the field annotation.
  645. //
  646. // We might have an array of matrices, though, so we only exit if
  647. // the field annotation says we have a matrix, and we've bottomed
  648. // out and the element type isn't itself an array.
  649. if(typeAnnotation.HasMatrixAnnotation() && !elementType->isArrayTy())
  650. {
  651. break;
  652. }
  653. // Non-array types should have `Elements` be zero, so as soon as we
  654. // find that we have our first real array (not a matrix), we initialize `Elements`
  655. if(!m_Desc.Elements) m_Desc.Elements = 1;
  656. // It isn't clear what is the desired behavior for multi-dimensional arrays,
  657. // but for now we do the expedient thing of multiplying out all their
  658. // dimensions.
  659. m_Desc.Elements *= type->getArrayNumElements();
  660. type = elementType;
  661. }
  662. // Default to a scalar type, just to avoid some duplication later.
  663. m_Desc.Class = D3D_SVC_SCALAR;
  664. // Look at the annotation to try to determine the basic type of value.
  665. //
  666. // Note that DXIL supports some types that don't currently have equivalents
  667. // in the reflection interface, so we try to muddle through here.
  668. D3D_SHADER_VARIABLE_TYPE componentType = D3D_SVT_VOID;
  669. switch(typeAnnotation.GetCompType().GetKind())
  670. {
  671. case hlsl::DXIL::ComponentType::Invalid:
  672. break;
  673. case hlsl::DXIL::ComponentType::I1:
  674. componentType = D3D_SVT_BOOL;
  675. m_Name = "bool";
  676. break;
  677. case hlsl::DXIL::ComponentType::I16:
  678. componentType = D3D_SVT_MIN16INT;
  679. m_Name = "min16int";
  680. break;
  681. case hlsl::DXIL::ComponentType::U16:
  682. componentType = D3D_SVT_MIN16UINT;
  683. m_Name = "min16uint";
  684. break;
  685. case hlsl::DXIL::ComponentType::I64:
  686. #ifdef DBG
  687. OutputDebugStringA("DxilContainerReflection.cpp: warning: component of type 'I64' being reflected as if 'I32'\n");
  688. #endif
  689. case hlsl::DXIL::ComponentType::I32:
  690. componentType = D3D_SVT_INT;
  691. m_Name = "int";
  692. break;
  693. case hlsl::DXIL::ComponentType::U64:
  694. #ifdef DBG
  695. OutputDebugStringA("DxilContainerReflection.cpp: warning: component of type 'U64' being reflected as if 'U32'\n");
  696. #endif
  697. case hlsl::DXIL::ComponentType::U32:
  698. componentType = D3D_SVT_UINT;
  699. m_Name = "uint";
  700. break;
  701. case hlsl::DXIL::ComponentType::F16:
  702. case hlsl::DXIL::ComponentType::SNormF16:
  703. case hlsl::DXIL::ComponentType::UNormF16:
  704. componentType = D3D_SVT_MIN16FLOAT;
  705. m_Name = "min16float";
  706. break;
  707. case hlsl::DXIL::ComponentType::F32:
  708. case hlsl::DXIL::ComponentType::SNormF32:
  709. case hlsl::DXIL::ComponentType::UNormF32:
  710. componentType = D3D_SVT_FLOAT;
  711. m_Name = "float";
  712. break;
  713. case hlsl::DXIL::ComponentType::F64:
  714. case hlsl::DXIL::ComponentType::SNormF64:
  715. case hlsl::DXIL::ComponentType::UNormF64:
  716. componentType = D3D_SVT_DOUBLE;
  717. m_Name = "double";
  718. break;
  719. default:
  720. #ifdef DBG
  721. OutputDebugStringA("DxilContainerReflection.cpp: error: unknown component type\n");
  722. #endif
  723. break;
  724. }
  725. m_Desc.Type = componentType;
  726. // A matrix type is encoded as a vector type, plus annotations, so we
  727. // need to check for this case before other vector cases.
  728. if(typeAnnotation.HasMatrixAnnotation())
  729. {
  730. // We can extract the details from the annotation.
  731. DxilMatrixAnnotation const& matrixAnnotation = typeAnnotation.GetMatrixAnnotation();
  732. switch(matrixAnnotation.Orientation)
  733. {
  734. default:
  735. #ifdef DBG
  736. OutputDebugStringA("DxilContainerReflection.cpp: error: unknown matrix orientation\n");
  737. #endif
  738. // Note: column-major layout is the default
  739. case hlsl::MatrixOrientation::Undefined:
  740. case hlsl::MatrixOrientation::ColumnMajor:
  741. m_Desc.Class = D3D_SVC_MATRIX_COLUMNS;
  742. break;
  743. case hlsl::MatrixOrientation::RowMajor:
  744. m_Desc.Class = D3D_SVC_MATRIX_ROWS;
  745. break;
  746. }
  747. m_Desc.Rows = matrixAnnotation.Rows;
  748. m_Desc.Columns = matrixAnnotation.Cols;
  749. m_Name += std::to_string(matrixAnnotation.Rows) + "x" + std::to_string(matrixAnnotation.Cols);
  750. }
  751. else if( type->isVectorTy() )
  752. {
  753. // We assume that LLVM vectors either represent matrices (handled above)
  754. // or HLSL vectors.
  755. //
  756. // Note: the reflection interface encodes an N-vector as if it had 1 row
  757. // and N columns.
  758. m_Desc.Class = D3D_SVC_VECTOR;
  759. m_Desc.Rows = 1;
  760. m_Desc.Columns = type->getVectorNumElements();
  761. m_Name += std::to_string(type->getVectorNumElements());
  762. }
  763. else if( type->isStructTy() )
  764. {
  765. // A struct type might be an ordinary user-defined `struct`,
  766. // or one of the builtin in HLSL "object" types.
  767. StructType *structType = cast<StructType>(type);
  768. // We use our function to try to detect an object type
  769. // based on its name.
  770. if(TryToDetectObjectType(structType, &m_Desc.Type))
  771. {
  772. m_Desc.Class = D3D_SVC_OBJECT;
  773. }
  774. else
  775. {
  776. // Otherwise we have a struct and need to recurse on its fields.
  777. m_Desc.Class = D3D_SVC_STRUCT;
  778. m_Desc.Rows = 1;
  779. // Try to "clean" the type name for use in reflection data
  780. llvm::StringRef name = structType->getName();
  781. name = name.ltrim("dx.alignment.legacy.");
  782. name = name.ltrim("struct.");
  783. m_Name = name;
  784. // Fields may have annotations, and we need to look at these
  785. // in order to decode their types properly.
  786. DxilTypeSystem &typeSys = M.GetTypeSystem();
  787. DxilStructAnnotation *structAnnotation = typeSys.GetStructAnnotation(structType);
  788. // There is no annotation for empty structs
  789. unsigned int fieldCount = 0;
  790. if (structAnnotation)
  791. fieldCount = type->getStructNumElements();
  792. // The DXBC reflection info computes `Columns` for a
  793. // `struct` type from the fields (see below)
  794. UINT columnCounter = 0;
  795. for(unsigned int ff = 0; ff < fieldCount; ++ff)
  796. {
  797. DxilFieldAnnotation& fieldAnnotation = structAnnotation->GetFieldAnnotation(ff);
  798. llvm::Type* fieldType = structType->getStructElementType(ff);
  799. // Skip fields with object types, since applications may not expect to see them here.
  800. //
  801. // TODO: should skipping be context-dependent, since we might not be inside
  802. // a constant buffer?
  803. if( IsObjectType(fieldType) )
  804. {
  805. continue;
  806. }
  807. CShaderReflectionType *fieldReflectionType = new CShaderReflectionType();
  808. allTypes.push_back(std::unique_ptr<CShaderReflectionType>(fieldReflectionType));
  809. fieldReflectionType->Initialize(M, fieldType, fieldAnnotation, 0, allTypes);
  810. m_MemberTypes.push_back(fieldReflectionType);
  811. m_MemberNames.push_back(fieldAnnotation.GetFieldName().c_str());
  812. // Effectively, we want to add one to `Columns` for every scalar nested recursively
  813. // inside this `struct` type (ignoring objects, which we filtered above). We should
  814. // be able to compute this as the product of the `Columns`, `Rows` and `Elements`
  815. // of each field, with the caveat that some of these may be zero, but shoud be
  816. // treated as one.
  817. columnCounter +=
  818. (fieldReflectionType->m_Desc.Columns ? fieldReflectionType->m_Desc.Columns : 1)
  819. * (fieldReflectionType->m_Desc.Rows ? fieldReflectionType->m_Desc.Rows : 1)
  820. * (fieldReflectionType->m_Desc.Elements ? fieldReflectionType->m_Desc.Elements : 1);
  821. }
  822. m_Desc.Columns = columnCounter;
  823. // Because we might have skipped fields during enumeration,
  824. // the `Members` count in the description might not be the same
  825. // as the field count of the original LLVM type.
  826. m_Desc.Members = m_MemberTypes.size();
  827. }
  828. }
  829. else if( type->isPointerTy() )
  830. {
  831. #ifdef DBG
  832. OutputDebugStringA("DxilContainerReflection.cpp: error: cannot reflect pointer type\n");
  833. #endif
  834. }
  835. else if( type->isVoidTy() )
  836. {
  837. // Name for `void` wasn't handle in the component-type `switch` above
  838. m_Name = "void";
  839. m_Desc.Class = D3D_SVC_SCALAR;
  840. m_Desc.Rows = 1;
  841. m_Desc.Columns = 1;
  842. }
  843. else
  844. {
  845. // Assume we have a scalar at this point.
  846. m_Desc.Class = D3D_SVC_SCALAR;
  847. m_Desc.Rows = 1;
  848. m_Desc.Columns = 1;
  849. // Special-case naming
  850. switch(m_Desc.Type)
  851. {
  852. default:
  853. break;
  854. case D3D_SVT_UINT:
  855. // Scalar `uint` gets reflected as `dword`, while vectors/matrices use `uint`...
  856. m_Name = "dword";
  857. break;
  858. }
  859. }
  860. // TODO: are there other cases to be handled?
  861. m_Desc.Name = m_Name.c_str();
  862. return S_OK;
  863. }
  864. void CShaderReflectionConstantBuffer::Initialize(
  865. DxilModule &M,
  866. DxilCBuffer &CB,
  867. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes) {
  868. ZeroMemory(&m_Desc, sizeof(m_Desc));
  869. m_Desc.Name = CB.GetGlobalName().c_str();
  870. m_Desc.Size = CB.GetSize() / CB.GetRangeSize();
  871. m_Desc.Size = (m_Desc.Size + 0x0f) & ~(0x0f); // Round up to 16 bytes for reflection.
  872. m_Desc.Type = D3D_CT_CBUFFER;
  873. m_Desc.uFlags = 0;
  874. Type *Ty = CB.GetGlobalSymbol()->getType()->getPointerElementType();
  875. // For ConstantBuffer<> buf[2], the array size is in Resource binding count
  876. // part.
  877. if (Ty->isArrayTy())
  878. Ty = Ty->getArrayElementType();
  879. DxilTypeSystem &typeSys = M.GetTypeSystem();
  880. StructType *ST = cast<StructType>(Ty);
  881. DxilStructAnnotation *annotation =
  882. typeSys.GetStructAnnotation(cast<StructType>(ST));
  883. // Dxil from dxbc doesn't have annotation.
  884. if (!annotation)
  885. return;
  886. m_Desc.Variables = ST->getNumContainedTypes();
  887. unsigned lastIndex = ST->getNumContainedTypes() - 1;
  888. for (unsigned i = 0; i < ST->getNumContainedTypes(); ++i) {
  889. DxilFieldAnnotation &fieldAnnotation = annotation->GetFieldAnnotation(i);
  890. D3D12_SHADER_VARIABLE_DESC VarDesc;
  891. ZeroMemory(&VarDesc, sizeof(VarDesc));
  892. VarDesc.uFlags |= D3D_SVF_USED; // Will update in SetCBufferUsage.
  893. CShaderReflectionVariable Var;
  894. //Create reflection type.
  895. CShaderReflectionType *pVarType = new CShaderReflectionType();
  896. allTypes.push_back(std::unique_ptr<CShaderReflectionType>(pVarType));
  897. pVarType->Initialize(M, ST->getContainedType(i), fieldAnnotation, fieldAnnotation.GetCBufferOffset(), allTypes);
  898. BYTE *pDefaultValue = nullptr;
  899. VarDesc.Name = fieldAnnotation.GetFieldName().c_str();
  900. VarDesc.StartOffset = fieldAnnotation.GetCBufferOffset();
  901. if (i < lastIndex) {
  902. DxilFieldAnnotation &nextFieldAnnotation =
  903. annotation->GetFieldAnnotation(i + 1);
  904. VarDesc.Size = nextFieldAnnotation.GetCBufferOffset() - fieldAnnotation.GetCBufferOffset();
  905. }
  906. else {
  907. VarDesc.Size = CB.GetSize() - fieldAnnotation.GetCBufferOffset();
  908. }
  909. Var.Initialize(this, &VarDesc, pVarType, pDefaultValue);
  910. m_Variables.push_back(Var);
  911. }
  912. }
  913. static unsigned CalcTypeSize(Type *Ty) {
  914. // Assume aligned values.
  915. if (Ty->isIntegerTy() || Ty->isFloatTy()) {
  916. return Ty->getPrimitiveSizeInBits() / 8;
  917. }
  918. else if (Ty->isArrayTy()) {
  919. ArrayType *AT = dyn_cast<ArrayType>(Ty);
  920. return AT->getNumElements() * CalcTypeSize(AT->getArrayElementType());
  921. }
  922. else if (Ty->isStructTy()) {
  923. StructType *ST = dyn_cast<StructType>(Ty);
  924. unsigned i = 0, c = ST->getStructNumElements();
  925. unsigned result = 0;
  926. for (; i < c; ++i) {
  927. result += CalcTypeSize(ST->getStructElementType(i));
  928. // TODO: align!
  929. }
  930. return result;
  931. }
  932. else if (Ty->isVectorTy()) {
  933. VectorType *VT = dyn_cast<VectorType>(Ty);
  934. return VT->getVectorNumElements() * CalcTypeSize(VT->getVectorElementType());
  935. }
  936. else {
  937. DXASSERT_NOMSG(false);
  938. return 0;
  939. }
  940. }
  941. static unsigned CalcResTypeSize(DxilModule &M, DxilResource &R) {
  942. UNREFERENCED_PARAMETER(M);
  943. Type *Ty = R.GetGlobalSymbol()->getType()->getPointerElementType();
  944. return CalcTypeSize(Ty);
  945. }
  946. void CShaderReflectionConstantBuffer::InitializeStructuredBuffer(
  947. DxilModule &M,
  948. DxilResource &R,
  949. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes) {
  950. ZeroMemory(&m_Desc, sizeof(m_Desc));
  951. m_Desc.Name = R.GetGlobalName().c_str();
  952. //m_Desc.Size = R.GetSize();
  953. m_Desc.Type = D3D11_CT_RESOURCE_BIND_INFO;
  954. m_Desc.uFlags = 0;
  955. m_Desc.Variables = 1;
  956. D3D12_SHADER_VARIABLE_DESC VarDesc;
  957. ZeroMemory(&VarDesc, sizeof(VarDesc));
  958. VarDesc.Name = "$Element";
  959. VarDesc.Size = CalcResTypeSize(M, R); // aligned bytes
  960. VarDesc.StartTexture = UINT_MAX;
  961. VarDesc.StartSampler = UINT_MAX;
  962. VarDesc.uFlags |= D3D_SVF_USED; // TODO: not necessarily true
  963. CShaderReflectionVariable Var;
  964. CShaderReflectionType *pVarType = nullptr;
  965. // Create reflection type, if we have the necessary annotation info
  966. // Extract the `struct` that wraps element type of the buffer resource
  967. Type *Ty = R.GetGlobalSymbol()->getType()->getPointerElementType();
  968. if(Ty->isArrayTy())
  969. Ty = Ty->getArrayElementType();
  970. StructType *ST = cast<StructType>(Ty);
  971. // Look up struct type annotation on the element type
  972. DxilTypeSystem &typeSys = M.GetTypeSystem();
  973. DxilStructAnnotation *annotation =
  974. typeSys.GetStructAnnotation(cast<StructType>(ST));
  975. // Dxil from dxbc doesn't have annotation.
  976. if(annotation)
  977. {
  978. // Actually create the reflection type.
  979. pVarType = new CShaderReflectionType();
  980. allTypes.push_back(std::unique_ptr<CShaderReflectionType>(pVarType));
  981. // The user-visible element type is the first field of the wrapepr `struct`
  982. Type *fieldType = ST->getElementType(0);
  983. DxilFieldAnnotation &fieldAnnotation = annotation->GetFieldAnnotation(0);
  984. pVarType->Initialize(M, fieldType, fieldAnnotation, fieldAnnotation.GetCBufferOffset(), allTypes);
  985. }
  986. BYTE *pDefaultValue = nullptr;
  987. Var.Initialize(this, &VarDesc, pVarType, pDefaultValue);
  988. m_Variables.push_back(Var);
  989. m_Desc.Size = VarDesc.Size;
  990. }
  991. HRESULT CShaderReflectionConstantBuffer::GetDesc(D3D12_SHADER_BUFFER_DESC *pDesc) {
  992. if (!pDesc)
  993. return E_POINTER;
  994. memcpy(pDesc, &m_Desc, sizeof(m_Desc));
  995. return S_OK;
  996. }
  997. ID3D12ShaderReflectionVariable *
  998. CShaderReflectionConstantBuffer::GetVariableByIndex(UINT Index) {
  999. if (Index >= m_Variables.size()) {
  1000. return &g_InvalidSRVariable;
  1001. }
  1002. return &m_Variables[Index];
  1003. }
  1004. ID3D12ShaderReflectionVariable *
  1005. CShaderReflectionConstantBuffer::GetVariableByName(LPCSTR Name) {
  1006. UINT index;
  1007. if (NULL == Name) {
  1008. return &g_InvalidSRVariable;
  1009. }
  1010. for (index = 0; index < m_Variables.size(); ++index) {
  1011. if (0 == strcmp(m_Variables[index].GetName(), Name)) {
  1012. return &m_Variables[index];
  1013. }
  1014. }
  1015. return &g_InvalidSRVariable;
  1016. }
  1017. ///////////////////////////////////////////////////////////////////////////////
  1018. // DxilShaderReflection implementation. //
  1019. static DxilResource *DxilResourceFromBase(DxilResourceBase *RB) {
  1020. DxilResourceBase::Class C = RB->GetClass();
  1021. if (C == DXIL::ResourceClass::UAV || C == DXIL::ResourceClass::SRV)
  1022. return (DxilResource *)RB;
  1023. return nullptr;
  1024. }
  1025. static D3D_SHADER_INPUT_TYPE ResourceToShaderInputType(DxilResourceBase *RB) {
  1026. DxilResource *R = DxilResourceFromBase(RB);
  1027. bool isUAV = RB->GetClass() == DxilResourceBase::Class::UAV;
  1028. switch (RB->GetKind()) {
  1029. case DxilResource::Kind::CBuffer:
  1030. return D3D_SIT_CBUFFER;
  1031. case DxilResource::Kind::Sampler:
  1032. return D3D_SIT_SAMPLER;
  1033. case DxilResource::Kind::RawBuffer:
  1034. return isUAV ? D3D_SIT_UAV_RWBYTEADDRESS : D3D_SIT_BYTEADDRESS;
  1035. case DxilResource::Kind::StructuredBuffer: {
  1036. if (!isUAV) return D3D_SIT_STRUCTURED;
  1037. // TODO: D3D_SIT_UAV_CONSUME_STRUCTURED, D3D_SIT_UAV_APPEND_STRUCTURED?
  1038. if (R->HasCounter()) return D3D_SIT_UAV_RWSTRUCTURED_WITH_COUNTER;
  1039. return D3D_SIT_UAV_RWSTRUCTURED;
  1040. }
  1041. case DxilResource::Kind::TBuffer:
  1042. case DxilResource::Kind::TypedBuffer:
  1043. case DxilResource::Kind::Texture1D:
  1044. case DxilResource::Kind::Texture1DArray:
  1045. case DxilResource::Kind::Texture2D:
  1046. case DxilResource::Kind::Texture2DArray:
  1047. case DxilResource::Kind::Texture2DMS:
  1048. case DxilResource::Kind::Texture2DMSArray:
  1049. case DxilResource::Kind::Texture3D:
  1050. case DxilResource::Kind::TextureCube:
  1051. case DxilResource::Kind::TextureCubeArray:
  1052. return isUAV ? D3D_SIT_UAV_RWTYPED : D3D_SIT_TEXTURE;
  1053. case DxilResource::Kind::RTAccelerationStructure:
  1054. return (D3D_SHADER_INPUT_TYPE)D3D_SIT_RTACCELERATIONSTRUCTURE;
  1055. default:
  1056. return (D3D_SHADER_INPUT_TYPE)-1;
  1057. }
  1058. }
  1059. static D3D_RESOURCE_RETURN_TYPE ResourceToReturnType(DxilResourceBase *RB) {
  1060. DxilResource *R = DxilResourceFromBase(RB);
  1061. if (R != nullptr) {
  1062. CompType CT = R->GetCompType();
  1063. if (CT.GetKind() == CompType::Kind::F64) return D3D_RETURN_TYPE_DOUBLE;
  1064. if (CT.IsUNorm()) return D3D_RETURN_TYPE_UNORM;
  1065. if (CT.IsSNorm()) return D3D_RETURN_TYPE_SNORM;
  1066. if (CT.IsSIntTy()) return D3D_RETURN_TYPE_SINT;
  1067. if (CT.IsUIntTy()) return D3D_RETURN_TYPE_UINT;
  1068. if (CT.IsFloatTy()) return D3D_RETURN_TYPE_FLOAT;
  1069. // D3D_RETURN_TYPE_CONTINUED: Return type is a multiple-dword type, such as a
  1070. // double or uint64, and the component is continued from the previous
  1071. // component that was declared. The first component represents the lower bits.
  1072. return D3D_RETURN_TYPE_MIXED;
  1073. }
  1074. return (D3D_RESOURCE_RETURN_TYPE)0;
  1075. }
  1076. static D3D_SRV_DIMENSION ResourceToDimension(DxilResourceBase *RB) {
  1077. switch (RB->GetKind()) {
  1078. case DxilResource::Kind::StructuredBuffer:
  1079. case DxilResource::Kind::TypedBuffer:
  1080. case DxilResource::Kind::TBuffer:
  1081. return D3D_SRV_DIMENSION_BUFFER;
  1082. case DxilResource::Kind::Texture1D:
  1083. return D3D_SRV_DIMENSION_TEXTURE1D;
  1084. case DxilResource::Kind::Texture1DArray:
  1085. return D3D_SRV_DIMENSION_TEXTURE1DARRAY;
  1086. case DxilResource::Kind::Texture2D:
  1087. return D3D_SRV_DIMENSION_TEXTURE2D;
  1088. case DxilResource::Kind::Texture2DArray:
  1089. return D3D_SRV_DIMENSION_TEXTURE2DARRAY;
  1090. case DxilResource::Kind::Texture2DMS:
  1091. return D3D_SRV_DIMENSION_TEXTURE2DMS;
  1092. case DxilResource::Kind::Texture2DMSArray:
  1093. return D3D_SRV_DIMENSION_TEXTURE2DMSARRAY;
  1094. case DxilResource::Kind::Texture3D:
  1095. return D3D_SRV_DIMENSION_TEXTURE3D;
  1096. case DxilResource::Kind::TextureCube:
  1097. return D3D_SRV_DIMENSION_TEXTURECUBE;
  1098. case DxilResource::Kind::TextureCubeArray:
  1099. return D3D_SRV_DIMENSION_TEXTURECUBEARRAY;
  1100. case DxilResource::Kind::RawBuffer:
  1101. return D3D11_SRV_DIMENSION_BUFFER; // D3D11_SRV_DIMENSION_BUFFEREX?
  1102. default:
  1103. return D3D_SRV_DIMENSION_UNKNOWN;
  1104. }
  1105. }
  1106. static UINT ResourceToFlags(DxilResourceBase *RB) {
  1107. UINT result = 0;
  1108. DxilResource *R = DxilResourceFromBase(RB);
  1109. if (R != nullptr &&
  1110. (R->IsAnyTexture() || R->GetKind() == DXIL::ResourceKind::TypedBuffer)) {
  1111. llvm::Type *RetTy = R->GetRetType();
  1112. if (VectorType *VT = dyn_cast<VectorType>(RetTy)) {
  1113. unsigned vecSize = VT->getNumElements();
  1114. switch (vecSize) {
  1115. case 4:
  1116. result |= D3D_SIF_TEXTURE_COMPONENTS;
  1117. break;
  1118. case 3:
  1119. result |= D3D_SIF_TEXTURE_COMPONENT_1;
  1120. break;
  1121. case 2:
  1122. result |= D3D_SIF_TEXTURE_COMPONENT_0;
  1123. break;
  1124. }
  1125. }
  1126. }
  1127. // D3D_SIF_USERPACKED
  1128. if (RB->GetClass() == DXIL::ResourceClass::Sampler) {
  1129. DxilSampler *S = static_cast<DxilSampler *>(RB);
  1130. if (S->GetSamplerKind() == DXIL::SamplerKind::Comparison)
  1131. result |= D3D_SIF_COMPARISON_SAMPLER;
  1132. }
  1133. return result;
  1134. }
  1135. void DxilModuleReflection::CreateReflectionObjectForResource(DxilResourceBase *RB) {
  1136. DxilResourceBase::Class C = RB->GetClass();
  1137. DxilResource *R =
  1138. (C == DXIL::ResourceClass::UAV || C == DXIL::ResourceClass::SRV)
  1139. ? (DxilResource *)RB
  1140. : nullptr;
  1141. D3D12_SHADER_INPUT_BIND_DESC inputBind;
  1142. ZeroMemory(&inputBind, sizeof(inputBind));
  1143. inputBind.BindCount = RB->GetRangeSize();
  1144. if (RB->GetRangeSize() == UINT_MAX)
  1145. inputBind.BindCount = 0;
  1146. inputBind.BindPoint = RB->GetLowerBound();
  1147. inputBind.Dimension = ResourceToDimension(RB);
  1148. inputBind.Name = RB->GetGlobalName().c_str();
  1149. inputBind.Type = ResourceToShaderInputType(RB);
  1150. if (R == nullptr) {
  1151. inputBind.NumSamples = 0;
  1152. }
  1153. else {
  1154. inputBind.NumSamples = R->GetSampleCount();
  1155. if (inputBind.NumSamples == 0) {
  1156. if (R->IsStructuredBuffer()) {
  1157. inputBind.NumSamples = CalcResTypeSize(*m_pDxilModule, *R);
  1158. }
  1159. else if (!R->IsRawBuffer()) {
  1160. inputBind.NumSamples = 0xFFFFFFFF;
  1161. }
  1162. }
  1163. }
  1164. inputBind.ReturnType = ResourceToReturnType(RB);
  1165. inputBind.Space = RB->GetSpaceID();
  1166. inputBind.uFlags = ResourceToFlags(RB);
  1167. inputBind.uID = RB->GetID();
  1168. m_Resources.push_back(inputBind);
  1169. }
  1170. // Find the imm offset part from a value.
  1171. // It must exist unless offset is 0.
  1172. static unsigned GetCBOffset(Value *V) {
  1173. if (ConstantInt *Imm = dyn_cast<ConstantInt>(V))
  1174. return Imm->getLimitedValue();
  1175. else if (UnaryInstruction *UI = dyn_cast<UnaryInstruction>(V)) {
  1176. return 0;
  1177. } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V)) {
  1178. switch (BO->getOpcode()) {
  1179. case Instruction::Add: {
  1180. unsigned left = GetCBOffset(BO->getOperand(0));
  1181. unsigned right = GetCBOffset(BO->getOperand(1));
  1182. return left + right;
  1183. } break;
  1184. case Instruction::Or: {
  1185. unsigned left = GetCBOffset(BO->getOperand(0));
  1186. unsigned right = GetCBOffset(BO->getOperand(1));
  1187. return left | right;
  1188. } break;
  1189. default:
  1190. return 0;
  1191. }
  1192. } else {
  1193. return 0;
  1194. }
  1195. }
  1196. void CollectInPhiChain(PHINode *cbUser, std::vector<unsigned> &cbufUsage,
  1197. unsigned offset, std::unordered_set<Value *> &userSet) {
  1198. if (userSet.count(cbUser) > 0)
  1199. return;
  1200. userSet.insert(cbUser);
  1201. for (User *cbU : cbUser->users()) {
  1202. if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(cbU)) {
  1203. for (unsigned idx : EV->getIndices()) {
  1204. cbufUsage.emplace_back(offset + idx * 4);
  1205. }
  1206. } else {
  1207. PHINode *phi = cast<PHINode>(cbU);
  1208. CollectInPhiChain(phi, cbufUsage, offset, userSet);
  1209. }
  1210. }
  1211. }
  1212. static void CollectCBufUsage(Value *cbHandle,
  1213. std::vector<unsigned> &cbufUsage) {
  1214. for (User *U : cbHandle->users()) {
  1215. CallInst *CI = cast<CallInst>(U);
  1216. ConstantInt *opcodeV =
  1217. cast<ConstantInt>(CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx));
  1218. DXIL::OpCode opcode = static_cast<DXIL::OpCode>(opcodeV->getLimitedValue());
  1219. if (opcode == DXIL::OpCode::CBufferLoadLegacy) {
  1220. DxilInst_CBufferLoadLegacy cbload(CI);
  1221. Value *resIndex = cbload.get_regIndex();
  1222. unsigned offset = GetCBOffset(resIndex);
  1223. // 16 bytes align.
  1224. offset <<= 4;
  1225. for (User *cbU : U->users()) {
  1226. if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(cbU)) {
  1227. for (unsigned idx : EV->getIndices()) {
  1228. cbufUsage.emplace_back(offset + idx * 4);
  1229. }
  1230. } else {
  1231. PHINode *phi = cast<PHINode>(cbU);
  1232. std::unordered_set<Value *> userSet;
  1233. CollectInPhiChain(phi, cbufUsage, offset, userSet);
  1234. }
  1235. }
  1236. } else if (opcode == DXIL::OpCode::CBufferLoad) {
  1237. DxilInst_CBufferLoad cbload(CI);
  1238. Value *byteOffset = cbload.get_byteOffset();
  1239. unsigned offset = GetCBOffset(byteOffset);
  1240. cbufUsage.emplace_back(offset);
  1241. } else {
  1242. //
  1243. DXASSERT(0, "invalid opcode");
  1244. }
  1245. }
  1246. }
  1247. static void SetCBufVarUsage(CShaderReflectionConstantBuffer &cb,
  1248. std::vector<unsigned> usage) {
  1249. D3D12_SHADER_BUFFER_DESC Desc;
  1250. if (FAILED(cb.GetDesc(&Desc)))
  1251. return;
  1252. unsigned size = Desc.Variables;
  1253. std::sort(usage.begin(), usage.end());
  1254. for (unsigned i = 0; i < size; i++) {
  1255. ID3D12ShaderReflectionVariable *pVar = cb.GetVariableByIndex(i);
  1256. D3D12_SHADER_VARIABLE_DESC VarDesc;
  1257. if (FAILED(pVar->GetDesc(&VarDesc)))
  1258. continue;
  1259. if (!pVar)
  1260. continue;
  1261. unsigned begin = VarDesc.StartOffset;
  1262. unsigned end = begin + VarDesc.Size;
  1263. auto beginIt = std::find_if(usage.begin(), usage.end(),
  1264. [&](unsigned v) { return v >= begin; });
  1265. auto endIt = std::find_if(usage.begin(), usage.end(),
  1266. [&](unsigned v) { return v >= end; });
  1267. bool used = beginIt != endIt;
  1268. // Clear used.
  1269. if (!used) {
  1270. CShaderReflectionType *pVarType = (CShaderReflectionType *)pVar->GetType();
  1271. BYTE *pDefaultValue = nullptr;
  1272. VarDesc.uFlags &= ~D3D_SVF_USED;
  1273. CShaderReflectionVariable *pCVarDesc = (CShaderReflectionVariable*)pVar;
  1274. pCVarDesc->Initialize(&cb, &VarDesc, pVarType, pDefaultValue);
  1275. }
  1276. }
  1277. }
  1278. void DxilShaderReflection::SetCBufferUsage() {
  1279. hlsl::OP *hlslOP = m_pDxilModule->GetOP();
  1280. LLVMContext &Ctx = m_pDxilModule->GetCtx();
  1281. // Indexes >= cbuffer size from DxilModule are SRV or UAV structured buffers.
  1282. // We only collect usage for actual cbuffers, so don't go clearing usage on other buffers.
  1283. unsigned cbSize = std::min(m_CBs.size(), m_pDxilModule->GetCBuffers().size());
  1284. std::vector< std::vector<unsigned> > cbufUsage(cbSize);
  1285. Function *createHandle = hlslOP->GetOpFunc(DXIL::OpCode::CreateHandle, Type::getVoidTy(Ctx));
  1286. if (createHandle->user_empty()) {
  1287. createHandle->eraseFromParent();
  1288. return;
  1289. }
  1290. // Find all cb handles.
  1291. for (User *U : createHandle->users()) {
  1292. DxilInst_CreateHandle handle(cast<CallInst>(U));
  1293. Value *resClass = handle.get_resourceClass();
  1294. ConstantInt *immResClass = cast<ConstantInt>(resClass);
  1295. if (immResClass->getLimitedValue() == (unsigned)DXIL::ResourceClass::CBuffer) {
  1296. ConstantInt *cbID = cast<ConstantInt>(handle.get_rangeId());
  1297. CollectCBufUsage(U, cbufUsage[cbID->getLimitedValue()]);
  1298. }
  1299. }
  1300. for (unsigned i=0;i<cbSize;i++) {
  1301. SetCBufVarUsage(*m_CBs[i], cbufUsage[i]);
  1302. }
  1303. }
  1304. void DxilModuleReflection::CreateReflectionObjects() {
  1305. DXASSERT_NOMSG(m_pDxilModule != nullptr);
  1306. // Create constant buffers, resources and signatures.
  1307. for (auto && cb : m_pDxilModule->GetCBuffers()) {
  1308. std::unique_ptr<CShaderReflectionConstantBuffer> rcb(new CShaderReflectionConstantBuffer());
  1309. rcb->Initialize(*m_pDxilModule, *(cb.get()), m_Types);
  1310. m_CBs.emplace_back(std::move(rcb));
  1311. }
  1312. // TODO: add tbuffers into m_CBs
  1313. for (auto && uav : m_pDxilModule->GetUAVs()) {
  1314. if (uav->GetKind() != DxilResource::Kind::StructuredBuffer) {
  1315. continue;
  1316. }
  1317. std::unique_ptr<CShaderReflectionConstantBuffer> rcb(new CShaderReflectionConstantBuffer());
  1318. rcb->InitializeStructuredBuffer(*m_pDxilModule, *(uav.get()), m_Types);
  1319. m_CBs.emplace_back(std::move(rcb));
  1320. }
  1321. for (auto && srv : m_pDxilModule->GetSRVs()) {
  1322. if (srv->GetKind() != DxilResource::Kind::StructuredBuffer) {
  1323. continue;
  1324. }
  1325. std::unique_ptr<CShaderReflectionConstantBuffer> rcb(new CShaderReflectionConstantBuffer());
  1326. rcb->InitializeStructuredBuffer(*m_pDxilModule, *(srv.get()), m_Types);
  1327. m_CBs.emplace_back(std::move(rcb));
  1328. }
  1329. // Populate all resources.
  1330. for (auto && cbRes : m_pDxilModule->GetCBuffers()) {
  1331. CreateReflectionObjectForResource(cbRes.get());
  1332. }
  1333. for (auto && samplerRes : m_pDxilModule->GetSamplers()) {
  1334. CreateReflectionObjectForResource(samplerRes.get());
  1335. }
  1336. for (auto && srvRes : m_pDxilModule->GetSRVs()) {
  1337. CreateReflectionObjectForResource(srvRes.get());
  1338. }
  1339. for (auto && uavRes : m_pDxilModule->GetUAVs()) {
  1340. CreateReflectionObjectForResource(uavRes.get());
  1341. }
  1342. }
  1343. static D3D_REGISTER_COMPONENT_TYPE CompTypeToRegisterComponentType(CompType CT) {
  1344. switch (CT.GetKind()) {
  1345. case DXIL::ComponentType::F16:
  1346. case DXIL::ComponentType::F32:
  1347. return D3D_REGISTER_COMPONENT_FLOAT32;
  1348. case DXIL::ComponentType::I1:
  1349. case DXIL::ComponentType::U16:
  1350. case DXIL::ComponentType::U32:
  1351. return D3D_REGISTER_COMPONENT_UINT32;
  1352. case DXIL::ComponentType::I16:
  1353. case DXIL::ComponentType::I32:
  1354. return D3D_REGISTER_COMPONENT_SINT32;
  1355. default:
  1356. return D3D_REGISTER_COMPONENT_UNKNOWN;
  1357. }
  1358. }
  1359. static D3D_MIN_PRECISION CompTypeToMinPrecision(CompType CT) {
  1360. switch (CT.GetKind()) {
  1361. case DXIL::ComponentType::F16:
  1362. return D3D_MIN_PRECISION_FLOAT_16;
  1363. case DXIL::ComponentType::I16:
  1364. return D3D_MIN_PRECISION_SINT_16;
  1365. case DXIL::ComponentType::U16:
  1366. return D3D_MIN_PRECISION_UINT_16;
  1367. default:
  1368. return D3D_MIN_PRECISION_DEFAULT;
  1369. }
  1370. }
  1371. D3D_NAME SemanticToSystemValueType(const Semantic *S, DXIL::TessellatorDomain domain) {
  1372. switch (S->GetKind()) {
  1373. case Semantic::Kind::ClipDistance:
  1374. return D3D_NAME_CLIP_DISTANCE;
  1375. case Semantic::Kind::Arbitrary:
  1376. return D3D_NAME_UNDEFINED;
  1377. case Semantic::Kind::VertexID:
  1378. return D3D_NAME_VERTEX_ID;
  1379. case Semantic::Kind::InstanceID:
  1380. return D3D_NAME_INSTANCE_ID;
  1381. case Semantic::Kind::Position:
  1382. return D3D_NAME_POSITION;
  1383. case Semantic::Kind::Coverage:
  1384. return D3D_NAME_COVERAGE;
  1385. case Semantic::Kind::InnerCoverage:
  1386. return D3D_NAME_INNER_COVERAGE;
  1387. case Semantic::Kind::PrimitiveID:
  1388. return D3D_NAME_PRIMITIVE_ID;
  1389. case Semantic::Kind::SampleIndex:
  1390. return D3D_NAME_SAMPLE_INDEX;
  1391. case Semantic::Kind::IsFrontFace:
  1392. return D3D_NAME_IS_FRONT_FACE;
  1393. case Semantic::Kind::RenderTargetArrayIndex:
  1394. return D3D_NAME_RENDER_TARGET_ARRAY_INDEX;
  1395. case Semantic::Kind::ViewPortArrayIndex:
  1396. return D3D_NAME_VIEWPORT_ARRAY_INDEX;
  1397. case Semantic::Kind::CullDistance:
  1398. return D3D_NAME_CULL_DISTANCE;
  1399. case Semantic::Kind::Target:
  1400. return D3D_NAME_TARGET;
  1401. case Semantic::Kind::Depth:
  1402. return D3D_NAME_DEPTH;
  1403. case Semantic::Kind::DepthLessEqual:
  1404. return D3D_NAME_DEPTH_LESS_EQUAL;
  1405. case Semantic::Kind::DepthGreaterEqual:
  1406. return D3D_NAME_DEPTH_GREATER_EQUAL;
  1407. case Semantic::Kind::StencilRef:
  1408. return D3D_NAME_STENCIL_REF;
  1409. case Semantic::Kind::TessFactor: {
  1410. switch (domain) {
  1411. case DXIL::TessellatorDomain::IsoLine:
  1412. return D3D_NAME_FINAL_LINE_DETAIL_TESSFACTOR;
  1413. case DXIL::TessellatorDomain::Tri:
  1414. return D3D_NAME_FINAL_TRI_EDGE_TESSFACTOR;
  1415. case DXIL::TessellatorDomain::Quad:
  1416. return D3D_NAME_FINAL_QUAD_EDGE_TESSFACTOR;
  1417. default:
  1418. return D3D_NAME_UNDEFINED;
  1419. }
  1420. }
  1421. case Semantic::Kind::InsideTessFactor:
  1422. switch (domain) {
  1423. case DXIL::TessellatorDomain::Tri:
  1424. return D3D_NAME_FINAL_TRI_INSIDE_TESSFACTOR;
  1425. case DXIL::TessellatorDomain::Quad:
  1426. return D3D_NAME_FINAL_QUAD_INSIDE_TESSFACTOR;
  1427. default:
  1428. return D3D_NAME_UNDEFINED;
  1429. }
  1430. case Semantic::Kind::DispatchThreadID:
  1431. case Semantic::Kind::GroupID:
  1432. case Semantic::Kind::GroupIndex:
  1433. case Semantic::Kind::GroupThreadID:
  1434. case Semantic::Kind::DomainLocation:
  1435. case Semantic::Kind::OutputControlPointID:
  1436. case Semantic::Kind::GSInstanceID:
  1437. case Semantic::Kind::Invalid:
  1438. default:
  1439. return D3D_NAME_UNDEFINED;
  1440. }
  1441. }
  1442. static uint8_t NegMask(uint8_t V) {
  1443. V ^= 0xF;
  1444. return V & 0xF;
  1445. }
  1446. void DxilShaderReflection::CreateReflectionObjectsForSignature(
  1447. const DxilSignature &Sig,
  1448. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> &Descs) {
  1449. bool clipDistanceSeen = false;
  1450. for (auto && SigElem : Sig.GetElements()) {
  1451. D3D12_SIGNATURE_PARAMETER_DESC Desc;
  1452. // TODO: why do we have multiple SV_ClipDistance elements?
  1453. if (SigElem->GetSemantic()->GetKind() == DXIL::SemanticKind::ClipDistance) {
  1454. if (clipDistanceSeen) continue;
  1455. clipDistanceSeen = true;
  1456. }
  1457. Desc.ComponentType = CompTypeToRegisterComponentType(SigElem->GetCompType());
  1458. Desc.Mask = SigElem->GetColsAsMask();
  1459. // D3D11_43 does not have MinPrecison.
  1460. if (m_PublicAPI != PublicAPI::D3D11_43)
  1461. Desc.MinPrecision = CompTypeToMinPrecision(SigElem->GetCompType());
  1462. Desc.ReadWriteMask = Sig.IsInput() ? 0 : Desc.Mask; // Start with output-never-written/input-never-read.
  1463. Desc.Register = SigElem->GetStartRow();
  1464. Desc.Stream = SigElem->GetOutputStream();
  1465. Desc.SystemValueType = SemanticToSystemValueType(SigElem->GetSemantic(), m_pDxilModule->GetTessellatorDomain());
  1466. Desc.SemanticName = SigElem->GetName();
  1467. if (!SigElem->GetSemantic()->IsArbitrary())
  1468. Desc.SemanticName = CreateUpperCase(Desc.SemanticName);
  1469. const std::vector<unsigned> &indexVec = SigElem->GetSemanticIndexVec();
  1470. for (unsigned semIdx = 0; semIdx < indexVec.size(); ++semIdx) {
  1471. Desc.SemanticIndex = indexVec[semIdx];
  1472. if (Desc.SystemValueType == D3D_NAME_FINAL_LINE_DETAIL_TESSFACTOR &&
  1473. Desc.SemanticIndex == 1)
  1474. Desc.SystemValueType = D3D_NAME_FINAL_LINE_DETAIL_TESSFACTOR;
  1475. Descs.push_back(Desc);
  1476. }
  1477. }
  1478. }
  1479. LPCSTR DxilShaderReflection::CreateUpperCase(LPCSTR pValue) {
  1480. // Restricted only to [a-z] ASCII.
  1481. LPCSTR pCursor = pValue;
  1482. while (*pCursor != '\0') {
  1483. if ('a' <= *pCursor && *pCursor <= 'z') {
  1484. break;
  1485. }
  1486. ++pCursor;
  1487. }
  1488. if (*pCursor == '\0')
  1489. return pValue;
  1490. std::unique_ptr<char[]> pUpperStr = std::make_unique<char[]>(strlen(pValue) + 1);
  1491. char *pWrite = pUpperStr.get();
  1492. pCursor = pValue;
  1493. for (;;) {
  1494. *pWrite = *pCursor;
  1495. if ('a' <= *pWrite && *pWrite <= 'z') {
  1496. *pWrite += ('A' - 'a');
  1497. }
  1498. if (*pWrite == '\0') break;
  1499. ++pWrite;
  1500. ++pCursor;
  1501. }
  1502. m_UpperCaseNames.push_back(std::move(pUpperStr));
  1503. return m_UpperCaseNames.back().get();
  1504. }
  1505. HRESULT DxilModuleReflection::LoadModule(IDxcBlob *pBlob,
  1506. const DxilPartHeader *pPart) {
  1507. DXASSERT_NOMSG(pBlob != nullptr);
  1508. DXASSERT_NOMSG(pPart != nullptr);
  1509. m_pContainer = pBlob;
  1510. const char *pData = GetDxilPartData(pPart);
  1511. try {
  1512. const char *pBitcode;
  1513. uint32_t bitcodeLength;
  1514. GetDxilProgramBitcode((DxilProgramHeader *)pData, &pBitcode, &bitcodeLength);
  1515. std::unique_ptr<MemoryBuffer> pMemBuffer =
  1516. MemoryBuffer::getMemBufferCopy(StringRef(pBitcode, bitcodeLength));
  1517. #if 0 // We materialize eagerly, because we'll need to walk instructions to look for usage information.
  1518. ErrorOr<std::unique_ptr<Module>> module =
  1519. getLazyBitcodeModule(std::move(pMemBuffer), Context);
  1520. #else
  1521. ErrorOr<std::unique_ptr<Module>> module =
  1522. parseBitcodeFile(pMemBuffer->getMemBufferRef(), Context, nullptr);
  1523. #endif
  1524. if (!module) {
  1525. return E_INVALIDARG;
  1526. }
  1527. std::swap(m_pModule, module.get());
  1528. m_pDxilModule = &m_pModule->GetOrCreateDxilModule();
  1529. CreateReflectionObjects();
  1530. return S_OK;
  1531. }
  1532. CATCH_CPP_RETURN_HRESULT();
  1533. };
  1534. HRESULT DxilShaderReflection::Load(IDxcBlob *pBlob,
  1535. const DxilPartHeader *pPart) {
  1536. IFR(LoadModule(pBlob, pPart));
  1537. try {
  1538. // Set cbuf usage.
  1539. SetCBufferUsage();
  1540. // Populate input/output/patch constant signatures.
  1541. CreateReflectionObjectsForSignature(m_pDxilModule->GetInputSignature(), m_InputSignature);
  1542. CreateReflectionObjectsForSignature(m_pDxilModule->GetOutputSignature(), m_OutputSignature);
  1543. CreateReflectionObjectsForSignature(m_pDxilModule->GetPatchConstantSignature(), m_PatchConstantSignature);
  1544. MarkUsedSignatureElements();
  1545. return S_OK;
  1546. }
  1547. CATCH_CPP_RETURN_HRESULT();
  1548. }
  1549. _Use_decl_annotations_
  1550. HRESULT DxilShaderReflection::GetDesc(D3D12_SHADER_DESC *pDesc) {
  1551. IFR(ZeroMemoryToOut(pDesc));
  1552. const DxilModule &M = *m_pDxilModule;
  1553. const ShaderModel *pSM = M.GetShaderModel();
  1554. pDesc->Version = EncodeVersion(pSM->GetKind(), pSM->GetMajor(), pSM->GetMinor());
  1555. // Unset: LPCSTR Creator; // Creator string
  1556. // Unset: UINT Flags; // Shader compilation/parse flags
  1557. pDesc->ConstantBuffers = m_CBs.size();
  1558. pDesc->BoundResources = m_Resources.size();
  1559. pDesc->InputParameters = m_InputSignature.size();
  1560. pDesc->OutputParameters = m_OutputSignature.size();
  1561. pDesc->PatchConstantParameters = m_PatchConstantSignature.size();
  1562. // Unset: UINT InstructionCount; // Number of emitted instructions
  1563. // Unset: UINT TempRegisterCount; // Number of temporary registers used
  1564. // Unset: UINT TempArrayCount; // Number of temporary arrays used
  1565. // Unset: UINT DefCount; // Number of constant defines
  1566. // Unset: UINT DclCount; // Number of declarations (input + output)
  1567. // Unset: UINT TextureNormalInstructions; // Number of non-categorized texture instructions
  1568. // Unset: UINT TextureLoadInstructions; // Number of texture load instructions
  1569. // Unset: UINT TextureCompInstructions; // Number of texture comparison instructions
  1570. // Unset: UINT TextureBiasInstructions; // Number of texture bias instructions
  1571. // Unset: UINT TextureGradientInstructions; // Number of texture gradient instructions
  1572. // Unset: UINT FloatInstructionCount; // Number of floating point arithmetic instructions used
  1573. // Unset: UINT IntInstructionCount; // Number of signed integer arithmetic instructions used
  1574. // Unset: UINT UintInstructionCount; // Number of unsigned integer arithmetic instructions used
  1575. // Unset: UINT StaticFlowControlCount; // Number of static flow control instructions used
  1576. // Unset: UINT DynamicFlowControlCount; // Number of dynamic flow control instructions used
  1577. // Unset: UINT MacroInstructionCount; // Number of macro instructions used
  1578. // Unset: UINT ArrayInstructionCount; // Number of array instructions used
  1579. // Unset: UINT CutInstructionCount; // Number of cut instructions used
  1580. // Unset: UINT EmitInstructionCount; // Number of emit instructions used
  1581. pDesc->GSOutputTopology = (D3D_PRIMITIVE_TOPOLOGY)M.GetStreamPrimitiveTopology();
  1582. pDesc->GSMaxOutputVertexCount = M.GetMaxVertexCount();
  1583. if (pSM->IsHS())
  1584. pDesc->InputPrimitive = (D3D_PRIMITIVE)(D3D_PRIMITIVE_1_CONTROL_POINT_PATCH + M.GetInputControlPointCount() - 1);
  1585. else
  1586. pDesc->InputPrimitive = (D3D_PRIMITIVE)M.GetInputPrimitive();
  1587. pDesc->cGSInstanceCount = M.GetGSInstanceCount();
  1588. if (pSM->IsHS())
  1589. pDesc->cControlPoints = M.GetOutputControlPointCount();
  1590. else if (pSM->IsDS())
  1591. pDesc->cControlPoints = M.GetInputControlPointCount();
  1592. pDesc->HSOutputPrimitive = (D3D_TESSELLATOR_OUTPUT_PRIMITIVE)M.GetTessellatorOutputPrimitive();
  1593. pDesc->HSPartitioning = (D3D_TESSELLATOR_PARTITIONING)M.GetTessellatorPartitioning();
  1594. pDesc->TessellatorDomain = (D3D_TESSELLATOR_DOMAIN)M.GetTessellatorDomain();
  1595. // instruction counts
  1596. // Unset: UINT cBarrierInstructions; // Number of barrier instructions in a compute shader
  1597. // Unset: UINT cInterlockedInstructions; // Number of interlocked instructions
  1598. // Unset: UINT cTextureStoreInstructions; // Number of texture writes
  1599. return S_OK;
  1600. }
  1601. static bool GetUnsignedVal(Value *V, uint32_t *pValue) {
  1602. ConstantInt *CI = dyn_cast<ConstantInt>(V);
  1603. if (!CI) return false;
  1604. uint64_t u = CI->getZExtValue();
  1605. if (u > UINT32_MAX) return false;
  1606. *pValue = (uint32_t)u;
  1607. return true;
  1608. }
  1609. void DxilShaderReflection::MarkUsedSignatureElements() {
  1610. Function *F = m_pDxilModule->GetEntryFunction();
  1611. DXASSERT(F != nullptr, "else module load should have failed");
  1612. // For every loadInput/storeOutput, update the corresponding ReadWriteMask.
  1613. // F is a pointer to a Function instance
  1614. unsigned elementCount = m_InputSignature.size() + m_OutputSignature.size() +
  1615. m_PatchConstantSignature.size();
  1616. unsigned markedElementCount = 0;
  1617. for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
  1618. DxilInst_LoadInput LI(&*I);
  1619. DxilInst_StoreOutput SO(&*I);
  1620. DxilInst_LoadPatchConstant LPC(&*I);
  1621. DxilInst_StorePatchConstant SPC(&*I);
  1622. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> *pDescs;
  1623. const DxilSignature *pSig;
  1624. uint32_t col, row, sigId;
  1625. if (LI) {
  1626. if (!GetUnsignedVal(LI.get_inputSigId(), &sigId)) continue;
  1627. if (!GetUnsignedVal(LI.get_colIndex(), &col)) continue;
  1628. if (!GetUnsignedVal(LI.get_rowIndex(), &row)) continue;
  1629. pDescs = &m_InputSignature;
  1630. pSig = &m_pDxilModule->GetInputSignature();
  1631. }
  1632. else if (SO) {
  1633. if (!GetUnsignedVal(SO.get_outputSigId(), &sigId)) continue;
  1634. if (!GetUnsignedVal(SO.get_colIndex(), &col)) continue;
  1635. if (!GetUnsignedVal(SO.get_rowIndex(), &row)) continue;
  1636. pDescs = &m_OutputSignature;
  1637. pSig = &m_pDxilModule->GetOutputSignature();
  1638. }
  1639. else if (SPC) {
  1640. if (!GetUnsignedVal(SPC.get_outputSigID(), &sigId)) continue;
  1641. if (!GetUnsignedVal(SPC.get_col(), &col)) continue;
  1642. if (!GetUnsignedVal(SPC.get_row(), &row)) continue;
  1643. pDescs = &m_PatchConstantSignature;
  1644. pSig = &m_pDxilModule->GetPatchConstantSignature();
  1645. }
  1646. else if (LPC) {
  1647. if (!GetUnsignedVal(LPC.get_inputSigId(), &sigId)) continue;
  1648. if (!GetUnsignedVal(LPC.get_col(), &col)) continue;
  1649. if (!GetUnsignedVal(LPC.get_row(), &row)) continue;
  1650. pDescs = &m_PatchConstantSignature;
  1651. pSig = &m_pDxilModule->GetPatchConstantSignature();
  1652. }
  1653. else {
  1654. continue;
  1655. }
  1656. if (sigId >= pDescs->size()) continue;
  1657. D3D12_SIGNATURE_PARAMETER_DESC *pDesc = &(*pDescs)[sigId];
  1658. // Consider being more fine-grained about masks.
  1659. // We report sometimes-read on input as always-read.
  1660. unsigned UsedMask = pSig->IsInput() ? pDesc->Mask : NegMask(pDesc->Mask);
  1661. if (pDesc->ReadWriteMask == UsedMask)
  1662. continue;
  1663. pDesc->ReadWriteMask = UsedMask;
  1664. ++markedElementCount;
  1665. if (markedElementCount == elementCount)
  1666. return;
  1667. }
  1668. }
  1669. _Use_decl_annotations_
  1670. ID3D12ShaderReflectionConstantBuffer* DxilShaderReflection::GetConstantBufferByIndex(UINT Index) {
  1671. return DxilModuleReflection::_GetConstantBufferByIndex(Index);
  1672. }
  1673. ID3D12ShaderReflectionConstantBuffer* DxilModuleReflection::_GetConstantBufferByIndex(UINT Index) {
  1674. if (Index >= m_CBs.size()) {
  1675. return &g_InvalidSRConstantBuffer;
  1676. }
  1677. return m_CBs[Index].get();
  1678. }
  1679. _Use_decl_annotations_
  1680. ID3D12ShaderReflectionConstantBuffer* DxilShaderReflection::GetConstantBufferByName(LPCSTR Name) {
  1681. return DxilModuleReflection::_GetConstantBufferByName(Name);
  1682. }
  1683. ID3D12ShaderReflectionConstantBuffer* DxilModuleReflection::_GetConstantBufferByName(LPCSTR Name) {
  1684. if (!Name) {
  1685. return &g_InvalidSRConstantBuffer;
  1686. }
  1687. for (UINT index = 0; index < m_CBs.size(); ++index) {
  1688. if (0 == strcmp(m_CBs[index]->GetName(), Name)) {
  1689. return m_CBs[index].get();
  1690. }
  1691. }
  1692. return &g_InvalidSRConstantBuffer;
  1693. }
  1694. _Use_decl_annotations_
  1695. HRESULT DxilShaderReflection::GetResourceBindingDesc(UINT ResourceIndex,
  1696. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc) {
  1697. return DxilModuleReflection::_GetResourceBindingDesc(ResourceIndex, pDesc, m_PublicAPI);
  1698. }
  1699. HRESULT DxilModuleReflection::_GetResourceBindingDesc(UINT ResourceIndex,
  1700. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc, PublicAPI api) {
  1701. IFRBOOL(pDesc != nullptr, E_INVALIDARG);
  1702. IFRBOOL(ResourceIndex < m_Resources.size(), E_INVALIDARG);
  1703. if (api != PublicAPI::D3D12) {
  1704. memcpy(pDesc, &m_Resources[ResourceIndex], sizeof(D3D11_SHADER_INPUT_BIND_DESC));
  1705. }
  1706. else {
  1707. *pDesc = m_Resources[ResourceIndex];
  1708. }
  1709. return S_OK;
  1710. }
  1711. _Use_decl_annotations_
  1712. HRESULT DxilShaderReflection::GetInputParameterDesc(UINT ParameterIndex,
  1713. _Out_ D3D12_SIGNATURE_PARAMETER_DESC *pDesc) {
  1714. IFRBOOL(pDesc != nullptr, E_INVALIDARG);
  1715. IFRBOOL(ParameterIndex < m_InputSignature.size(), E_INVALIDARG);
  1716. if (m_PublicAPI != PublicAPI::D3D11_43)
  1717. *pDesc = m_InputSignature[ParameterIndex];
  1718. else
  1719. memcpy(pDesc, &m_InputSignature[ParameterIndex],
  1720. // D3D11_43 does not have MinPrecison.
  1721. sizeof(D3D12_SIGNATURE_PARAMETER_DESC) - sizeof(D3D_MIN_PRECISION));
  1722. return S_OK;
  1723. }
  1724. _Use_decl_annotations_
  1725. HRESULT DxilShaderReflection::GetOutputParameterDesc(UINT ParameterIndex,
  1726. D3D12_SIGNATURE_PARAMETER_DESC *pDesc) {
  1727. IFRBOOL(pDesc != nullptr, E_INVALIDARG);
  1728. IFRBOOL(ParameterIndex < m_OutputSignature.size(), E_INVALIDARG);
  1729. if (m_PublicAPI != PublicAPI::D3D11_43)
  1730. *pDesc = m_OutputSignature[ParameterIndex];
  1731. else
  1732. memcpy(pDesc, &m_OutputSignature[ParameterIndex],
  1733. // D3D11_43 does not have MinPrecison.
  1734. sizeof(D3D12_SIGNATURE_PARAMETER_DESC) - sizeof(D3D_MIN_PRECISION));
  1735. return S_OK;
  1736. }
  1737. _Use_decl_annotations_
  1738. HRESULT DxilShaderReflection::GetPatchConstantParameterDesc(UINT ParameterIndex,
  1739. D3D12_SIGNATURE_PARAMETER_DESC *pDesc) {
  1740. IFRBOOL(pDesc != nullptr, E_INVALIDARG);
  1741. IFRBOOL(ParameterIndex < m_PatchConstantSignature.size(), E_INVALIDARG);
  1742. if (m_PublicAPI != PublicAPI::D3D11_43)
  1743. *pDesc = m_PatchConstantSignature[ParameterIndex];
  1744. else
  1745. memcpy(pDesc, &m_PatchConstantSignature[ParameterIndex],
  1746. // D3D11_43 does not have MinPrecison.
  1747. sizeof(D3D12_SIGNATURE_PARAMETER_DESC) - sizeof(D3D_MIN_PRECISION));
  1748. return S_OK;
  1749. }
  1750. _Use_decl_annotations_
  1751. ID3D12ShaderReflectionVariable* DxilShaderReflection::GetVariableByName(LPCSTR Name) {
  1752. return DxilModuleReflection::_GetVariableByName(Name);
  1753. }
  1754. ID3D12ShaderReflectionVariable* DxilModuleReflection::_GetVariableByName(LPCSTR Name) {
  1755. if (Name != nullptr) {
  1756. // Iterate through all cbuffers to find the variable.
  1757. for (UINT i = 0; i < m_CBs.size(); i++) {
  1758. ID3D12ShaderReflectionVariable *pVar = m_CBs[i]->GetVariableByName(Name);
  1759. if (pVar != &g_InvalidSRVariable) {
  1760. return pVar;
  1761. }
  1762. }
  1763. }
  1764. return &g_InvalidSRVariable;
  1765. }
  1766. _Use_decl_annotations_
  1767. HRESULT DxilShaderReflection::GetResourceBindingDescByName(LPCSTR Name,
  1768. D3D12_SHADER_INPUT_BIND_DESC *pDesc) {
  1769. return DxilModuleReflection::_GetResourceBindingDescByName(Name, pDesc, m_PublicAPI);
  1770. }
  1771. HRESULT DxilModuleReflection::_GetResourceBindingDescByName(LPCSTR Name,
  1772. D3D12_SHADER_INPUT_BIND_DESC *pDesc, PublicAPI api) {
  1773. IFRBOOL(Name != nullptr, E_INVALIDARG);
  1774. for (UINT i = 0; i < m_Resources.size(); i++) {
  1775. if (strcmp(m_Resources[i].Name, Name) == 0) {
  1776. if (api != PublicAPI::D3D12) {
  1777. memcpy(pDesc, &m_Resources[i], sizeof(D3D11_SHADER_INPUT_BIND_DESC));
  1778. }
  1779. else {
  1780. *pDesc = m_Resources[i];
  1781. }
  1782. return S_OK;
  1783. }
  1784. }
  1785. return HRESULT_FROM_WIN32(ERROR_NOT_FOUND);
  1786. }
  1787. UINT DxilShaderReflection::GetMovInstructionCount() { return 0; }
  1788. UINT DxilShaderReflection::GetMovcInstructionCount() { return 0; }
  1789. UINT DxilShaderReflection::GetConversionInstructionCount() { return 0; }
  1790. UINT DxilShaderReflection::GetBitwiseInstructionCount() { return 0; }
  1791. D3D_PRIMITIVE DxilShaderReflection::GetGSInputPrimitive() {
  1792. if (!m_pDxilModule->GetShaderModel()->IsGS())
  1793. return D3D_PRIMITIVE::D3D10_PRIMITIVE_UNDEFINED;
  1794. return (D3D_PRIMITIVE)m_pDxilModule->GetInputPrimitive();
  1795. }
  1796. BOOL DxilShaderReflection::IsSampleFrequencyShader() {
  1797. // TODO: determine correct value
  1798. return FALSE;
  1799. }
  1800. UINT DxilShaderReflection::GetNumInterfaceSlots() { return 0; }
  1801. _Use_decl_annotations_
  1802. HRESULT DxilShaderReflection::GetMinFeatureLevel(enum D3D_FEATURE_LEVEL* pLevel) {
  1803. IFR(AssignToOut(D3D_FEATURE_LEVEL_12_0, pLevel));
  1804. return S_OK;
  1805. }
  1806. _Use_decl_annotations_
  1807. UINT DxilShaderReflection::GetThreadGroupSize(UINT *pSizeX, UINT *pSizeY, UINT *pSizeZ) {
  1808. if (!m_pDxilModule->GetShaderModel()->IsCS()) {
  1809. AssignToOutOpt((UINT)0, pSizeX);
  1810. AssignToOutOpt((UINT)0, pSizeY);
  1811. AssignToOutOpt((UINT)0, pSizeZ);
  1812. return 0;
  1813. }
  1814. unsigned x = m_pDxilModule->GetNumThreads(0);
  1815. unsigned y = m_pDxilModule->GetNumThreads(1);
  1816. unsigned z = m_pDxilModule->GetNumThreads(2);
  1817. AssignToOutOpt(x, pSizeX);
  1818. AssignToOutOpt(y, pSizeY);
  1819. AssignToOutOpt(z, pSizeZ);
  1820. return x * y * z;
  1821. }
  1822. UINT64 DxilShaderReflection::GetRequiresFlags() {
  1823. UINT64 result = 0;
  1824. uint64_t features = m_pDxilModule->m_ShaderFlags.GetFeatureInfo();
  1825. if (features & ShaderFeatureInfo_Doubles) result |= D3D_SHADER_REQUIRES_DOUBLES;
  1826. if (features & ShaderFeatureInfo_UAVsAtEveryStage) result |= D3D_SHADER_REQUIRES_UAVS_AT_EVERY_STAGE;
  1827. if (features & ShaderFeatureInfo_64UAVs) result |= D3D_SHADER_REQUIRES_64_UAVS;
  1828. if (features & ShaderFeatureInfo_MinimumPrecision) result |= D3D_SHADER_REQUIRES_MINIMUM_PRECISION;
  1829. if (features & ShaderFeatureInfo_11_1_DoubleExtensions) result |= D3D_SHADER_REQUIRES_11_1_DOUBLE_EXTENSIONS;
  1830. if (features & ShaderFeatureInfo_11_1_ShaderExtensions) result |= D3D_SHADER_REQUIRES_11_1_SHADER_EXTENSIONS;
  1831. if (features & ShaderFeatureInfo_LEVEL9ComparisonFiltering) result |= D3D_SHADER_REQUIRES_LEVEL_9_COMPARISON_FILTERING;
  1832. if (features & ShaderFeatureInfo_TiledResources) result |= D3D_SHADER_REQUIRES_TILED_RESOURCES;
  1833. if (features & ShaderFeatureInfo_StencilRef) result |= D3D_SHADER_REQUIRES_STENCIL_REF;
  1834. if (features & ShaderFeatureInfo_InnerCoverage) result |= D3D_SHADER_REQUIRES_INNER_COVERAGE;
  1835. if (features & ShaderFeatureInfo_TypedUAVLoadAdditionalFormats) result |= D3D_SHADER_REQUIRES_TYPED_UAV_LOAD_ADDITIONAL_FORMATS;
  1836. if (features & ShaderFeatureInfo_ROVs) result |= D3D_SHADER_REQUIRES_ROVS;
  1837. if (features & ShaderFeatureInfo_ViewportAndRTArrayIndexFromAnyShaderFeedingRasterizer) result |= D3D_SHADER_REQUIRES_VIEWPORT_AND_RT_ARRAY_INDEX_FROM_ANY_SHADER_FEEDING_RASTERIZER;
  1838. return result;
  1839. }
  1840. // ID3D12FunctionReflection
  1841. class CFunctionReflection : public ID3D12FunctionReflection {
  1842. protected:
  1843. DxilLibraryReflection * m_pLibraryReflection = nullptr;
  1844. const Function *m_pFunction;
  1845. const DxilFunctionProps *m_pProps; // nullptr if non-shader library function or patch constant function
  1846. std::string m_Name;
  1847. typedef SmallSetVector<UINT32, 8> ResourceUseSet;
  1848. ResourceUseSet m_UsedResources;
  1849. ResourceUseSet m_UsedCBs;
  1850. public:
  1851. void Initialize(DxilLibraryReflection* pLibraryReflection, Function *pFunction) {
  1852. DXASSERT_NOMSG(pLibraryReflection);
  1853. DXASSERT_NOMSG(pFunction);
  1854. m_pLibraryReflection = pLibraryReflection;
  1855. m_pFunction = pFunction;
  1856. const DxilModule &M = *m_pLibraryReflection->m_pDxilModule;
  1857. m_Name = m_pFunction->getName().str();
  1858. m_pProps = nullptr;
  1859. if (M.HasDxilFunctionProps(m_pFunction)) {
  1860. m_pProps = &M.GetDxilFunctionProps(m_pFunction);
  1861. }
  1862. }
  1863. void AddResourceReference(UINT resIndex) {
  1864. m_UsedResources.insert(resIndex);
  1865. }
  1866. void AddCBReference(UINT cbIndex) {
  1867. m_UsedCBs.insert(cbIndex);
  1868. }
  1869. // ID3D12FunctionReflection
  1870. STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_FUNCTION_DESC * pDesc);
  1871. // BufferIndex relative to used constant buffers here
  1872. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer *, GetConstantBufferByIndex)(THIS_ _In_ UINT BufferIndex);
  1873. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer *, GetConstantBufferByName)(THIS_ _In_ LPCSTR Name);
  1874. STDMETHOD(GetResourceBindingDesc)(THIS_ _In_ UINT ResourceIndex,
  1875. _Out_ D3D12_SHADER_INPUT_BIND_DESC * pDesc);
  1876. STDMETHOD_(ID3D12ShaderReflectionVariable *, GetVariableByName)(THIS_ _In_ LPCSTR Name);
  1877. STDMETHOD(GetResourceBindingDescByName)(THIS_ _In_ LPCSTR Name,
  1878. _Out_ D3D12_SHADER_INPUT_BIND_DESC * pDesc);
  1879. // Use D3D_RETURN_PARAMETER_INDEX to get description of the return value.
  1880. STDMETHOD_(ID3D12FunctionParameterReflection *, GetFunctionParameter)(THIS_ _In_ INT ParameterIndex) {
  1881. return &g_InvalidFunctionParameter;
  1882. }
  1883. };
  1884. _Use_decl_annotations_
  1885. HRESULT CFunctionReflection::GetDesc(D3D12_FUNCTION_DESC *pDesc) {
  1886. DXASSERT_NOMSG(m_pLibraryReflection);
  1887. IFR(ZeroMemoryToOut(pDesc));
  1888. const ShaderModel* pSM = m_pLibraryReflection->m_pDxilModule->GetShaderModel();
  1889. DXIL::ShaderKind kind = DXIL::ShaderKind::Library;
  1890. if (m_pProps) {
  1891. kind = m_pProps->shaderKind;
  1892. }
  1893. pDesc->Version = EncodeVersion(kind, pSM->GetMajor(), pSM->GetMinor());
  1894. //Unset: LPCSTR Creator; // Creator string
  1895. //Unset: UINT Flags; // Shader compilation/parse flags
  1896. pDesc->ConstantBuffers = (UINT)m_UsedCBs.size();
  1897. pDesc->BoundResources = (UINT)m_UsedResources.size();
  1898. //Unset: UINT InstructionCount; // Number of emitted instructions
  1899. //Unset: UINT TempRegisterCount; // Number of temporary registers used
  1900. //Unset: UINT TempArrayCount; // Number of temporary arrays used
  1901. //Unset: UINT DefCount; // Number of constant defines
  1902. //Unset: UINT DclCount; // Number of declarations (input + output)
  1903. //Unset: UINT TextureNormalInstructions; // Number of non-categorized texture instructions
  1904. //Unset: UINT TextureLoadInstructions; // Number of texture load instructions
  1905. //Unset: UINT TextureCompInstructions; // Number of texture comparison instructions
  1906. //Unset: UINT TextureBiasInstructions; // Number of texture bias instructions
  1907. //Unset: UINT TextureGradientInstructions; // Number of texture gradient instructions
  1908. //Unset: UINT FloatInstructionCount; // Number of floating point arithmetic instructions used
  1909. //Unset: UINT IntInstructionCount; // Number of signed integer arithmetic instructions used
  1910. //Unset: UINT UintInstructionCount; // Number of unsigned integer arithmetic instructions used
  1911. //Unset: UINT StaticFlowControlCount; // Number of static flow control instructions used
  1912. //Unset: UINT DynamicFlowControlCount; // Number of dynamic flow control instructions used
  1913. //Unset: UINT MacroInstructionCount; // Number of macro instructions used
  1914. //Unset: UINT ArrayInstructionCount; // Number of array instructions used
  1915. //Unset: UINT MovInstructionCount; // Number of mov instructions used
  1916. //Unset: UINT MovcInstructionCount; // Number of movc instructions used
  1917. //Unset: UINT ConversionInstructionCount; // Number of type conversion instructions used
  1918. //Unset: UINT BitwiseInstructionCount; // Number of bitwise arithmetic instructions used
  1919. //Unset: D3D_FEATURE_LEVEL MinFeatureLevel; // Min target of the function byte code
  1920. //Unset: UINT64 RequiredFeatureFlags; // Required feature flags
  1921. pDesc->Name = m_Name.c_str();
  1922. //Unset: INT FunctionParameterCount; // Number of logical parameters in the function signature (not including return)
  1923. //Unset: BOOL HasReturn; // TRUE, if function returns a value, false - it is a subroutine
  1924. //Unset: BOOL Has10Level9VertexShader; // TRUE, if there is a 10L9 VS blob
  1925. //Unset: BOOL Has10Level9PixelShader; // TRUE, if there is a 10L9 PS blob
  1926. return S_OK;
  1927. }
  1928. // BufferIndex is relative to used constant buffers here
  1929. ID3D12ShaderReflectionConstantBuffer *CFunctionReflection::GetConstantBufferByIndex(UINT BufferIndex) {
  1930. DXASSERT_NOMSG(m_pLibraryReflection);
  1931. if (BufferIndex >= m_UsedCBs.size())
  1932. return &g_InvalidSRConstantBuffer;
  1933. return m_pLibraryReflection->_GetConstantBufferByIndex(m_UsedCBs[BufferIndex]);
  1934. }
  1935. ID3D12ShaderReflectionConstantBuffer *CFunctionReflection::GetConstantBufferByName(LPCSTR Name) {
  1936. DXASSERT_NOMSG(m_pLibraryReflection);
  1937. return m_pLibraryReflection->_GetConstantBufferByName(Name);
  1938. }
  1939. HRESULT CFunctionReflection::GetResourceBindingDesc(UINT ResourceIndex,
  1940. D3D12_SHADER_INPUT_BIND_DESC * pDesc) {
  1941. DXASSERT_NOMSG(m_pLibraryReflection);
  1942. if (ResourceIndex >= m_UsedResources.size())
  1943. return E_INVALIDARG;
  1944. return m_pLibraryReflection->_GetResourceBindingDesc(m_UsedResources[ResourceIndex], pDesc);
  1945. }
  1946. ID3D12ShaderReflectionVariable * CFunctionReflection::GetVariableByName(LPCSTR Name) {
  1947. DXASSERT_NOMSG(m_pLibraryReflection);
  1948. return m_pLibraryReflection->_GetVariableByName(Name);
  1949. }
  1950. HRESULT CFunctionReflection::GetResourceBindingDescByName(LPCSTR Name,
  1951. D3D12_SHADER_INPUT_BIND_DESC * pDesc) {
  1952. DXASSERT_NOMSG(m_pLibraryReflection);
  1953. return m_pLibraryReflection->_GetResourceBindingDescByName(Name, pDesc);
  1954. }
  1955. // DxilLibraryReflection
  1956. // From DxilContainerAssembler:
  1957. static llvm::Function *FindUsingFunction(llvm::Value *User) {
  1958. if (llvm::Instruction *I = dyn_cast<llvm::Instruction>(User)) {
  1959. // Instruction should be inside a basic block, which is in a function
  1960. return cast<llvm::Function>(I->getParent()->getParent());
  1961. }
  1962. // User can be either instruction, constant, or operator. But User is an
  1963. // operator only if constant is a scalar value, not resource pointer.
  1964. llvm::Constant *CU = cast<llvm::Constant>(User);
  1965. if (!CU->user_empty())
  1966. return FindUsingFunction(*CU->user_begin());
  1967. else
  1968. return nullptr;
  1969. }
  1970. void DxilLibraryReflection::AddResourceUseToFunctions(DxilResourceBase &resource, unsigned resIndex) {
  1971. Constant *var = resource.GetGlobalSymbol();
  1972. if (var) {
  1973. for (auto user : var->users()) {
  1974. // Find the function.
  1975. if (llvm::Function *F = FindUsingFunction(user)) {
  1976. auto funcReflector = m_FunctionsByPtr[F];
  1977. funcReflector->AddResourceReference(resIndex);
  1978. if (resource.GetClass() == DXIL::ResourceClass::CBuffer) {
  1979. funcReflector->AddCBReference(resource.GetID());
  1980. }
  1981. }
  1982. }
  1983. }
  1984. }
  1985. void DxilLibraryReflection::AddResourceDependencies() {
  1986. std::map<StringRef, CFunctionReflection*> orderedMap;
  1987. for (auto &F : m_pModule->functions()) {
  1988. if (F.isDeclaration())
  1989. continue;
  1990. auto &func = m_FunctionMap[F.getName()];
  1991. DXASSERT(!func.get(), "otherwise duplicate named functions");
  1992. func.reset(new CFunctionReflection());
  1993. func->Initialize(this, &F);
  1994. m_FunctionsByPtr[&F] = func.get();
  1995. orderedMap[F.getName()] = func.get();
  1996. }
  1997. // Fill in function vector sorted by name
  1998. m_FunctionVector.clear();
  1999. m_FunctionVector.reserve(orderedMap.size());
  2000. for (auto &it : orderedMap) {
  2001. m_FunctionVector.push_back(it.second);
  2002. }
  2003. UINT resIndex = 0;
  2004. for (auto &resource : m_Resources) {
  2005. switch ((UINT32)resource.Type) {
  2006. case D3D_SIT_CBUFFER:
  2007. AddResourceUseToFunctions(m_pDxilModule->GetCBuffer(resource.uID), resIndex);
  2008. break;
  2009. case D3D_SIT_TBUFFER: // TODO: Handle when TBuffers are added to CB list
  2010. case D3D_SIT_TEXTURE:
  2011. case D3D_SIT_STRUCTURED:
  2012. case D3D_SIT_BYTEADDRESS:
  2013. case D3D_SIT_RTACCELERATIONSTRUCTURE:
  2014. AddResourceUseToFunctions(m_pDxilModule->GetSRV(resource.uID), resIndex);
  2015. break;
  2016. case D3D_SIT_UAV_RWTYPED:
  2017. case D3D_SIT_UAV_RWSTRUCTURED:
  2018. case D3D_SIT_UAV_RWBYTEADDRESS:
  2019. case D3D_SIT_UAV_APPEND_STRUCTURED:
  2020. case D3D_SIT_UAV_CONSUME_STRUCTURED:
  2021. case D3D_SIT_UAV_RWSTRUCTURED_WITH_COUNTER:
  2022. AddResourceUseToFunctions(m_pDxilModule->GetUAV(resource.uID), resIndex);
  2023. break;
  2024. case D3D_SIT_SAMPLER:
  2025. AddResourceUseToFunctions(m_pDxilModule->GetSampler(resource.uID), resIndex);
  2026. break;
  2027. }
  2028. resIndex++;
  2029. }
  2030. }
  2031. // ID3D12LibraryReflection
  2032. HRESULT DxilLibraryReflection::Load(IDxcBlob *pBlob,
  2033. const DxilPartHeader *pPart) {
  2034. IFR(LoadModule(pBlob, pPart));
  2035. try {
  2036. AddResourceDependencies();
  2037. return S_OK;
  2038. }
  2039. CATCH_CPP_RETURN_HRESULT();
  2040. }
  2041. _Use_decl_annotations_
  2042. HRESULT DxilLibraryReflection::GetDesc(D3D12_LIBRARY_DESC * pDesc) {
  2043. IFR(ZeroMemoryToOut(pDesc));
  2044. //Unset: LPCSTR Creator; // The name of the originator of the library.
  2045. //Unset: UINT Flags; // Compilation flags.
  2046. //UINT FunctionCount; // Number of functions exported from the library.
  2047. pDesc->FunctionCount = (UINT)m_FunctionVector.size();
  2048. return S_OK;
  2049. }
  2050. _Use_decl_annotations_
  2051. ID3D12FunctionReflection *DxilLibraryReflection::GetFunctionByIndex(INT FunctionIndex) {
  2052. if ((UINT)FunctionIndex >= m_FunctionVector.size())
  2053. return &g_InvalidFunction;
  2054. return m_FunctionVector[FunctionIndex];
  2055. }
  2056. #else // LLVM_ON_WIN32
  2057. void hlsl::CreateDxcContainerReflection(IDxcContainerReflection **ppResult) {
  2058. *ppResult = nullptr;
  2059. }
  2060. DEFINE_CROSS_PLATFORM_UUIDOF(IDxcContainerReflection)
  2061. #endif // LLVM_ON_WIN32