DxilContainerReflection.cpp 88 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilContainerReflection.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Provides support for reading DXIL container structures. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/Bitcode/ReaderWriter.h"
  12. #include "llvm/IR/LLVMContext.h"
  13. #include "llvm/IR/InstIterator.h"
  14. #include "llvm/IR/Operator.h"
  15. #include "dxc/HLSL/DxilContainer.h"
  16. #include "dxc/HLSL/DxilModule.h"
  17. #include "dxc/HLSL/DxilShaderModel.h"
  18. #include "dxc/HLSL/DxilOperations.h"
  19. #include "dxc/HLSL/DxilInstructions.h"
  20. #include "dxc/Support/Global.h"
  21. #include "dxc/Support/Unicode.h"
  22. #include "dxc/Support/WinIncludes.h"
  23. #include "dxc/Support/microcom.h"
  24. #include "dxc/Support/FileIOHelper.h"
  25. #include "dxc/Support/dxcapi.impl.h"
  26. #include "dxc/HLSL/DxilFunctionProps.h"
  27. #include <unordered_set>
  28. #include "llvm/ADT/SetVector.h"
  29. #include "dxc/dxcapi.h"
  30. #ifdef LLVM_ON_WIN32
  31. #include "d3d12shader.h" // for compatibility
  32. #include "d3d11shader.h" // for compatibility
  33. const GUID IID_ID3D11ShaderReflection_43 = {
  34. 0x0a233719,
  35. 0x3960,
  36. 0x4578,
  37. {0x9d, 0x7c, 0x20, 0x3b, 0x8b, 0x1d, 0x9c, 0xc1}};
  38. const GUID IID_ID3D11ShaderReflection_47 = {
  39. 0x8d536ca1,
  40. 0x0cca,
  41. 0x4956,
  42. {0xa8, 0x37, 0x78, 0x69, 0x63, 0x75, 0x55, 0x84}};
  43. using namespace llvm;
  44. using namespace hlsl;
  45. class DxilContainerReflection : public IDxcContainerReflection {
  46. private:
  47. DXC_MICROCOM_TM_REF_FIELDS()
  48. CComPtr<IDxcBlob> m_container;
  49. const DxilContainerHeader *m_pHeader = nullptr;
  50. uint32_t m_headerLen = 0;
  51. bool IsLoaded() const { return m_pHeader != nullptr; }
  52. public:
  53. DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
  54. DXC_MICROCOM_TM_CTOR(DxilContainerReflection)
  55. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
  56. return DoBasicQueryInterface<IDxcContainerReflection>(this, iid, ppvObject);
  57. }
  58. HRESULT STDMETHODCALLTYPE Load(_In_ IDxcBlob *pContainer) override;
  59. HRESULT STDMETHODCALLTYPE GetPartCount(_Out_ UINT32 *pResult) override;
  60. HRESULT STDMETHODCALLTYPE GetPartKind(UINT32 idx, _Out_ UINT32 *pResult) override;
  61. HRESULT STDMETHODCALLTYPE GetPartContent(UINT32 idx, _COM_Outptr_ IDxcBlob **ppResult) override;
  62. HRESULT STDMETHODCALLTYPE FindFirstPartKind(UINT32 kind, _Out_ UINT32 *pResult) override;
  63. HRESULT STDMETHODCALLTYPE GetPartReflection(UINT32 idx, REFIID iid, _COM_Outptr_ void **ppvObject) override;
  64. };
  65. class CShaderReflectionConstantBuffer;
  66. class CShaderReflectionType;
  67. enum class PublicAPI { D3D12 = 0, D3D11_47 = 1, D3D11_43 = 2 };
  68. class DxilModuleReflection {
  69. public:
  70. CComPtr<IDxcBlob> m_pContainer;
  71. LLVMContext Context;
  72. std::unique_ptr<Module> m_pModule; // Must come after LLVMContext, otherwise unique_ptr will over-delete.
  73. DxilModule *m_pDxilModule = nullptr;
  74. std::vector<std::unique_ptr<CShaderReflectionConstantBuffer>> m_CBs;
  75. std::vector<D3D12_SHADER_INPUT_BIND_DESC> m_Resources;
  76. std::vector<std::unique_ptr<CShaderReflectionType>> m_Types;
  77. void CreateReflectionObjects();
  78. void CreateReflectionObjectForResource(DxilResourceBase *R);
  79. HRESULT LoadModule(IDxcBlob *pBlob, const DxilPartHeader *pPart);
  80. // Common code
  81. ID3D12ShaderReflectionConstantBuffer* _GetConstantBufferByIndex(UINT Index);
  82. ID3D12ShaderReflectionConstantBuffer* _GetConstantBufferByName(LPCSTR Name);
  83. HRESULT _GetResourceBindingDesc(UINT ResourceIndex,
  84. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc,
  85. PublicAPI api = PublicAPI::D3D12);
  86. ID3D12ShaderReflectionVariable* _GetVariableByName(LPCSTR Name);
  87. HRESULT _GetResourceBindingDescByName(LPCSTR Name,
  88. D3D12_SHADER_INPUT_BIND_DESC *pDesc,
  89. PublicAPI api = PublicAPI::D3D12);
  90. };
  91. class DxilShaderReflection : public DxilModuleReflection, public ID3D12ShaderReflection {
  92. private:
  93. DXC_MICROCOM_TM_REF_FIELDS()
  94. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> m_InputSignature;
  95. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> m_OutputSignature;
  96. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> m_PatchConstantSignature;
  97. std::vector<std::unique_ptr<char[]>> m_UpperCaseNames;
  98. void SetCBufferUsage();
  99. void CreateReflectionObjectsForSignature(
  100. const DxilSignature &Sig,
  101. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> &Descs);
  102. LPCSTR CreateUpperCase(LPCSTR pValue);
  103. void MarkUsedSignatureElements();
  104. public:
  105. PublicAPI m_PublicAPI;
  106. void SetPublicAPI(PublicAPI value) { m_PublicAPI = value; }
  107. static PublicAPI IIDToAPI(REFIID iid) {
  108. PublicAPI api = PublicAPI::D3D12;
  109. if (IsEqualIID(IID_ID3D11ShaderReflection_43, iid))
  110. api = PublicAPI::D3D11_43;
  111. else if (IsEqualIID(IID_ID3D11ShaderReflection_47, iid))
  112. api = PublicAPI::D3D11_47;
  113. return api;
  114. }
  115. DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
  116. DXC_MICROCOM_TM_CTOR(DxilShaderReflection)
  117. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
  118. HRESULT hr = DoBasicQueryInterface<ID3D12ShaderReflection>(this, iid, ppvObject);
  119. if (hr == E_NOINTERFACE) {
  120. // ID3D11ShaderReflection is identical to ID3D12ShaderReflection, except
  121. // for some shorter data structures in some out parameters.
  122. PublicAPI api = IIDToAPI(iid);
  123. if (api == m_PublicAPI) {
  124. *ppvObject = (ID3D12ShaderReflection *)this;
  125. this->AddRef();
  126. hr = S_OK;
  127. }
  128. }
  129. return hr;
  130. }
  131. HRESULT Load(IDxcBlob *pBlob, const DxilPartHeader *pPart);
  132. // ID3D12ShaderReflection
  133. STDMETHODIMP GetDesc(THIS_ _Out_ D3D12_SHADER_DESC *pDesc);
  134. STDMETHODIMP_(ID3D12ShaderReflectionConstantBuffer*) GetConstantBufferByIndex(THIS_ _In_ UINT Index);
  135. STDMETHODIMP_(ID3D12ShaderReflectionConstantBuffer*) GetConstantBufferByName(THIS_ _In_ LPCSTR Name);
  136. STDMETHODIMP GetResourceBindingDesc(THIS_ _In_ UINT ResourceIndex,
  137. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc);
  138. STDMETHODIMP GetInputParameterDesc(THIS_ _In_ UINT ParameterIndex,
  139. _Out_ D3D12_SIGNATURE_PARAMETER_DESC *pDesc);
  140. STDMETHODIMP GetOutputParameterDesc(THIS_ _In_ UINT ParameterIndex,
  141. _Out_ D3D12_SIGNATURE_PARAMETER_DESC *pDesc);
  142. STDMETHODIMP GetPatchConstantParameterDesc(THIS_ _In_ UINT ParameterIndex,
  143. _Out_ D3D12_SIGNATURE_PARAMETER_DESC *pDesc);
  144. STDMETHODIMP_(ID3D12ShaderReflectionVariable*) GetVariableByName(THIS_ _In_ LPCSTR Name);
  145. STDMETHODIMP GetResourceBindingDescByName(THIS_ _In_ LPCSTR Name,
  146. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc);
  147. STDMETHODIMP_(UINT) GetMovInstructionCount(THIS);
  148. STDMETHODIMP_(UINT) GetMovcInstructionCount(THIS);
  149. STDMETHODIMP_(UINT) GetConversionInstructionCount(THIS);
  150. STDMETHODIMP_(UINT) GetBitwiseInstructionCount(THIS);
  151. STDMETHODIMP_(D3D_PRIMITIVE) GetGSInputPrimitive(THIS);
  152. STDMETHODIMP_(BOOL) IsSampleFrequencyShader(THIS);
  153. STDMETHODIMP_(UINT) GetNumInterfaceSlots(THIS);
  154. STDMETHODIMP GetMinFeatureLevel(THIS_ _Out_ enum D3D_FEATURE_LEVEL* pLevel);
  155. STDMETHODIMP_(UINT) GetThreadGroupSize(THIS_
  156. _Out_opt_ UINT* pSizeX,
  157. _Out_opt_ UINT* pSizeY,
  158. _Out_opt_ UINT* pSizeZ);
  159. STDMETHODIMP_(UINT64) GetRequiresFlags(THIS);
  160. };
  161. class CFunctionReflection;
  162. class DxilLibraryReflection : public DxilModuleReflection, public ID3D12LibraryReflection {
  163. private:
  164. DXC_MICROCOM_TM_REF_FIELDS()
  165. // Storage, and function by name:
  166. typedef DenseMap<StringRef, std::unique_ptr<CFunctionReflection> > FunctionMap;
  167. typedef DenseMap<const Function*, CFunctionReflection*> FunctionsByPtr;
  168. FunctionMap m_FunctionMap;
  169. FunctionsByPtr m_FunctionsByPtr;
  170. // Enable indexing into functions in deterministic order:
  171. std::vector<CFunctionReflection*> m_FunctionVector;
  172. void AddResourceUseToFunctions(DxilResourceBase &resource, unsigned resIndex);
  173. void AddResourceDependencies();
  174. void SetCBufferUsage();
  175. public:
  176. DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
  177. DXC_MICROCOM_TM_CTOR(DxilLibraryReflection)
  178. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
  179. return DoBasicQueryInterface<ID3D12LibraryReflection>(this, iid, ppvObject);
  180. }
  181. HRESULT Load(IDxcBlob *pBlob, const DxilPartHeader *pPart);
  182. // ID3D12LibraryReflection
  183. STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_LIBRARY_DESC * pDesc);
  184. STDMETHOD_(ID3D12FunctionReflection *, GetFunctionByIndex)(THIS_ _In_ INT FunctionIndex);
  185. };
  186. _Use_decl_annotations_
  187. HRESULT DxilContainerReflection::Load(IDxcBlob *pContainer) {
  188. if (pContainer == nullptr) {
  189. m_container.Release();
  190. m_pHeader = nullptr;
  191. m_headerLen = 0;
  192. return S_OK;
  193. }
  194. uint32_t bufLen = pContainer->GetBufferSize();
  195. const DxilContainerHeader *pHeader =
  196. IsDxilContainerLike(pContainer->GetBufferPointer(), bufLen);
  197. if (pHeader == nullptr) {
  198. return E_INVALIDARG;
  199. }
  200. if (!IsValidDxilContainer(pHeader, bufLen)) {
  201. return E_INVALIDARG;
  202. }
  203. m_container = pContainer;
  204. m_headerLen = bufLen;
  205. m_pHeader = pHeader;
  206. return S_OK;
  207. }
  208. _Use_decl_annotations_
  209. HRESULT DxilContainerReflection::GetPartCount(UINT32 *pResult) {
  210. if (pResult == nullptr) return E_POINTER;
  211. if (!IsLoaded()) return E_NOT_VALID_STATE;
  212. *pResult = m_pHeader->PartCount;
  213. return S_OK;
  214. }
  215. _Use_decl_annotations_
  216. HRESULT DxilContainerReflection::GetPartKind(UINT32 idx, _Out_ UINT32 *pResult) {
  217. if (pResult == nullptr) return E_POINTER;
  218. if (!IsLoaded()) return E_NOT_VALID_STATE;
  219. if (idx >= m_pHeader->PartCount) return E_BOUNDS;
  220. const DxilPartHeader *pPart = GetDxilContainerPart(m_pHeader, idx);
  221. *pResult = pPart->PartFourCC;
  222. return S_OK;
  223. }
  224. _Use_decl_annotations_
  225. HRESULT DxilContainerReflection::GetPartContent(UINT32 idx, _COM_Outptr_ IDxcBlob **ppResult) {
  226. if (ppResult == nullptr) return E_POINTER;
  227. *ppResult = nullptr;
  228. if (!IsLoaded()) return E_NOT_VALID_STATE;
  229. if (idx >= m_pHeader->PartCount) return E_BOUNDS;
  230. const DxilPartHeader *pPart = GetDxilContainerPart(m_pHeader, idx);
  231. const char *pData = GetDxilPartData(pPart);
  232. uint32_t offset = (uint32_t)(pData - (char*)m_container->GetBufferPointer()); // Offset from the beginning.
  233. uint32_t length = pPart->PartSize;
  234. DxcThreadMalloc TM(m_pMalloc);
  235. return DxcCreateBlobFromBlob(m_container, offset, length, ppResult);
  236. }
  237. _Use_decl_annotations_
  238. HRESULT DxilContainerReflection::FindFirstPartKind(UINT32 kind, _Out_ UINT32 *pResult) {
  239. if (pResult == nullptr) return E_POINTER;
  240. *pResult = 0;
  241. if (!IsLoaded()) return E_NOT_VALID_STATE;
  242. DxilPartIterator it = std::find_if(begin(m_pHeader), end(m_pHeader), DxilPartIsType(kind));
  243. if (it == end(m_pHeader)) return HRESULT_FROM_WIN32(ERROR_NOT_FOUND);
  244. *pResult = it.index;
  245. return S_OK;
  246. }
  247. _Use_decl_annotations_
  248. HRESULT DxilContainerReflection::GetPartReflection(UINT32 idx, REFIID iid, void **ppvObject) {
  249. if (ppvObject == nullptr) return E_POINTER;
  250. *ppvObject = nullptr;
  251. if (!IsLoaded()) return E_NOT_VALID_STATE;
  252. if (idx >= m_pHeader->PartCount) return E_BOUNDS;
  253. const DxilPartHeader *pPart = GetDxilContainerPart(m_pHeader, idx);
  254. if (pPart->PartFourCC != DFCC_DXIL && pPart->PartFourCC != DFCC_ShaderDebugInfoDXIL) {
  255. return E_NOTIMPL;
  256. }
  257. DxcThreadMalloc TM(m_pMalloc);
  258. HRESULT hr = S_OK;
  259. const DxilProgramHeader *pProgramHeader =
  260. reinterpret_cast<const DxilProgramHeader*>(GetDxilPartData(pPart));
  261. if (!IsValidDxilProgramHeader(pProgramHeader, pPart->PartSize)) {
  262. return E_INVALIDARG;
  263. }
  264. DXIL::ShaderKind SK = GetVersionShaderType(pProgramHeader->ProgramVersion);
  265. if (SK == DXIL::ShaderKind::Library) {
  266. CComPtr<DxilLibraryReflection> pReflection = DxilLibraryReflection::Alloc(m_pMalloc);
  267. IFCOOM(pReflection.p);
  268. IFC(pReflection->Load(m_container, pPart));
  269. IFC(pReflection.p->QueryInterface(iid, ppvObject));
  270. } else {
  271. CComPtr<DxilShaderReflection> pReflection = DxilShaderReflection::Alloc(m_pMalloc);
  272. IFCOOM(pReflection.p);
  273. PublicAPI api = DxilShaderReflection::IIDToAPI(iid);
  274. pReflection->SetPublicAPI(api);
  275. IFC(pReflection->Load(m_container, pPart));
  276. IFC(pReflection.p->QueryInterface(iid, ppvObject));
  277. }
  278. Cleanup:
  279. return hr;
  280. }
  281. void hlsl::CreateDxcContainerReflection(IDxcContainerReflection **ppResult) {
  282. CComPtr<DxilContainerReflection> pReflection = DxilContainerReflection::Alloc(DxcGetThreadMallocNoRef());
  283. *ppResult = pReflection.Detach();
  284. if (*ppResult == nullptr) throw std::bad_alloc();
  285. }
  286. ///////////////////////////////////////////////////////////////////////////////
  287. // DxilShaderReflection implementation - helper objects. //
  288. class CShaderReflectionType;
  289. class CShaderReflectionVariable;
  290. class CShaderReflectionConstantBuffer;
  291. class CShaderReflection;
  292. struct D3D11_INTERNALSHADER_RESOURCE_DEF;
  293. class CShaderReflectionType : public ID3D12ShaderReflectionType
  294. {
  295. protected:
  296. D3D12_SHADER_TYPE_DESC m_Desc;
  297. std::string m_Name;
  298. std::vector<StringRef> m_MemberNames;
  299. std::vector<CShaderReflectionType*> m_MemberTypes;
  300. CShaderReflectionType* m_pSubType;
  301. CShaderReflectionType* m_pBaseClass;
  302. std::vector<CShaderReflectionType*> m_Interfaces;
  303. ULONG_PTR m_Identity;
  304. public:
  305. // Internal
  306. HRESULT Initialize(
  307. DxilModule &M,
  308. llvm::Type *type,
  309. DxilFieldAnnotation &typeAnnotation,
  310. unsigned int baseOffset,
  311. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
  312. // ID3D12ShaderReflectionType
  313. STDMETHOD(GetDesc)(D3D12_SHADER_TYPE_DESC *pDesc);
  314. STDMETHOD_(ID3D12ShaderReflectionType*, GetMemberTypeByIndex)(UINT Index);
  315. STDMETHOD_(ID3D12ShaderReflectionType*, GetMemberTypeByName)(LPCSTR Name);
  316. STDMETHOD_(LPCSTR, GetMemberTypeName)(UINT Index);
  317. STDMETHOD(IsEqual)(THIS_ ID3D12ShaderReflectionType* pType);
  318. STDMETHOD_(ID3D12ShaderReflectionType*, GetSubType)(THIS);
  319. STDMETHOD_(ID3D12ShaderReflectionType*, GetBaseClass)(THIS);
  320. STDMETHOD_(UINT, GetNumInterfaces)(THIS);
  321. STDMETHOD_(ID3D12ShaderReflectionType*, GetInterfaceByIndex)(THIS_ UINT uIndex);
  322. STDMETHOD(IsOfType)(THIS_ ID3D12ShaderReflectionType* pType);
  323. STDMETHOD(ImplementsInterface)(THIS_ ID3D12ShaderReflectionType* pBase);
  324. bool CheckEqual(_In_ CShaderReflectionType *pOther) {
  325. return m_Identity == pOther->m_Identity;
  326. }
  327. };
  328. class CShaderReflectionVariable : public ID3D12ShaderReflectionVariable
  329. {
  330. protected:
  331. D3D12_SHADER_VARIABLE_DESC m_Desc;
  332. CShaderReflectionType *m_pType;
  333. CShaderReflectionConstantBuffer *m_pBuffer;
  334. BYTE *m_pDefaultValue;
  335. public:
  336. void Initialize(CShaderReflectionConstantBuffer *pBuffer,
  337. D3D12_SHADER_VARIABLE_DESC *pDesc,
  338. CShaderReflectionType *pType, BYTE *pDefaultValue);
  339. LPCSTR GetName() { return m_Desc.Name; }
  340. // ID3D12ShaderReflectionVariable
  341. STDMETHOD(GetDesc)(D3D12_SHADER_VARIABLE_DESC *pDesc);
  342. STDMETHOD_(ID3D12ShaderReflectionType*, GetType)();
  343. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer*, GetBuffer)();
  344. STDMETHOD_(UINT, GetInterfaceSlot)(THIS_ UINT uArrayIndex);
  345. };
  346. class CShaderReflectionConstantBuffer : public ID3D12ShaderReflectionConstantBuffer
  347. {
  348. protected:
  349. D3D12_SHADER_BUFFER_DESC m_Desc;
  350. std::vector<CShaderReflectionVariable> m_Variables;
  351. public:
  352. CShaderReflectionConstantBuffer() = default;
  353. CShaderReflectionConstantBuffer(CShaderReflectionConstantBuffer &&other) {
  354. m_Desc = other.m_Desc;
  355. std::swap(m_Variables, other.m_Variables);
  356. }
  357. void Initialize(DxilModule &M,
  358. DxilCBuffer &CB,
  359. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
  360. void InitializeStructuredBuffer(DxilModule &M,
  361. DxilResource &R,
  362. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes);
  363. LPCSTR GetName() { return m_Desc.Name; }
  364. // ID3D12ShaderReflectionConstantBuffer
  365. STDMETHOD(GetDesc)(D3D12_SHADER_BUFFER_DESC *pDesc);
  366. STDMETHOD_(ID3D12ShaderReflectionVariable*, GetVariableByIndex)(UINT Index);
  367. STDMETHOD_(ID3D12ShaderReflectionVariable*, GetVariableByName)(LPCSTR Name);
  368. };
  369. // Invalid type sentinel definitions
  370. class CInvalidSRType;
  371. class CInvalidSRVariable;
  372. class CInvalidSRConstantBuffer;
  373. class CInvalidSRLibraryFunction;
  374. class CInvalidSRFunctionParameter;
  375. class CInvalidSRType : public ID3D12ShaderReflectionType {
  376. STDMETHOD(GetDesc)(D3D12_SHADER_TYPE_DESC *pDesc) { return E_FAIL; }
  377. STDMETHOD_(ID3D12ShaderReflectionType*, GetMemberTypeByIndex)(UINT Index);
  378. STDMETHOD_(ID3D12ShaderReflectionType*, GetMemberTypeByName)(LPCSTR Name);
  379. STDMETHOD_(LPCSTR, GetMemberTypeName)(UINT Index) { return "$Invalid"; }
  380. STDMETHOD(IsEqual)(THIS_ ID3D12ShaderReflectionType* pType) { return E_FAIL; }
  381. STDMETHOD_(ID3D12ShaderReflectionType*, GetSubType)(THIS);
  382. STDMETHOD_(ID3D12ShaderReflectionType*, GetBaseClass)(THIS);
  383. STDMETHOD_(UINT, GetNumInterfaces)(THIS) { return 0; }
  384. STDMETHOD_(ID3D12ShaderReflectionType*, GetInterfaceByIndex)(THIS_ UINT uIndex);
  385. STDMETHOD(IsOfType)(THIS_ ID3D12ShaderReflectionType* pType) { return E_FAIL; }
  386. STDMETHOD(ImplementsInterface)(THIS_ ID3D12ShaderReflectionType* pBase) { return E_FAIL; }
  387. };
  388. static CInvalidSRType g_InvalidSRType;
  389. ID3D12ShaderReflectionType* CInvalidSRType::GetMemberTypeByIndex(UINT) { return &g_InvalidSRType; }
  390. ID3D12ShaderReflectionType* CInvalidSRType::GetMemberTypeByName(LPCSTR) { return &g_InvalidSRType; }
  391. ID3D12ShaderReflectionType* CInvalidSRType::GetSubType() { return &g_InvalidSRType; }
  392. ID3D12ShaderReflectionType* CInvalidSRType::GetBaseClass() { return &g_InvalidSRType; }
  393. ID3D12ShaderReflectionType* CInvalidSRType::GetInterfaceByIndex(UINT) { return &g_InvalidSRType; }
  394. class CInvalidSRVariable : public ID3D12ShaderReflectionVariable {
  395. STDMETHOD(GetDesc)(D3D12_SHADER_VARIABLE_DESC *pDesc) { return E_FAIL; }
  396. STDMETHOD_(ID3D12ShaderReflectionType*, GetType)() { return &g_InvalidSRType; }
  397. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer*, GetBuffer)();
  398. STDMETHOD_(UINT, GetInterfaceSlot)(THIS_ UINT uIndex) { return UINT_MAX; }
  399. };
  400. static CInvalidSRVariable g_InvalidSRVariable;
  401. class CInvalidSRConstantBuffer : public ID3D12ShaderReflectionConstantBuffer {
  402. STDMETHOD(GetDesc)(D3D12_SHADER_BUFFER_DESC *pDesc) { return E_FAIL; }
  403. STDMETHOD_(ID3D12ShaderReflectionVariable*, GetVariableByIndex)(UINT Index) { return &g_InvalidSRVariable; }
  404. STDMETHOD_(ID3D12ShaderReflectionVariable*, GetVariableByName)(LPCSTR Name) { return &g_InvalidSRVariable; }
  405. };
  406. static CInvalidSRConstantBuffer g_InvalidSRConstantBuffer;
  407. class CInvalidFunctionParameter : public ID3D12FunctionParameterReflection {
  408. STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_PARAMETER_DESC * pDesc) { return E_FAIL; }
  409. };
  410. CInvalidFunctionParameter g_InvalidFunctionParameter;
  411. class CInvalidFunction : public ID3D12FunctionReflection {
  412. STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_FUNCTION_DESC * pDesc) { return E_FAIL; }
  413. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer *, GetConstantBufferByIndex)(THIS_ _In_ UINT BufferIndex) { return &g_InvalidSRConstantBuffer; }
  414. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer *, GetConstantBufferByName)(THIS_ _In_ LPCSTR Name) { return &g_InvalidSRConstantBuffer; }
  415. STDMETHOD(GetResourceBindingDesc)(THIS_ _In_ UINT ResourceIndex,
  416. _Out_ D3D12_SHADER_INPUT_BIND_DESC * pDesc) { return E_FAIL; }
  417. STDMETHOD_(ID3D12ShaderReflectionVariable *, GetVariableByName)(THIS_ _In_ LPCSTR Name) { return nullptr; }
  418. STDMETHOD(GetResourceBindingDescByName)(THIS_ _In_ LPCSTR Name,
  419. _Out_ D3D12_SHADER_INPUT_BIND_DESC * pDesc) { return E_FAIL; }
  420. // Use D3D_RETURN_PARAMETER_INDEX to get description of the return value.
  421. STDMETHOD_(ID3D12FunctionParameterReflection *, GetFunctionParameter)(THIS_ _In_ INT ParameterIndex) { return &g_InvalidFunctionParameter; }
  422. };
  423. CInvalidFunction g_InvalidFunction;
  424. void CShaderReflectionVariable::Initialize(
  425. CShaderReflectionConstantBuffer *pBuffer, D3D12_SHADER_VARIABLE_DESC *pDesc,
  426. CShaderReflectionType *pType, BYTE *pDefaultValue) {
  427. m_pBuffer = pBuffer;
  428. memcpy(&m_Desc, pDesc, sizeof(m_Desc));
  429. m_pType = pType;
  430. m_pDefaultValue = pDefaultValue;
  431. }
  432. HRESULT CShaderReflectionVariable::GetDesc(D3D12_SHADER_VARIABLE_DESC *pDesc) {
  433. if (!pDesc) return E_POINTER;
  434. memcpy(pDesc, &m_Desc, sizeof(m_Desc));
  435. return S_OK;
  436. }
  437. ID3D12ShaderReflectionType *CShaderReflectionVariable::GetType() {
  438. return m_pType;
  439. }
  440. ID3D12ShaderReflectionConstantBuffer *CShaderReflectionVariable::GetBuffer() {
  441. return m_pBuffer;
  442. }
  443. UINT CShaderReflectionVariable::GetInterfaceSlot(UINT uArrayIndex) {
  444. return UINT_MAX;
  445. }
  446. ID3D12ShaderReflectionConstantBuffer *CInvalidSRVariable::GetBuffer() {
  447. return &g_InvalidSRConstantBuffer;
  448. }
  449. STDMETHODIMP CShaderReflectionType::GetDesc(D3D12_SHADER_TYPE_DESC *pDesc)
  450. {
  451. if (!pDesc) return E_POINTER;
  452. memcpy(pDesc, &m_Desc, sizeof(m_Desc));
  453. return S_OK;
  454. }
  455. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetMemberTypeByIndex(UINT Index)
  456. {
  457. if (Index >= m_MemberTypes.size()) {
  458. return &g_InvalidSRType;
  459. }
  460. return m_MemberTypes[Index];
  461. }
  462. STDMETHODIMP_(LPCSTR) CShaderReflectionType::GetMemberTypeName(UINT Index)
  463. {
  464. if (Index >= m_MemberTypes.size()) {
  465. return nullptr;
  466. }
  467. return (LPCSTR) m_MemberNames[Index].bytes_begin();
  468. }
  469. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetMemberTypeByName(LPCSTR Name)
  470. {
  471. UINT memberCount = m_Desc.Members;
  472. for( UINT mm = 0; mm < memberCount; ++mm ) {
  473. if( m_MemberNames[mm] == Name ) {
  474. return m_MemberTypes[mm];
  475. }
  476. }
  477. return nullptr;
  478. }
  479. STDMETHODIMP CShaderReflectionType::IsEqual(THIS_ ID3D12ShaderReflectionType* pType)
  480. {
  481. // TODO: implement this check, if users actually depend on it
  482. return S_FALSE;
  483. }
  484. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetSubType(THIS)
  485. {
  486. // TODO: implement `class`-related features, if requested
  487. return nullptr;
  488. }
  489. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetBaseClass(THIS)
  490. {
  491. // TODO: implement `class`-related features, if requested
  492. return nullptr;
  493. }
  494. STDMETHODIMP_(UINT) CShaderReflectionType::GetNumInterfaces(THIS)
  495. {
  496. // HLSL interfaces have been deprecated
  497. return 0;
  498. }
  499. STDMETHODIMP_(ID3D12ShaderReflectionType*) CShaderReflectionType::GetInterfaceByIndex(THIS_ UINT uIndex)
  500. {
  501. // HLSL interfaces have been deprecated
  502. return nullptr;
  503. }
  504. STDMETHODIMP CShaderReflectionType::IsOfType(THIS_ ID3D12ShaderReflectionType* pType)
  505. {
  506. // TODO: implement `class`-related features, if requested
  507. return S_FALSE;
  508. }
  509. STDMETHODIMP CShaderReflectionType::ImplementsInterface(THIS_ ID3D12ShaderReflectionType* pBase)
  510. {
  511. // HLSL interfaces have been deprecated
  512. return S_FALSE;
  513. }
  514. // Helper routine for types that don't have an obvious mapping
  515. // to the existing shader reflection interface.
  516. static bool ProcessUnhandledObjectType(
  517. llvm::StructType *structType,
  518. D3D_SHADER_VARIABLE_TYPE *outObjectType)
  519. {
  520. // Don't actually make this a hard error, but instead report the problem using a suitable debug message.
  521. #ifdef DBG
  522. OutputDebugFormatA("DxilContainerReflection.cpp: error: unhandled object type '%s'.\n", structType->getName().str().c_str());
  523. #endif
  524. *outObjectType = D3D_SVT_VOID;
  525. return true;
  526. }
  527. // Helper routine to try to detect if a type represents an HLSL "object" type
  528. // (a texture, sampler, buffer, etc.), and to extract the coresponding shader
  529. // reflection type.
  530. static bool TryToDetectObjectType(
  531. llvm::StructType *structType,
  532. D3D_SHADER_VARIABLE_TYPE *outObjectType)
  533. {
  534. // Note: This logic is largely duplicated from `HLModule::IsHLSLObjectType`
  535. // with the addition of returning the appropriate reflection type tag.
  536. //
  537. // That logic looks error-prone, since it relies on string tests against
  538. // type names, including cases that just test against a prefix.
  539. // This code doesn't try to be any more robust.
  540. StringRef name = structType->getName();
  541. if(name.startswith("dx.types.wave_t") )
  542. {
  543. return ProcessUnhandledObjectType(structType, outObjectType);
  544. }
  545. // Strip off some prefixes we are likely to see.
  546. name = name.ltrim("class.");
  547. name = name.ltrim("struct.");
  548. // Slice types occur as intermediates (they aren not objects)
  549. if(name.endswith("_slice_type")) { return false; }
  550. // We might check for an exact name match, or a prefix match
  551. #define EXACT_MATCH(NAME, TAG) \
  552. else if(name == #NAME) do { *outObjectType = TAG; return true; } while(0)
  553. #define PREFIX_MATCH(NAME, TAG) \
  554. else if(name.startswith(#NAME)) do { *outObjectType = TAG; return true; } while(0)
  555. if(0) {}
  556. EXACT_MATCH(SamplerState, D3D_SVT_SAMPLER);
  557. EXACT_MATCH(SamplerComparisonState, D3D_SVT_SAMPLER);
  558. // Note: GS output stream types are supported in the reflection interface.
  559. else if(name.startswith("TriangleStream")) { return ProcessUnhandledObjectType(structType, outObjectType); }
  560. else if(name.startswith("PointStream")) { return ProcessUnhandledObjectType(structType, outObjectType); }
  561. else if(name.startswith("LineStream")) { return ProcessUnhandledObjectType(structType, outObjectType); }
  562. PREFIX_MATCH(AppendStructuredBuffer, D3D_SVT_APPEND_STRUCTURED_BUFFER);
  563. PREFIX_MATCH(ConsumeStructuredBuffer, D3D_SVT_CONSUME_STRUCTURED_BUFFER);
  564. PREFIX_MATCH(ConstantBuffer, D3D_SVT_CBUFFER);
  565. // Note: the `HLModule` code does this trick to avoid checking more names
  566. // than it has to, but it doesn't seem 100% correct to do this.
  567. // TODO: consider just listing the `RasterizerOrdered` cases explicitly,
  568. // just as we do for the `RW` cases already.
  569. name = name.ltrim("RasterizerOrdered");
  570. if(0) {}
  571. EXACT_MATCH(ByteAddressBuffer, D3D_SVT_BYTEADDRESS_BUFFER);
  572. EXACT_MATCH(RWByteAddressBuffer, D3D_SVT_RWBYTEADDRESS_BUFFER);
  573. PREFIX_MATCH(Buffer, D3D_SVT_BUFFER);
  574. PREFIX_MATCH(RWBuffer, D3D_SVT_RWBUFFER);
  575. PREFIX_MATCH(StructuredBuffer, D3D_SVT_STRUCTURED_BUFFER);
  576. PREFIX_MATCH(RWStructuredBuffer, D3D_SVT_RWSTRUCTURED_BUFFER);
  577. PREFIX_MATCH(Texture1D, D3D_SVT_TEXTURE1D);
  578. PREFIX_MATCH(RWTexture1D, D3D_SVT_RWTEXTURE1D);
  579. PREFIX_MATCH(Texture1DArray, D3D_SVT_TEXTURE1DARRAY);
  580. PREFIX_MATCH(RWTexture1DArray, D3D_SVT_RWTEXTURE1DARRAY);
  581. PREFIX_MATCH(Texture2D, D3D_SVT_TEXTURE2D);
  582. PREFIX_MATCH(RWTexture2D, D3D_SVT_RWTEXTURE2D);
  583. PREFIX_MATCH(Texture2DArray, D3D_SVT_TEXTURE2DARRAY);
  584. PREFIX_MATCH(RWTexture2DArray, D3D_SVT_RWTEXTURE2DARRAY);
  585. PREFIX_MATCH(Texture3D, D3D_SVT_TEXTURE3D);
  586. PREFIX_MATCH(RWTexture3D, D3D_SVT_RWTEXTURE3D);
  587. PREFIX_MATCH(TextureCube, D3D_SVT_TEXTURECUBE);
  588. PREFIX_MATCH(TextureCubeArray, D3D_SVT_TEXTURECUBEARRAY);
  589. PREFIX_MATCH(Texture2DMS, D3D_SVT_TEXTURE2DMS);
  590. PREFIX_MATCH(Texture2DMSArray, D3D_SVT_TEXTURE2DMSARRAY);
  591. #undef EXACT_MATCH
  592. #undef PREFIX_MATCH
  593. // Default: not an object type
  594. return false;
  595. }
  596. // Helper to determine if an LLVM type represents an HLSL
  597. // object type (uses the `TryToDetectObjectType()` function
  598. // defined previously).
  599. static bool IsObjectType(
  600. llvm::Type* inType)
  601. {
  602. llvm::Type* type = inType;
  603. while(type->isArrayTy())
  604. {
  605. type = type->getArrayElementType();
  606. }
  607. llvm::StructType* structType = dyn_cast<StructType>(type);
  608. if(!structType)
  609. return false;
  610. D3D_SHADER_VARIABLE_TYPE ignored;
  611. return TryToDetectObjectType(structType, &ignored);
  612. }
  613. // Main logic for translating an LLVM type and associated
  614. // annotations into a D3D shader reflection type.
  615. HRESULT CShaderReflectionType::Initialize(
  616. DxilModule &M,
  617. llvm::Type *inType,
  618. DxilFieldAnnotation &typeAnnotation,
  619. unsigned int baseOffset,
  620. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes)
  621. {
  622. DXASSERT_NOMSG(inType);
  623. // Set a bunch of fields to default values, to avoid duplication.
  624. m_Desc.Rows = 0;
  625. m_Desc.Columns = 0;
  626. m_Desc.Elements = 0;
  627. m_Desc.Members = 0;
  628. // Extract offset relative to parent.
  629. // Note: the `baseOffset` is used in the case where the type in
  630. // question is a field in a constant buffer, since then both the
  631. // field and the variable store the same offset information, and
  632. // we need to zero out the value in the type to avoid the user
  633. // of the reflection interface seeing 2x the correct value.
  634. m_Desc.Offset = typeAnnotation.GetCBufferOffset() - baseOffset;
  635. // Arrays don't seem to be represented directly in the reflection
  636. // data, but only as the `Elements` field being non-zero.
  637. // We "unwrap" any array type here, and then proceed to look
  638. // at the element type.
  639. llvm::Type* type = inType;
  640. while(type->isArrayTy())
  641. {
  642. llvm::Type* elementType = type->getArrayElementType();
  643. // Note: At this point an HLSL matrix type may appear as an ordinary
  644. // array (not wrapped in a `struct`), so `HLMatrixLower::IsMatrixType()`
  645. // is not sufficient. Instead we need to check the field annotation.
  646. //
  647. // We might have an array of matrices, though, so we only exit if
  648. // the field annotation says we have a matrix, and we've bottomed
  649. // out and the element type isn't itself an array.
  650. if(typeAnnotation.HasMatrixAnnotation() && !elementType->isArrayTy())
  651. {
  652. break;
  653. }
  654. // Non-array types should have `Elements` be zero, so as soon as we
  655. // find that we have our first real array (not a matrix), we initialize `Elements`
  656. if(!m_Desc.Elements) m_Desc.Elements = 1;
  657. // It isn't clear what is the desired behavior for multi-dimensional arrays,
  658. // but for now we do the expedient thing of multiplying out all their
  659. // dimensions.
  660. m_Desc.Elements *= type->getArrayNumElements();
  661. type = elementType;
  662. }
  663. // Default to a scalar type, just to avoid some duplication later.
  664. m_Desc.Class = D3D_SVC_SCALAR;
  665. // Look at the annotation to try to determine the basic type of value.
  666. //
  667. // Note that DXIL supports some types that don't currently have equivalents
  668. // in the reflection interface, so we try to muddle through here.
  669. D3D_SHADER_VARIABLE_TYPE componentType = D3D_SVT_VOID;
  670. switch(typeAnnotation.GetCompType().GetKind())
  671. {
  672. case hlsl::DXIL::ComponentType::Invalid:
  673. break;
  674. case hlsl::DXIL::ComponentType::I1:
  675. componentType = D3D_SVT_BOOL;
  676. m_Name = "bool";
  677. break;
  678. case hlsl::DXIL::ComponentType::I16:
  679. componentType = D3D_SVT_MIN16INT;
  680. m_Name = "min16int";
  681. break;
  682. case hlsl::DXIL::ComponentType::U16:
  683. componentType = D3D_SVT_MIN16UINT;
  684. m_Name = "min16uint";
  685. break;
  686. case hlsl::DXIL::ComponentType::I64:
  687. #ifdef DBG
  688. OutputDebugStringA("DxilContainerReflection.cpp: warning: component of type 'I64' being reflected as if 'I32'\n");
  689. #endif
  690. case hlsl::DXIL::ComponentType::I32:
  691. componentType = D3D_SVT_INT;
  692. m_Name = "int";
  693. break;
  694. case hlsl::DXIL::ComponentType::U64:
  695. #ifdef DBG
  696. OutputDebugStringA("DxilContainerReflection.cpp: warning: component of type 'U64' being reflected as if 'U32'\n");
  697. #endif
  698. case hlsl::DXIL::ComponentType::U32:
  699. componentType = D3D_SVT_UINT;
  700. m_Name = "uint";
  701. break;
  702. case hlsl::DXIL::ComponentType::F16:
  703. case hlsl::DXIL::ComponentType::SNormF16:
  704. case hlsl::DXIL::ComponentType::UNormF16:
  705. componentType = D3D_SVT_MIN16FLOAT;
  706. m_Name = "min16float";
  707. break;
  708. case hlsl::DXIL::ComponentType::F32:
  709. case hlsl::DXIL::ComponentType::SNormF32:
  710. case hlsl::DXIL::ComponentType::UNormF32:
  711. componentType = D3D_SVT_FLOAT;
  712. m_Name = "float";
  713. break;
  714. case hlsl::DXIL::ComponentType::F64:
  715. case hlsl::DXIL::ComponentType::SNormF64:
  716. case hlsl::DXIL::ComponentType::UNormF64:
  717. componentType = D3D_SVT_DOUBLE;
  718. m_Name = "double";
  719. break;
  720. default:
  721. #ifdef DBG
  722. OutputDebugStringA("DxilContainerReflection.cpp: error: unknown component type\n");
  723. #endif
  724. break;
  725. }
  726. m_Desc.Type = componentType;
  727. // A matrix type is encoded as a vector type, plus annotations, so we
  728. // need to check for this case before other vector cases.
  729. if(typeAnnotation.HasMatrixAnnotation())
  730. {
  731. // We can extract the details from the annotation.
  732. DxilMatrixAnnotation const& matrixAnnotation = typeAnnotation.GetMatrixAnnotation();
  733. switch(matrixAnnotation.Orientation)
  734. {
  735. default:
  736. #ifdef DBG
  737. OutputDebugStringA("DxilContainerReflection.cpp: error: unknown matrix orientation\n");
  738. #endif
  739. // Note: column-major layout is the default
  740. case hlsl::MatrixOrientation::Undefined:
  741. case hlsl::MatrixOrientation::ColumnMajor:
  742. m_Desc.Class = D3D_SVC_MATRIX_COLUMNS;
  743. break;
  744. case hlsl::MatrixOrientation::RowMajor:
  745. m_Desc.Class = D3D_SVC_MATRIX_ROWS;
  746. break;
  747. }
  748. m_Desc.Rows = matrixAnnotation.Rows;
  749. m_Desc.Columns = matrixAnnotation.Cols;
  750. m_Name += std::to_string(matrixAnnotation.Rows) + "x" + std::to_string(matrixAnnotation.Cols);
  751. }
  752. else if( type->isVectorTy() )
  753. {
  754. // We assume that LLVM vectors either represent matrices (handled above)
  755. // or HLSL vectors.
  756. //
  757. // Note: the reflection interface encodes an N-vector as if it had 1 row
  758. // and N columns.
  759. m_Desc.Class = D3D_SVC_VECTOR;
  760. m_Desc.Rows = 1;
  761. m_Desc.Columns = type->getVectorNumElements();
  762. m_Name += std::to_string(type->getVectorNumElements());
  763. }
  764. else if( type->isStructTy() )
  765. {
  766. // A struct type might be an ordinary user-defined `struct`,
  767. // or one of the builtin in HLSL "object" types.
  768. StructType *structType = cast<StructType>(type);
  769. // We use our function to try to detect an object type
  770. // based on its name.
  771. if(TryToDetectObjectType(structType, &m_Desc.Type))
  772. {
  773. m_Desc.Class = D3D_SVC_OBJECT;
  774. }
  775. else
  776. {
  777. // Otherwise we have a struct and need to recurse on its fields.
  778. m_Desc.Class = D3D_SVC_STRUCT;
  779. m_Desc.Rows = 1;
  780. // Try to "clean" the type name for use in reflection data
  781. llvm::StringRef name = structType->getName();
  782. name = name.ltrim("dx.alignment.legacy.");
  783. name = name.ltrim("struct.");
  784. m_Name = name;
  785. // Fields may have annotations, and we need to look at these
  786. // in order to decode their types properly.
  787. DxilTypeSystem &typeSys = M.GetTypeSystem();
  788. DxilStructAnnotation *structAnnotation = typeSys.GetStructAnnotation(structType);
  789. // There is no annotation for empty structs
  790. unsigned int fieldCount = 0;
  791. if (structAnnotation)
  792. fieldCount = type->getStructNumElements();
  793. // The DXBC reflection info computes `Columns` for a
  794. // `struct` type from the fields (see below)
  795. UINT columnCounter = 0;
  796. for(unsigned int ff = 0; ff < fieldCount; ++ff)
  797. {
  798. DxilFieldAnnotation& fieldAnnotation = structAnnotation->GetFieldAnnotation(ff);
  799. llvm::Type* fieldType = structType->getStructElementType(ff);
  800. // Skip fields with object types, since applications may not expect to see them here.
  801. //
  802. // TODO: should skipping be context-dependent, since we might not be inside
  803. // a constant buffer?
  804. if( IsObjectType(fieldType) )
  805. {
  806. continue;
  807. }
  808. CShaderReflectionType *fieldReflectionType = new CShaderReflectionType();
  809. allTypes.push_back(std::unique_ptr<CShaderReflectionType>(fieldReflectionType));
  810. fieldReflectionType->Initialize(M, fieldType, fieldAnnotation, 0, allTypes);
  811. m_MemberTypes.push_back(fieldReflectionType);
  812. m_MemberNames.push_back(fieldAnnotation.GetFieldName().c_str());
  813. // Effectively, we want to add one to `Columns` for every scalar nested recursively
  814. // inside this `struct` type (ignoring objects, which we filtered above). We should
  815. // be able to compute this as the product of the `Columns`, `Rows` and `Elements`
  816. // of each field, with the caveat that some of these may be zero, but shoud be
  817. // treated as one.
  818. columnCounter +=
  819. (fieldReflectionType->m_Desc.Columns ? fieldReflectionType->m_Desc.Columns : 1)
  820. * (fieldReflectionType->m_Desc.Rows ? fieldReflectionType->m_Desc.Rows : 1)
  821. * (fieldReflectionType->m_Desc.Elements ? fieldReflectionType->m_Desc.Elements : 1);
  822. }
  823. m_Desc.Columns = columnCounter;
  824. // Because we might have skipped fields during enumeration,
  825. // the `Members` count in the description might not be the same
  826. // as the field count of the original LLVM type.
  827. m_Desc.Members = m_MemberTypes.size();
  828. }
  829. }
  830. else if( type->isPointerTy() )
  831. {
  832. #ifdef DBG
  833. OutputDebugStringA("DxilContainerReflection.cpp: error: cannot reflect pointer type\n");
  834. #endif
  835. }
  836. else if( type->isVoidTy() )
  837. {
  838. // Name for `void` wasn't handle in the component-type `switch` above
  839. m_Name = "void";
  840. m_Desc.Class = D3D_SVC_SCALAR;
  841. m_Desc.Rows = 1;
  842. m_Desc.Columns = 1;
  843. }
  844. else
  845. {
  846. // Assume we have a scalar at this point.
  847. m_Desc.Class = D3D_SVC_SCALAR;
  848. m_Desc.Rows = 1;
  849. m_Desc.Columns = 1;
  850. // Special-case naming
  851. switch(m_Desc.Type)
  852. {
  853. default:
  854. break;
  855. case D3D_SVT_UINT:
  856. // Scalar `uint` gets reflected as `dword`, while vectors/matrices use `uint`...
  857. m_Name = "dword";
  858. break;
  859. }
  860. }
  861. // TODO: are there other cases to be handled?
  862. m_Desc.Name = m_Name.c_str();
  863. return S_OK;
  864. }
  865. void CShaderReflectionConstantBuffer::Initialize(
  866. DxilModule &M,
  867. DxilCBuffer &CB,
  868. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes) {
  869. ZeroMemory(&m_Desc, sizeof(m_Desc));
  870. m_Desc.Name = CB.GetGlobalName().c_str();
  871. m_Desc.Size = CB.GetSize() / CB.GetRangeSize();
  872. m_Desc.Size = (m_Desc.Size + 0x0f) & ~(0x0f); // Round up to 16 bytes for reflection.
  873. m_Desc.Type = D3D_CT_CBUFFER;
  874. m_Desc.uFlags = 0;
  875. Type *Ty = CB.GetGlobalSymbol()->getType()->getPointerElementType();
  876. // For ConstantBuffer<> buf[2], the array size is in Resource binding count
  877. // part.
  878. if (Ty->isArrayTy())
  879. Ty = Ty->getArrayElementType();
  880. DxilTypeSystem &typeSys = M.GetTypeSystem();
  881. StructType *ST = cast<StructType>(Ty);
  882. DxilStructAnnotation *annotation =
  883. typeSys.GetStructAnnotation(cast<StructType>(ST));
  884. // Dxil from dxbc doesn't have annotation.
  885. if (!annotation)
  886. return;
  887. m_Desc.Variables = ST->getNumContainedTypes();
  888. unsigned lastIndex = ST->getNumContainedTypes() - 1;
  889. for (unsigned i = 0; i < ST->getNumContainedTypes(); ++i) {
  890. DxilFieldAnnotation &fieldAnnotation = annotation->GetFieldAnnotation(i);
  891. D3D12_SHADER_VARIABLE_DESC VarDesc;
  892. ZeroMemory(&VarDesc, sizeof(VarDesc));
  893. VarDesc.uFlags |= D3D_SVF_USED; // Will update in SetCBufferUsage.
  894. CShaderReflectionVariable Var;
  895. //Create reflection type.
  896. CShaderReflectionType *pVarType = new CShaderReflectionType();
  897. allTypes.push_back(std::unique_ptr<CShaderReflectionType>(pVarType));
  898. pVarType->Initialize(M, ST->getContainedType(i), fieldAnnotation, fieldAnnotation.GetCBufferOffset(), allTypes);
  899. BYTE *pDefaultValue = nullptr;
  900. VarDesc.Name = fieldAnnotation.GetFieldName().c_str();
  901. VarDesc.StartOffset = fieldAnnotation.GetCBufferOffset();
  902. if (i < lastIndex) {
  903. DxilFieldAnnotation &nextFieldAnnotation =
  904. annotation->GetFieldAnnotation(i + 1);
  905. VarDesc.Size = nextFieldAnnotation.GetCBufferOffset() - fieldAnnotation.GetCBufferOffset();
  906. }
  907. else {
  908. VarDesc.Size = CB.GetSize() - fieldAnnotation.GetCBufferOffset();
  909. }
  910. Var.Initialize(this, &VarDesc, pVarType, pDefaultValue);
  911. m_Variables.push_back(Var);
  912. }
  913. }
  914. static unsigned CalcTypeSize(Type *Ty) {
  915. // Assume aligned values.
  916. if (Ty->isIntegerTy() || Ty->isFloatTy()) {
  917. return Ty->getPrimitiveSizeInBits() / 8;
  918. }
  919. else if (Ty->isArrayTy()) {
  920. ArrayType *AT = dyn_cast<ArrayType>(Ty);
  921. return AT->getNumElements() * CalcTypeSize(AT->getArrayElementType());
  922. }
  923. else if (Ty->isStructTy()) {
  924. StructType *ST = dyn_cast<StructType>(Ty);
  925. unsigned i = 0, c = ST->getStructNumElements();
  926. unsigned result = 0;
  927. for (; i < c; ++i) {
  928. result += CalcTypeSize(ST->getStructElementType(i));
  929. // TODO: align!
  930. }
  931. return result;
  932. }
  933. else if (Ty->isVectorTy()) {
  934. VectorType *VT = dyn_cast<VectorType>(Ty);
  935. return VT->getVectorNumElements() * CalcTypeSize(VT->getVectorElementType());
  936. }
  937. else {
  938. DXASSERT_NOMSG(false);
  939. return 0;
  940. }
  941. }
  942. static unsigned CalcResTypeSize(DxilModule &M, DxilResource &R) {
  943. UNREFERENCED_PARAMETER(M);
  944. Type *Ty = R.GetGlobalSymbol()->getType()->getPointerElementType();
  945. return CalcTypeSize(Ty);
  946. }
  947. void CShaderReflectionConstantBuffer::InitializeStructuredBuffer(
  948. DxilModule &M,
  949. DxilResource &R,
  950. std::vector<std::unique_ptr<CShaderReflectionType>>& allTypes) {
  951. ZeroMemory(&m_Desc, sizeof(m_Desc));
  952. m_Desc.Name = R.GetGlobalName().c_str();
  953. //m_Desc.Size = R.GetSize();
  954. m_Desc.Type = D3D11_CT_RESOURCE_BIND_INFO;
  955. m_Desc.uFlags = 0;
  956. m_Desc.Variables = 1;
  957. D3D12_SHADER_VARIABLE_DESC VarDesc;
  958. ZeroMemory(&VarDesc, sizeof(VarDesc));
  959. VarDesc.Name = "$Element";
  960. VarDesc.Size = CalcResTypeSize(M, R); // aligned bytes
  961. VarDesc.StartTexture = UINT_MAX;
  962. VarDesc.StartSampler = UINT_MAX;
  963. VarDesc.uFlags |= D3D_SVF_USED; // TODO: not necessarily true
  964. CShaderReflectionVariable Var;
  965. CShaderReflectionType *pVarType = nullptr;
  966. // Create reflection type, if we have the necessary annotation info
  967. // Extract the `struct` that wraps element type of the buffer resource
  968. Constant *GV = R.GetGlobalSymbol();
  969. Type *Ty = GV->getType()->getPointerElementType();
  970. if(Ty->isArrayTy())
  971. Ty = Ty->getArrayElementType();
  972. StructType *ST = cast<StructType>(Ty);
  973. // Look up struct type annotation on the element type
  974. DxilTypeSystem &typeSys = M.GetTypeSystem();
  975. DxilStructAnnotation *annotation =
  976. typeSys.GetStructAnnotation(cast<StructType>(ST));
  977. // Dxil from dxbc doesn't have annotation.
  978. if(annotation)
  979. {
  980. // Actually create the reflection type.
  981. pVarType = new CShaderReflectionType();
  982. allTypes.push_back(std::unique_ptr<CShaderReflectionType>(pVarType));
  983. // The user-visible element type is the first field of the wrapepr `struct`
  984. Type *fieldType = ST->getElementType(0);
  985. DxilFieldAnnotation &fieldAnnotation = annotation->GetFieldAnnotation(0);
  986. pVarType->Initialize(M, fieldType, fieldAnnotation, fieldAnnotation.GetCBufferOffset(), allTypes);
  987. }
  988. BYTE *pDefaultValue = nullptr;
  989. Var.Initialize(this, &VarDesc, pVarType, pDefaultValue);
  990. m_Variables.push_back(Var);
  991. m_Desc.Size = VarDesc.Size;
  992. }
  993. HRESULT CShaderReflectionConstantBuffer::GetDesc(D3D12_SHADER_BUFFER_DESC *pDesc) {
  994. if (!pDesc)
  995. return E_POINTER;
  996. memcpy(pDesc, &m_Desc, sizeof(m_Desc));
  997. return S_OK;
  998. }
  999. ID3D12ShaderReflectionVariable *
  1000. CShaderReflectionConstantBuffer::GetVariableByIndex(UINT Index) {
  1001. if (Index >= m_Variables.size()) {
  1002. return &g_InvalidSRVariable;
  1003. }
  1004. return &m_Variables[Index];
  1005. }
  1006. ID3D12ShaderReflectionVariable *
  1007. CShaderReflectionConstantBuffer::GetVariableByName(LPCSTR Name) {
  1008. UINT index;
  1009. if (NULL == Name) {
  1010. return &g_InvalidSRVariable;
  1011. }
  1012. for (index = 0; index < m_Variables.size(); ++index) {
  1013. if (0 == strcmp(m_Variables[index].GetName(), Name)) {
  1014. return &m_Variables[index];
  1015. }
  1016. }
  1017. return &g_InvalidSRVariable;
  1018. }
  1019. ///////////////////////////////////////////////////////////////////////////////
  1020. // DxilShaderReflection implementation. //
  1021. static DxilResource *DxilResourceFromBase(DxilResourceBase *RB) {
  1022. DxilResourceBase::Class C = RB->GetClass();
  1023. if (C == DXIL::ResourceClass::UAV || C == DXIL::ResourceClass::SRV)
  1024. return (DxilResource *)RB;
  1025. return nullptr;
  1026. }
  1027. static D3D_SHADER_INPUT_TYPE ResourceToShaderInputType(DxilResourceBase *RB) {
  1028. DxilResource *R = DxilResourceFromBase(RB);
  1029. bool isUAV = RB->GetClass() == DxilResourceBase::Class::UAV;
  1030. switch (RB->GetKind()) {
  1031. case DxilResource::Kind::CBuffer:
  1032. return D3D_SIT_CBUFFER;
  1033. case DxilResource::Kind::Sampler:
  1034. return D3D_SIT_SAMPLER;
  1035. case DxilResource::Kind::RawBuffer:
  1036. return isUAV ? D3D_SIT_UAV_RWBYTEADDRESS : D3D_SIT_BYTEADDRESS;
  1037. case DxilResource::Kind::StructuredBuffer: {
  1038. if (!isUAV) return D3D_SIT_STRUCTURED;
  1039. // TODO: D3D_SIT_UAV_CONSUME_STRUCTURED, D3D_SIT_UAV_APPEND_STRUCTURED?
  1040. if (R->HasCounter()) return D3D_SIT_UAV_RWSTRUCTURED_WITH_COUNTER;
  1041. return D3D_SIT_UAV_RWSTRUCTURED;
  1042. }
  1043. case DxilResource::Kind::TBuffer:
  1044. case DxilResource::Kind::TypedBuffer:
  1045. case DxilResource::Kind::Texture1D:
  1046. case DxilResource::Kind::Texture1DArray:
  1047. case DxilResource::Kind::Texture2D:
  1048. case DxilResource::Kind::Texture2DArray:
  1049. case DxilResource::Kind::Texture2DMS:
  1050. case DxilResource::Kind::Texture2DMSArray:
  1051. case DxilResource::Kind::Texture3D:
  1052. case DxilResource::Kind::TextureCube:
  1053. case DxilResource::Kind::TextureCubeArray:
  1054. return isUAV ? D3D_SIT_UAV_RWTYPED : D3D_SIT_TEXTURE;
  1055. case DxilResource::Kind::RTAccelerationStructure:
  1056. return (D3D_SHADER_INPUT_TYPE)D3D_SIT_RTACCELERATIONSTRUCTURE;
  1057. default:
  1058. return (D3D_SHADER_INPUT_TYPE)-1;
  1059. }
  1060. }
  1061. static D3D_RESOURCE_RETURN_TYPE ResourceToReturnType(DxilResourceBase *RB) {
  1062. DxilResource *R = DxilResourceFromBase(RB);
  1063. if (R != nullptr) {
  1064. CompType CT = R->GetCompType();
  1065. if (CT.GetKind() == CompType::Kind::F64) return D3D_RETURN_TYPE_DOUBLE;
  1066. if (CT.IsUNorm()) return D3D_RETURN_TYPE_UNORM;
  1067. if (CT.IsSNorm()) return D3D_RETURN_TYPE_SNORM;
  1068. if (CT.IsSIntTy()) return D3D_RETURN_TYPE_SINT;
  1069. if (CT.IsUIntTy()) return D3D_RETURN_TYPE_UINT;
  1070. if (CT.IsFloatTy()) return D3D_RETURN_TYPE_FLOAT;
  1071. // D3D_RETURN_TYPE_CONTINUED: Return type is a multiple-dword type, such as a
  1072. // double or uint64, and the component is continued from the previous
  1073. // component that was declared. The first component represents the lower bits.
  1074. return D3D_RETURN_TYPE_MIXED;
  1075. }
  1076. return (D3D_RESOURCE_RETURN_TYPE)0;
  1077. }
  1078. static D3D_SRV_DIMENSION ResourceToDimension(DxilResourceBase *RB) {
  1079. switch (RB->GetKind()) {
  1080. case DxilResource::Kind::StructuredBuffer:
  1081. case DxilResource::Kind::TypedBuffer:
  1082. case DxilResource::Kind::TBuffer:
  1083. return D3D_SRV_DIMENSION_BUFFER;
  1084. case DxilResource::Kind::Texture1D:
  1085. return D3D_SRV_DIMENSION_TEXTURE1D;
  1086. case DxilResource::Kind::Texture1DArray:
  1087. return D3D_SRV_DIMENSION_TEXTURE1DARRAY;
  1088. case DxilResource::Kind::Texture2D:
  1089. return D3D_SRV_DIMENSION_TEXTURE2D;
  1090. case DxilResource::Kind::Texture2DArray:
  1091. return D3D_SRV_DIMENSION_TEXTURE2DARRAY;
  1092. case DxilResource::Kind::Texture2DMS:
  1093. return D3D_SRV_DIMENSION_TEXTURE2DMS;
  1094. case DxilResource::Kind::Texture2DMSArray:
  1095. return D3D_SRV_DIMENSION_TEXTURE2DMSARRAY;
  1096. case DxilResource::Kind::Texture3D:
  1097. return D3D_SRV_DIMENSION_TEXTURE3D;
  1098. case DxilResource::Kind::TextureCube:
  1099. return D3D_SRV_DIMENSION_TEXTURECUBE;
  1100. case DxilResource::Kind::TextureCubeArray:
  1101. return D3D_SRV_DIMENSION_TEXTURECUBEARRAY;
  1102. case DxilResource::Kind::RawBuffer:
  1103. return D3D11_SRV_DIMENSION_BUFFER; // D3D11_SRV_DIMENSION_BUFFEREX?
  1104. default:
  1105. return D3D_SRV_DIMENSION_UNKNOWN;
  1106. }
  1107. }
  1108. static UINT ResourceToFlags(DxilResourceBase *RB) {
  1109. UINT result = 0;
  1110. DxilResource *R = DxilResourceFromBase(RB);
  1111. if (R != nullptr &&
  1112. (R->IsAnyTexture() || R->GetKind() == DXIL::ResourceKind::TypedBuffer)) {
  1113. llvm::Type *RetTy = R->GetRetType();
  1114. if (VectorType *VT = dyn_cast<VectorType>(RetTy)) {
  1115. unsigned vecSize = VT->getNumElements();
  1116. switch (vecSize) {
  1117. case 4:
  1118. result |= D3D_SIF_TEXTURE_COMPONENTS;
  1119. break;
  1120. case 3:
  1121. result |= D3D_SIF_TEXTURE_COMPONENT_1;
  1122. break;
  1123. case 2:
  1124. result |= D3D_SIF_TEXTURE_COMPONENT_0;
  1125. break;
  1126. }
  1127. }
  1128. }
  1129. // D3D_SIF_USERPACKED
  1130. if (RB->GetClass() == DXIL::ResourceClass::Sampler) {
  1131. DxilSampler *S = static_cast<DxilSampler *>(RB);
  1132. if (S->GetSamplerKind() == DXIL::SamplerKind::Comparison)
  1133. result |= D3D_SIF_COMPARISON_SAMPLER;
  1134. }
  1135. return result;
  1136. }
  1137. void DxilModuleReflection::CreateReflectionObjectForResource(DxilResourceBase *RB) {
  1138. DxilResourceBase::Class C = RB->GetClass();
  1139. DxilResource *R =
  1140. (C == DXIL::ResourceClass::UAV || C == DXIL::ResourceClass::SRV)
  1141. ? (DxilResource *)RB
  1142. : nullptr;
  1143. D3D12_SHADER_INPUT_BIND_DESC inputBind;
  1144. ZeroMemory(&inputBind, sizeof(inputBind));
  1145. inputBind.BindCount = RB->GetRangeSize();
  1146. if (RB->GetRangeSize() == UINT_MAX)
  1147. inputBind.BindCount = 0;
  1148. inputBind.BindPoint = RB->GetLowerBound();
  1149. inputBind.Dimension = ResourceToDimension(RB);
  1150. inputBind.Name = RB->GetGlobalName().c_str();
  1151. inputBind.Type = ResourceToShaderInputType(RB);
  1152. if (R == nullptr) {
  1153. inputBind.NumSamples = 0;
  1154. }
  1155. else {
  1156. inputBind.NumSamples = R->GetSampleCount();
  1157. if (inputBind.NumSamples == 0) {
  1158. if (R->IsStructuredBuffer()) {
  1159. inputBind.NumSamples = CalcResTypeSize(*m_pDxilModule, *R);
  1160. }
  1161. else if (!R->IsRawBuffer()) {
  1162. inputBind.NumSamples = 0xFFFFFFFF;
  1163. }
  1164. }
  1165. }
  1166. inputBind.ReturnType = ResourceToReturnType(RB);
  1167. inputBind.Space = RB->GetSpaceID();
  1168. inputBind.uFlags = ResourceToFlags(RB);
  1169. inputBind.uID = RB->GetID();
  1170. m_Resources.push_back(inputBind);
  1171. }
  1172. // Find the imm offset part from a value.
  1173. // It must exist unless offset is 0.
  1174. static unsigned GetCBOffset(Value *V) {
  1175. if (ConstantInt *Imm = dyn_cast<ConstantInt>(V))
  1176. return Imm->getLimitedValue();
  1177. else if (UnaryInstruction *UI = dyn_cast<UnaryInstruction>(V)) {
  1178. return 0;
  1179. } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V)) {
  1180. switch (BO->getOpcode()) {
  1181. case Instruction::Add: {
  1182. unsigned left = GetCBOffset(BO->getOperand(0));
  1183. unsigned right = GetCBOffset(BO->getOperand(1));
  1184. return left + right;
  1185. } break;
  1186. case Instruction::Or: {
  1187. unsigned left = GetCBOffset(BO->getOperand(0));
  1188. unsigned right = GetCBOffset(BO->getOperand(1));
  1189. return left | right;
  1190. } break;
  1191. default:
  1192. return 0;
  1193. }
  1194. } else {
  1195. return 0;
  1196. }
  1197. }
  1198. void CollectInPhiChain(PHINode *cbUser, std::vector<unsigned> &cbufUsage,
  1199. unsigned offset, std::unordered_set<Value *> &userSet) {
  1200. if (userSet.count(cbUser) > 0)
  1201. return;
  1202. userSet.insert(cbUser);
  1203. for (User *cbU : cbUser->users()) {
  1204. if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(cbU)) {
  1205. for (unsigned idx : EV->getIndices()) {
  1206. cbufUsage.emplace_back(offset + idx * 4);
  1207. }
  1208. } else {
  1209. PHINode *phi = cast<PHINode>(cbU);
  1210. CollectInPhiChain(phi, cbufUsage, offset, userSet);
  1211. }
  1212. }
  1213. }
  1214. static void CollectCBufUsage(Value *cbHandle,
  1215. std::vector<unsigned> &cbufUsage) {
  1216. for (User *U : cbHandle->users()) {
  1217. CallInst *CI = cast<CallInst>(U);
  1218. ConstantInt *opcodeV =
  1219. cast<ConstantInt>(CI->getArgOperand(DXIL::OperandIndex::kOpcodeIdx));
  1220. DXIL::OpCode opcode = static_cast<DXIL::OpCode>(opcodeV->getLimitedValue());
  1221. if (opcode == DXIL::OpCode::CBufferLoadLegacy) {
  1222. DxilInst_CBufferLoadLegacy cbload(CI);
  1223. Value *resIndex = cbload.get_regIndex();
  1224. unsigned offset = GetCBOffset(resIndex);
  1225. // 16 bytes align.
  1226. offset <<= 4;
  1227. for (User *cbU : U->users()) {
  1228. if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(cbU)) {
  1229. for (unsigned idx : EV->getIndices()) {
  1230. cbufUsage.emplace_back(offset + idx * 4);
  1231. }
  1232. } else {
  1233. PHINode *phi = cast<PHINode>(cbU);
  1234. std::unordered_set<Value *> userSet;
  1235. CollectInPhiChain(phi, cbufUsage, offset, userSet);
  1236. }
  1237. }
  1238. } else if (opcode == DXIL::OpCode::CBufferLoad) {
  1239. DxilInst_CBufferLoad cbload(CI);
  1240. Value *byteOffset = cbload.get_byteOffset();
  1241. unsigned offset = GetCBOffset(byteOffset);
  1242. cbufUsage.emplace_back(offset);
  1243. } else {
  1244. //
  1245. DXASSERT(0, "invalid opcode");
  1246. }
  1247. }
  1248. }
  1249. static void SetCBufVarUsage(CShaderReflectionConstantBuffer &cb,
  1250. std::vector<unsigned> usage) {
  1251. D3D12_SHADER_BUFFER_DESC Desc;
  1252. if (FAILED(cb.GetDesc(&Desc)))
  1253. return;
  1254. unsigned size = Desc.Variables;
  1255. std::sort(usage.begin(), usage.end());
  1256. for (unsigned i = 0; i < size; i++) {
  1257. ID3D12ShaderReflectionVariable *pVar = cb.GetVariableByIndex(i);
  1258. D3D12_SHADER_VARIABLE_DESC VarDesc;
  1259. if (FAILED(pVar->GetDesc(&VarDesc)))
  1260. continue;
  1261. if (!pVar)
  1262. continue;
  1263. unsigned begin = VarDesc.StartOffset;
  1264. unsigned end = begin + VarDesc.Size;
  1265. auto beginIt = std::find_if(usage.begin(), usage.end(),
  1266. [&](unsigned v) { return v >= begin; });
  1267. auto endIt = std::find_if(usage.begin(), usage.end(),
  1268. [&](unsigned v) { return v >= end; });
  1269. bool used = beginIt != endIt;
  1270. // Clear used.
  1271. if (!used) {
  1272. CShaderReflectionType *pVarType = (CShaderReflectionType *)pVar->GetType();
  1273. BYTE *pDefaultValue = nullptr;
  1274. VarDesc.uFlags &= ~D3D_SVF_USED;
  1275. CShaderReflectionVariable *pCVarDesc = (CShaderReflectionVariable*)pVar;
  1276. pCVarDesc->Initialize(&cb, &VarDesc, pVarType, pDefaultValue);
  1277. }
  1278. }
  1279. }
  1280. void DxilShaderReflection::SetCBufferUsage() {
  1281. hlsl::OP *hlslOP = m_pDxilModule->GetOP();
  1282. LLVMContext &Ctx = m_pDxilModule->GetCtx();
  1283. // Indexes >= cbuffer size from DxilModule are SRV or UAV structured buffers.
  1284. // We only collect usage for actual cbuffers, so don't go clearing usage on other buffers.
  1285. unsigned cbSize = std::min(m_CBs.size(), m_pDxilModule->GetCBuffers().size());
  1286. std::vector< std::vector<unsigned> > cbufUsage(cbSize);
  1287. Function *createHandle = hlslOP->GetOpFunc(DXIL::OpCode::CreateHandle, Type::getVoidTy(Ctx));
  1288. if (createHandle->user_empty()) {
  1289. createHandle->eraseFromParent();
  1290. return;
  1291. }
  1292. // Find all cb handles.
  1293. for (User *U : createHandle->users()) {
  1294. DxilInst_CreateHandle handle(cast<CallInst>(U));
  1295. Value *resClass = handle.get_resourceClass();
  1296. ConstantInt *immResClass = cast<ConstantInt>(resClass);
  1297. if (immResClass->getLimitedValue() == (unsigned)DXIL::ResourceClass::CBuffer) {
  1298. ConstantInt *cbID = cast<ConstantInt>(handle.get_rangeId());
  1299. CollectCBufUsage(U, cbufUsage[cbID->getLimitedValue()]);
  1300. }
  1301. }
  1302. for (unsigned i=0;i<cbSize;i++) {
  1303. SetCBufVarUsage(*m_CBs[i], cbufUsage[i]);
  1304. }
  1305. }
  1306. void DxilModuleReflection::CreateReflectionObjects() {
  1307. DXASSERT_NOMSG(m_pDxilModule != nullptr);
  1308. // Create constant buffers, resources and signatures.
  1309. for (auto && cb : m_pDxilModule->GetCBuffers()) {
  1310. std::unique_ptr<CShaderReflectionConstantBuffer> rcb(new CShaderReflectionConstantBuffer());
  1311. rcb->Initialize(*m_pDxilModule, *(cb.get()), m_Types);
  1312. m_CBs.emplace_back(std::move(rcb));
  1313. }
  1314. // TODO: add tbuffers into m_CBs
  1315. for (auto && uav : m_pDxilModule->GetUAVs()) {
  1316. if (uav->GetKind() != DxilResource::Kind::StructuredBuffer) {
  1317. continue;
  1318. }
  1319. std::unique_ptr<CShaderReflectionConstantBuffer> rcb(new CShaderReflectionConstantBuffer());
  1320. rcb->InitializeStructuredBuffer(*m_pDxilModule, *(uav.get()), m_Types);
  1321. m_CBs.emplace_back(std::move(rcb));
  1322. }
  1323. for (auto && srv : m_pDxilModule->GetSRVs()) {
  1324. if (srv->GetKind() != DxilResource::Kind::StructuredBuffer) {
  1325. continue;
  1326. }
  1327. std::unique_ptr<CShaderReflectionConstantBuffer> rcb(new CShaderReflectionConstantBuffer());
  1328. rcb->InitializeStructuredBuffer(*m_pDxilModule, *(srv.get()), m_Types);
  1329. m_CBs.emplace_back(std::move(rcb));
  1330. }
  1331. // Populate all resources.
  1332. for (auto && cbRes : m_pDxilModule->GetCBuffers()) {
  1333. CreateReflectionObjectForResource(cbRes.get());
  1334. }
  1335. for (auto && samplerRes : m_pDxilModule->GetSamplers()) {
  1336. CreateReflectionObjectForResource(samplerRes.get());
  1337. }
  1338. for (auto && srvRes : m_pDxilModule->GetSRVs()) {
  1339. CreateReflectionObjectForResource(srvRes.get());
  1340. }
  1341. for (auto && uavRes : m_pDxilModule->GetUAVs()) {
  1342. CreateReflectionObjectForResource(uavRes.get());
  1343. }
  1344. }
  1345. static D3D_REGISTER_COMPONENT_TYPE CompTypeToRegisterComponentType(CompType CT) {
  1346. switch (CT.GetKind()) {
  1347. case DXIL::ComponentType::F16:
  1348. case DXIL::ComponentType::F32:
  1349. return D3D_REGISTER_COMPONENT_FLOAT32;
  1350. case DXIL::ComponentType::I1:
  1351. case DXIL::ComponentType::U16:
  1352. case DXIL::ComponentType::U32:
  1353. return D3D_REGISTER_COMPONENT_UINT32;
  1354. case DXIL::ComponentType::I16:
  1355. case DXIL::ComponentType::I32:
  1356. return D3D_REGISTER_COMPONENT_SINT32;
  1357. default:
  1358. return D3D_REGISTER_COMPONENT_UNKNOWN;
  1359. }
  1360. }
  1361. static D3D_MIN_PRECISION CompTypeToMinPrecision(CompType CT) {
  1362. switch (CT.GetKind()) {
  1363. case DXIL::ComponentType::F16:
  1364. return D3D_MIN_PRECISION_FLOAT_16;
  1365. case DXIL::ComponentType::I16:
  1366. return D3D_MIN_PRECISION_SINT_16;
  1367. case DXIL::ComponentType::U16:
  1368. return D3D_MIN_PRECISION_UINT_16;
  1369. default:
  1370. return D3D_MIN_PRECISION_DEFAULT;
  1371. }
  1372. }
  1373. D3D_NAME SemanticToSystemValueType(const Semantic *S, DXIL::TessellatorDomain domain) {
  1374. switch (S->GetKind()) {
  1375. case Semantic::Kind::ClipDistance:
  1376. return D3D_NAME_CLIP_DISTANCE;
  1377. case Semantic::Kind::Arbitrary:
  1378. return D3D_NAME_UNDEFINED;
  1379. case Semantic::Kind::VertexID:
  1380. return D3D_NAME_VERTEX_ID;
  1381. case Semantic::Kind::InstanceID:
  1382. return D3D_NAME_INSTANCE_ID;
  1383. case Semantic::Kind::Position:
  1384. return D3D_NAME_POSITION;
  1385. case Semantic::Kind::Coverage:
  1386. return D3D_NAME_COVERAGE;
  1387. case Semantic::Kind::InnerCoverage:
  1388. return D3D_NAME_INNER_COVERAGE;
  1389. case Semantic::Kind::PrimitiveID:
  1390. return D3D_NAME_PRIMITIVE_ID;
  1391. case Semantic::Kind::SampleIndex:
  1392. return D3D_NAME_SAMPLE_INDEX;
  1393. case Semantic::Kind::IsFrontFace:
  1394. return D3D_NAME_IS_FRONT_FACE;
  1395. case Semantic::Kind::RenderTargetArrayIndex:
  1396. return D3D_NAME_RENDER_TARGET_ARRAY_INDEX;
  1397. case Semantic::Kind::ViewPortArrayIndex:
  1398. return D3D_NAME_VIEWPORT_ARRAY_INDEX;
  1399. case Semantic::Kind::CullDistance:
  1400. return D3D_NAME_CULL_DISTANCE;
  1401. case Semantic::Kind::Target:
  1402. return D3D_NAME_TARGET;
  1403. case Semantic::Kind::Depth:
  1404. return D3D_NAME_DEPTH;
  1405. case Semantic::Kind::DepthLessEqual:
  1406. return D3D_NAME_DEPTH_LESS_EQUAL;
  1407. case Semantic::Kind::DepthGreaterEqual:
  1408. return D3D_NAME_DEPTH_GREATER_EQUAL;
  1409. case Semantic::Kind::StencilRef:
  1410. return D3D_NAME_STENCIL_REF;
  1411. case Semantic::Kind::TessFactor: {
  1412. switch (domain) {
  1413. case DXIL::TessellatorDomain::IsoLine:
  1414. return D3D_NAME_FINAL_LINE_DETAIL_TESSFACTOR;
  1415. case DXIL::TessellatorDomain::Tri:
  1416. return D3D_NAME_FINAL_TRI_EDGE_TESSFACTOR;
  1417. case DXIL::TessellatorDomain::Quad:
  1418. return D3D_NAME_FINAL_QUAD_EDGE_TESSFACTOR;
  1419. default:
  1420. return D3D_NAME_UNDEFINED;
  1421. }
  1422. }
  1423. case Semantic::Kind::InsideTessFactor:
  1424. switch (domain) {
  1425. case DXIL::TessellatorDomain::Tri:
  1426. return D3D_NAME_FINAL_TRI_INSIDE_TESSFACTOR;
  1427. case DXIL::TessellatorDomain::Quad:
  1428. return D3D_NAME_FINAL_QUAD_INSIDE_TESSFACTOR;
  1429. default:
  1430. return D3D_NAME_UNDEFINED;
  1431. }
  1432. case Semantic::Kind::DispatchThreadID:
  1433. case Semantic::Kind::GroupID:
  1434. case Semantic::Kind::GroupIndex:
  1435. case Semantic::Kind::GroupThreadID:
  1436. case Semantic::Kind::DomainLocation:
  1437. case Semantic::Kind::OutputControlPointID:
  1438. case Semantic::Kind::GSInstanceID:
  1439. case Semantic::Kind::Invalid:
  1440. default:
  1441. return D3D_NAME_UNDEFINED;
  1442. }
  1443. }
  1444. static uint8_t NegMask(uint8_t V) {
  1445. V ^= 0xF;
  1446. return V & 0xF;
  1447. }
  1448. void DxilShaderReflection::CreateReflectionObjectsForSignature(
  1449. const DxilSignature &Sig,
  1450. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> &Descs) {
  1451. bool clipDistanceSeen = false;
  1452. for (auto && SigElem : Sig.GetElements()) {
  1453. D3D12_SIGNATURE_PARAMETER_DESC Desc;
  1454. // TODO: why do we have multiple SV_ClipDistance elements?
  1455. if (SigElem->GetSemantic()->GetKind() == DXIL::SemanticKind::ClipDistance) {
  1456. if (clipDistanceSeen) continue;
  1457. clipDistanceSeen = true;
  1458. }
  1459. Desc.ComponentType = CompTypeToRegisterComponentType(SigElem->GetCompType());
  1460. Desc.Mask = SigElem->GetColsAsMask();
  1461. // D3D11_43 does not have MinPrecison.
  1462. if (m_PublicAPI != PublicAPI::D3D11_43)
  1463. Desc.MinPrecision = CompTypeToMinPrecision(SigElem->GetCompType());
  1464. Desc.ReadWriteMask = Sig.IsInput() ? 0 : Desc.Mask; // Start with output-never-written/input-never-read.
  1465. Desc.Register = SigElem->GetStartRow();
  1466. Desc.Stream = SigElem->GetOutputStream();
  1467. Desc.SystemValueType = SemanticToSystemValueType(SigElem->GetSemantic(), m_pDxilModule->GetTessellatorDomain());
  1468. Desc.SemanticName = SigElem->GetName();
  1469. if (!SigElem->GetSemantic()->IsArbitrary())
  1470. Desc.SemanticName = CreateUpperCase(Desc.SemanticName);
  1471. const std::vector<unsigned> &indexVec = SigElem->GetSemanticIndexVec();
  1472. for (unsigned semIdx = 0; semIdx < indexVec.size(); ++semIdx) {
  1473. Desc.SemanticIndex = indexVec[semIdx];
  1474. if (Desc.SystemValueType == D3D_NAME_FINAL_LINE_DETAIL_TESSFACTOR &&
  1475. Desc.SemanticIndex == 1)
  1476. Desc.SystemValueType = D3D_NAME_FINAL_LINE_DETAIL_TESSFACTOR;
  1477. Descs.push_back(Desc);
  1478. }
  1479. }
  1480. }
  1481. LPCSTR DxilShaderReflection::CreateUpperCase(LPCSTR pValue) {
  1482. // Restricted only to [a-z] ASCII.
  1483. LPCSTR pCursor = pValue;
  1484. while (*pCursor != '\0') {
  1485. if ('a' <= *pCursor && *pCursor <= 'z') {
  1486. break;
  1487. }
  1488. ++pCursor;
  1489. }
  1490. if (*pCursor == '\0')
  1491. return pValue;
  1492. std::unique_ptr<char[]> pUpperStr = std::make_unique<char[]>(strlen(pValue) + 1);
  1493. char *pWrite = pUpperStr.get();
  1494. pCursor = pValue;
  1495. for (;;) {
  1496. *pWrite = *pCursor;
  1497. if ('a' <= *pWrite && *pWrite <= 'z') {
  1498. *pWrite += ('A' - 'a');
  1499. }
  1500. if (*pWrite == '\0') break;
  1501. ++pWrite;
  1502. ++pCursor;
  1503. }
  1504. m_UpperCaseNames.push_back(std::move(pUpperStr));
  1505. return m_UpperCaseNames.back().get();
  1506. }
  1507. HRESULT DxilModuleReflection::LoadModule(IDxcBlob *pBlob,
  1508. const DxilPartHeader *pPart) {
  1509. DXASSERT_NOMSG(pBlob != nullptr);
  1510. DXASSERT_NOMSG(pPart != nullptr);
  1511. m_pContainer = pBlob;
  1512. const char *pData = GetDxilPartData(pPart);
  1513. try {
  1514. const char *pBitcode;
  1515. uint32_t bitcodeLength;
  1516. GetDxilProgramBitcode((DxilProgramHeader *)pData, &pBitcode, &bitcodeLength);
  1517. std::unique_ptr<MemoryBuffer> pMemBuffer =
  1518. MemoryBuffer::getMemBufferCopy(StringRef(pBitcode, bitcodeLength));
  1519. #if 0 // We materialize eagerly, because we'll need to walk instructions to look for usage information.
  1520. ErrorOr<std::unique_ptr<Module>> module =
  1521. getLazyBitcodeModule(std::move(pMemBuffer), Context);
  1522. #else
  1523. ErrorOr<std::unique_ptr<Module>> module =
  1524. parseBitcodeFile(pMemBuffer->getMemBufferRef(), Context, nullptr);
  1525. #endif
  1526. if (!module) {
  1527. return E_INVALIDARG;
  1528. }
  1529. std::swap(m_pModule, module.get());
  1530. m_pDxilModule = &m_pModule->GetOrCreateDxilModule();
  1531. CreateReflectionObjects();
  1532. return S_OK;
  1533. }
  1534. CATCH_CPP_RETURN_HRESULT();
  1535. };
  1536. HRESULT DxilShaderReflection::Load(IDxcBlob *pBlob,
  1537. const DxilPartHeader *pPart) {
  1538. IFR(LoadModule(pBlob, pPart));
  1539. try {
  1540. // Set cbuf usage.
  1541. SetCBufferUsage();
  1542. // Populate input/output/patch constant signatures.
  1543. CreateReflectionObjectsForSignature(m_pDxilModule->GetInputSignature(), m_InputSignature);
  1544. CreateReflectionObjectsForSignature(m_pDxilModule->GetOutputSignature(), m_OutputSignature);
  1545. CreateReflectionObjectsForSignature(m_pDxilModule->GetPatchConstantSignature(), m_PatchConstantSignature);
  1546. MarkUsedSignatureElements();
  1547. return S_OK;
  1548. }
  1549. CATCH_CPP_RETURN_HRESULT();
  1550. }
  1551. _Use_decl_annotations_
  1552. HRESULT DxilShaderReflection::GetDesc(D3D12_SHADER_DESC *pDesc) {
  1553. IFR(ZeroMemoryToOut(pDesc));
  1554. const DxilModule &M = *m_pDxilModule;
  1555. const ShaderModel *pSM = M.GetShaderModel();
  1556. pDesc->Version = EncodeVersion(pSM->GetKind(), pSM->GetMajor(), pSM->GetMinor());
  1557. // Unset: LPCSTR Creator; // Creator string
  1558. // Unset: UINT Flags; // Shader compilation/parse flags
  1559. pDesc->ConstantBuffers = m_CBs.size();
  1560. pDesc->BoundResources = m_Resources.size();
  1561. pDesc->InputParameters = m_InputSignature.size();
  1562. pDesc->OutputParameters = m_OutputSignature.size();
  1563. pDesc->PatchConstantParameters = m_PatchConstantSignature.size();
  1564. // Unset: UINT InstructionCount; // Number of emitted instructions
  1565. // Unset: UINT TempRegisterCount; // Number of temporary registers used
  1566. // Unset: UINT TempArrayCount; // Number of temporary arrays used
  1567. // Unset: UINT DefCount; // Number of constant defines
  1568. // Unset: UINT DclCount; // Number of declarations (input + output)
  1569. // Unset: UINT TextureNormalInstructions; // Number of non-categorized texture instructions
  1570. // Unset: UINT TextureLoadInstructions; // Number of texture load instructions
  1571. // Unset: UINT TextureCompInstructions; // Number of texture comparison instructions
  1572. // Unset: UINT TextureBiasInstructions; // Number of texture bias instructions
  1573. // Unset: UINT TextureGradientInstructions; // Number of texture gradient instructions
  1574. // Unset: UINT FloatInstructionCount; // Number of floating point arithmetic instructions used
  1575. // Unset: UINT IntInstructionCount; // Number of signed integer arithmetic instructions used
  1576. // Unset: UINT UintInstructionCount; // Number of unsigned integer arithmetic instructions used
  1577. // Unset: UINT StaticFlowControlCount; // Number of static flow control instructions used
  1578. // Unset: UINT DynamicFlowControlCount; // Number of dynamic flow control instructions used
  1579. // Unset: UINT MacroInstructionCount; // Number of macro instructions used
  1580. // Unset: UINT ArrayInstructionCount; // Number of array instructions used
  1581. // Unset: UINT CutInstructionCount; // Number of cut instructions used
  1582. // Unset: UINT EmitInstructionCount; // Number of emit instructions used
  1583. // Unset: D3D_PRIMITIVE_TOPOLOGY GSOutputTopology; // Geometry shader output topology
  1584. // Unset: UINT GSMaxOutputVertexCount; // Geometry shader maximum output vertex count
  1585. // Unset: D3D_PRIMITIVE InputPrimitive; // GS/HS input primitive
  1586. // Unset: UINT cGSInstanceCount; // Number of Geometry shader instances
  1587. // Unset: UINT cControlPoints; // Number of control points in the HS->DS stage
  1588. // Unset: D3D_TESSELLATOR_OUTPUT_PRIMITIVE HSOutputPrimitive; // Primitive output by the tessellator
  1589. // Unset: D3D_TESSELLATOR_PARTITIONING HSPartitioning; // Partitioning mode of the tessellator
  1590. // Unset: D3D_TESSELLATOR_DOMAIN TessellatorDomain; // Domain of the tessellator (quad, tri, isoline)
  1591. // instruction counts
  1592. // Unset: UINT cBarrierInstructions; // Number of barrier instructions in a compute shader
  1593. // Unset: UINT cInterlockedInstructions; // Number of interlocked instructions
  1594. // Unset: UINT cTextureStoreInstructions; // Number of texture writes
  1595. return S_OK;
  1596. }
  1597. static bool GetUnsignedVal(Value *V, uint32_t *pValue) {
  1598. ConstantInt *CI = dyn_cast<ConstantInt>(V);
  1599. if (!CI) return false;
  1600. uint64_t u = CI->getZExtValue();
  1601. if (u > UINT32_MAX) return false;
  1602. *pValue = (uint32_t)u;
  1603. return true;
  1604. }
  1605. void DxilShaderReflection::MarkUsedSignatureElements() {
  1606. Function *F = m_pDxilModule->GetEntryFunction();
  1607. DXASSERT(F != nullptr, "else module load should have failed");
  1608. // For every loadInput/storeOutput, update the corresponding ReadWriteMask.
  1609. // F is a pointer to a Function instance
  1610. unsigned elementCount = m_InputSignature.size() + m_OutputSignature.size() +
  1611. m_PatchConstantSignature.size();
  1612. unsigned markedElementCount = 0;
  1613. for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
  1614. DxilInst_LoadInput LI(&*I);
  1615. DxilInst_StoreOutput SO(&*I);
  1616. DxilInst_LoadPatchConstant LPC(&*I);
  1617. DxilInst_StorePatchConstant SPC(&*I);
  1618. std::vector<D3D12_SIGNATURE_PARAMETER_DESC> *pDescs;
  1619. const DxilSignature *pSig;
  1620. uint32_t col, row, sigId;
  1621. if (LI) {
  1622. if (!GetUnsignedVal(LI.get_inputSigId(), &sigId)) continue;
  1623. if (!GetUnsignedVal(LI.get_colIndex(), &col)) continue;
  1624. if (!GetUnsignedVal(LI.get_rowIndex(), &row)) continue;
  1625. pDescs = &m_InputSignature;
  1626. pSig = &m_pDxilModule->GetInputSignature();
  1627. }
  1628. else if (SO) {
  1629. if (!GetUnsignedVal(SO.get_outputSigId(), &sigId)) continue;
  1630. if (!GetUnsignedVal(SO.get_colIndex(), &col)) continue;
  1631. if (!GetUnsignedVal(SO.get_rowIndex(), &row)) continue;
  1632. pDescs = &m_OutputSignature;
  1633. pSig = &m_pDxilModule->GetOutputSignature();
  1634. }
  1635. else if (SPC) {
  1636. if (!GetUnsignedVal(SPC.get_outputSigID(), &sigId)) continue;
  1637. if (!GetUnsignedVal(SPC.get_col(), &col)) continue;
  1638. if (!GetUnsignedVal(SPC.get_row(), &row)) continue;
  1639. pDescs = &m_PatchConstantSignature;
  1640. pSig = &m_pDxilModule->GetPatchConstantSignature();
  1641. }
  1642. else if (LPC) {
  1643. if (!GetUnsignedVal(LPC.get_inputSigId(), &sigId)) continue;
  1644. if (!GetUnsignedVal(LPC.get_col(), &col)) continue;
  1645. if (!GetUnsignedVal(LPC.get_row(), &row)) continue;
  1646. pDescs = &m_PatchConstantSignature;
  1647. pSig = &m_pDxilModule->GetPatchConstantSignature();
  1648. }
  1649. else {
  1650. continue;
  1651. }
  1652. if (sigId >= pDescs->size()) continue;
  1653. D3D12_SIGNATURE_PARAMETER_DESC *pDesc = &(*pDescs)[sigId];
  1654. // Consider being more fine-grained about masks.
  1655. // We report sometimes-read on input as always-read.
  1656. unsigned UsedMask = pSig->IsInput() ? pDesc->Mask : NegMask(pDesc->Mask);
  1657. if (pDesc->ReadWriteMask == UsedMask)
  1658. continue;
  1659. pDesc->ReadWriteMask = UsedMask;
  1660. ++markedElementCount;
  1661. if (markedElementCount == elementCount)
  1662. return;
  1663. }
  1664. }
  1665. _Use_decl_annotations_
  1666. ID3D12ShaderReflectionConstantBuffer* DxilShaderReflection::GetConstantBufferByIndex(UINT Index) {
  1667. return DxilModuleReflection::_GetConstantBufferByIndex(Index);
  1668. }
  1669. ID3D12ShaderReflectionConstantBuffer* DxilModuleReflection::_GetConstantBufferByIndex(UINT Index) {
  1670. if (Index >= m_CBs.size()) {
  1671. return &g_InvalidSRConstantBuffer;
  1672. }
  1673. return m_CBs[Index].get();
  1674. }
  1675. _Use_decl_annotations_
  1676. ID3D12ShaderReflectionConstantBuffer* DxilShaderReflection::GetConstantBufferByName(LPCSTR Name) {
  1677. return DxilModuleReflection::_GetConstantBufferByName(Name);
  1678. }
  1679. ID3D12ShaderReflectionConstantBuffer* DxilModuleReflection::_GetConstantBufferByName(LPCSTR Name) {
  1680. if (!Name) {
  1681. return &g_InvalidSRConstantBuffer;
  1682. }
  1683. for (UINT index = 0; index < m_CBs.size(); ++index) {
  1684. if (0 == strcmp(m_CBs[index]->GetName(), Name)) {
  1685. return m_CBs[index].get();
  1686. }
  1687. }
  1688. return &g_InvalidSRConstantBuffer;
  1689. }
  1690. _Use_decl_annotations_
  1691. HRESULT DxilShaderReflection::GetResourceBindingDesc(UINT ResourceIndex,
  1692. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc) {
  1693. return DxilModuleReflection::_GetResourceBindingDesc(ResourceIndex, pDesc, m_PublicAPI);
  1694. }
  1695. HRESULT DxilModuleReflection::_GetResourceBindingDesc(UINT ResourceIndex,
  1696. _Out_ D3D12_SHADER_INPUT_BIND_DESC *pDesc, PublicAPI api) {
  1697. IFRBOOL(pDesc != nullptr, E_INVALIDARG);
  1698. IFRBOOL(ResourceIndex < m_Resources.size(), E_INVALIDARG);
  1699. if (api != PublicAPI::D3D12) {
  1700. memcpy(pDesc, &m_Resources[ResourceIndex], sizeof(D3D11_SHADER_INPUT_BIND_DESC));
  1701. }
  1702. else {
  1703. *pDesc = m_Resources[ResourceIndex];
  1704. }
  1705. return S_OK;
  1706. }
  1707. _Use_decl_annotations_
  1708. HRESULT DxilShaderReflection::GetInputParameterDesc(UINT ParameterIndex,
  1709. _Out_ D3D12_SIGNATURE_PARAMETER_DESC *pDesc) {
  1710. IFRBOOL(pDesc != nullptr, E_INVALIDARG);
  1711. IFRBOOL(ParameterIndex < m_InputSignature.size(), E_INVALIDARG);
  1712. if (m_PublicAPI != PublicAPI::D3D11_43)
  1713. *pDesc = m_InputSignature[ParameterIndex];
  1714. else
  1715. memcpy(pDesc, &m_InputSignature[ParameterIndex],
  1716. // D3D11_43 does not have MinPrecison.
  1717. sizeof(D3D12_SIGNATURE_PARAMETER_DESC) - sizeof(D3D_MIN_PRECISION));
  1718. return S_OK;
  1719. }
  1720. _Use_decl_annotations_
  1721. HRESULT DxilShaderReflection::GetOutputParameterDesc(UINT ParameterIndex,
  1722. D3D12_SIGNATURE_PARAMETER_DESC *pDesc) {
  1723. IFRBOOL(pDesc != nullptr, E_INVALIDARG);
  1724. IFRBOOL(ParameterIndex < m_OutputSignature.size(), E_INVALIDARG);
  1725. if (m_PublicAPI != PublicAPI::D3D11_43)
  1726. *pDesc = m_OutputSignature[ParameterIndex];
  1727. else
  1728. memcpy(pDesc, &m_OutputSignature[ParameterIndex],
  1729. // D3D11_43 does not have MinPrecison.
  1730. sizeof(D3D12_SIGNATURE_PARAMETER_DESC) - sizeof(D3D_MIN_PRECISION));
  1731. return S_OK;
  1732. }
  1733. _Use_decl_annotations_
  1734. HRESULT DxilShaderReflection::GetPatchConstantParameterDesc(UINT ParameterIndex,
  1735. D3D12_SIGNATURE_PARAMETER_DESC *pDesc) {
  1736. IFRBOOL(pDesc != nullptr, E_INVALIDARG);
  1737. IFRBOOL(ParameterIndex < m_PatchConstantSignature.size(), E_INVALIDARG);
  1738. if (m_PublicAPI != PublicAPI::D3D11_43)
  1739. *pDesc = m_PatchConstantSignature[ParameterIndex];
  1740. else
  1741. memcpy(pDesc, &m_PatchConstantSignature[ParameterIndex],
  1742. // D3D11_43 does not have MinPrecison.
  1743. sizeof(D3D12_SIGNATURE_PARAMETER_DESC) - sizeof(D3D_MIN_PRECISION));
  1744. return S_OK;
  1745. }
  1746. _Use_decl_annotations_
  1747. ID3D12ShaderReflectionVariable* DxilShaderReflection::GetVariableByName(LPCSTR Name) {
  1748. return DxilModuleReflection::_GetVariableByName(Name);
  1749. }
  1750. ID3D12ShaderReflectionVariable* DxilModuleReflection::_GetVariableByName(LPCSTR Name) {
  1751. if (Name != nullptr) {
  1752. // Iterate through all cbuffers to find the variable.
  1753. for (UINT i = 0; i < m_CBs.size(); i++) {
  1754. ID3D12ShaderReflectionVariable *pVar = m_CBs[i]->GetVariableByName(Name);
  1755. if (pVar != &g_InvalidSRVariable) {
  1756. return pVar;
  1757. }
  1758. }
  1759. }
  1760. return &g_InvalidSRVariable;
  1761. }
  1762. _Use_decl_annotations_
  1763. HRESULT DxilShaderReflection::GetResourceBindingDescByName(LPCSTR Name,
  1764. D3D12_SHADER_INPUT_BIND_DESC *pDesc) {
  1765. return DxilModuleReflection::_GetResourceBindingDescByName(Name, pDesc, m_PublicAPI);
  1766. }
  1767. HRESULT DxilModuleReflection::_GetResourceBindingDescByName(LPCSTR Name,
  1768. D3D12_SHADER_INPUT_BIND_DESC *pDesc, PublicAPI api) {
  1769. IFRBOOL(Name != nullptr, E_INVALIDARG);
  1770. for (UINT i = 0; i < m_Resources.size(); i++) {
  1771. if (strcmp(m_Resources[i].Name, Name) == 0) {
  1772. if (api != PublicAPI::D3D12) {
  1773. memcpy(pDesc, &m_Resources[i], sizeof(D3D11_SHADER_INPUT_BIND_DESC));
  1774. }
  1775. else {
  1776. *pDesc = m_Resources[i];
  1777. }
  1778. return S_OK;
  1779. }
  1780. }
  1781. return HRESULT_FROM_WIN32(ERROR_NOT_FOUND);
  1782. }
  1783. UINT DxilShaderReflection::GetMovInstructionCount() { return 0; }
  1784. UINT DxilShaderReflection::GetMovcInstructionCount() { return 0; }
  1785. UINT DxilShaderReflection::GetConversionInstructionCount() { return 0; }
  1786. UINT DxilShaderReflection::GetBitwiseInstructionCount() { return 0; }
  1787. D3D_PRIMITIVE DxilShaderReflection::GetGSInputPrimitive() {
  1788. if (!m_pDxilModule->GetShaderModel()->IsGS())
  1789. return D3D_PRIMITIVE::D3D10_PRIMITIVE_UNDEFINED;
  1790. return (D3D_PRIMITIVE)m_pDxilModule->GetInputPrimitive();
  1791. }
  1792. BOOL DxilShaderReflection::IsSampleFrequencyShader() {
  1793. // TODO: determine correct value
  1794. return FALSE;
  1795. }
  1796. UINT DxilShaderReflection::GetNumInterfaceSlots() { return 0; }
  1797. _Use_decl_annotations_
  1798. HRESULT DxilShaderReflection::GetMinFeatureLevel(enum D3D_FEATURE_LEVEL* pLevel) {
  1799. IFR(AssignToOut(D3D_FEATURE_LEVEL_12_0, pLevel));
  1800. return S_OK;
  1801. }
  1802. _Use_decl_annotations_
  1803. UINT DxilShaderReflection::GetThreadGroupSize(UINT *pSizeX, UINT *pSizeY, UINT *pSizeZ) {
  1804. if (!m_pDxilModule->GetShaderModel()->IsCS()) {
  1805. AssignToOutOpt((UINT)0, pSizeX);
  1806. AssignToOutOpt((UINT)0, pSizeY);
  1807. AssignToOutOpt((UINT)0, pSizeZ);
  1808. return 0;
  1809. }
  1810. unsigned x = m_pDxilModule->GetNumThreads(0);
  1811. unsigned y = m_pDxilModule->GetNumThreads(1);
  1812. unsigned z = m_pDxilModule->GetNumThreads(2);
  1813. AssignToOutOpt(x, pSizeX);
  1814. AssignToOutOpt(y, pSizeY);
  1815. AssignToOutOpt(z, pSizeZ);
  1816. return x * y * z;
  1817. }
  1818. UINT64 DxilShaderReflection::GetRequiresFlags() {
  1819. UINT64 result = 0;
  1820. uint64_t features = m_pDxilModule->m_ShaderFlags.GetFeatureInfo();
  1821. if (features & ShaderFeatureInfo_Doubles) result |= D3D_SHADER_REQUIRES_DOUBLES;
  1822. if (features & ShaderFeatureInfo_UAVsAtEveryStage) result |= D3D_SHADER_REQUIRES_UAVS_AT_EVERY_STAGE;
  1823. if (features & ShaderFeatureInfo_64UAVs) result |= D3D_SHADER_REQUIRES_64_UAVS;
  1824. if (features & ShaderFeatureInfo_MinimumPrecision) result |= D3D_SHADER_REQUIRES_MINIMUM_PRECISION;
  1825. if (features & ShaderFeatureInfo_11_1_DoubleExtensions) result |= D3D_SHADER_REQUIRES_11_1_DOUBLE_EXTENSIONS;
  1826. if (features & ShaderFeatureInfo_11_1_ShaderExtensions) result |= D3D_SHADER_REQUIRES_11_1_SHADER_EXTENSIONS;
  1827. if (features & ShaderFeatureInfo_LEVEL9ComparisonFiltering) result |= D3D_SHADER_REQUIRES_LEVEL_9_COMPARISON_FILTERING;
  1828. if (features & ShaderFeatureInfo_TiledResources) result |= D3D_SHADER_REQUIRES_TILED_RESOURCES;
  1829. if (features & ShaderFeatureInfo_StencilRef) result |= D3D_SHADER_REQUIRES_STENCIL_REF;
  1830. if (features & ShaderFeatureInfo_InnerCoverage) result |= D3D_SHADER_REQUIRES_INNER_COVERAGE;
  1831. if (features & ShaderFeatureInfo_TypedUAVLoadAdditionalFormats) result |= D3D_SHADER_REQUIRES_TYPED_UAV_LOAD_ADDITIONAL_FORMATS;
  1832. if (features & ShaderFeatureInfo_ROVs) result |= D3D_SHADER_REQUIRES_ROVS;
  1833. if (features & ShaderFeatureInfo_ViewportAndRTArrayIndexFromAnyShaderFeedingRasterizer) result |= D3D_SHADER_REQUIRES_VIEWPORT_AND_RT_ARRAY_INDEX_FROM_ANY_SHADER_FEEDING_RASTERIZER;
  1834. return result;
  1835. }
  1836. // ID3D12FunctionReflection
  1837. class CFunctionReflection : public ID3D12FunctionReflection {
  1838. protected:
  1839. DxilLibraryReflection * m_pLibraryReflection = nullptr;
  1840. const Function *m_pFunction;
  1841. const DxilFunctionProps *m_pProps; // nullptr if non-shader library function or patch constant function
  1842. std::string m_Name;
  1843. typedef SmallSetVector<UINT32, 8> ResourceUseSet;
  1844. ResourceUseSet m_UsedResources;
  1845. ResourceUseSet m_UsedCBs;
  1846. public:
  1847. void Initialize(DxilLibraryReflection* pLibraryReflection, Function *pFunction) {
  1848. DXASSERT_NOMSG(pLibraryReflection);
  1849. DXASSERT_NOMSG(pFunction);
  1850. m_pLibraryReflection = pLibraryReflection;
  1851. m_pFunction = pFunction;
  1852. const DxilModule &M = *m_pLibraryReflection->m_pDxilModule;
  1853. m_Name = m_pFunction->getName().str();
  1854. m_pProps = nullptr;
  1855. if (M.HasDxilFunctionProps(m_pFunction)) {
  1856. m_pProps = &M.GetDxilFunctionProps(m_pFunction);
  1857. }
  1858. }
  1859. void AddResourceReference(UINT resIndex) {
  1860. m_UsedResources.insert(resIndex);
  1861. }
  1862. void AddCBReference(UINT cbIndex) {
  1863. m_UsedCBs.insert(cbIndex);
  1864. }
  1865. // ID3D12FunctionReflection
  1866. STDMETHOD(GetDesc)(THIS_ _Out_ D3D12_FUNCTION_DESC * pDesc);
  1867. // BufferIndex relative to used constant buffers here
  1868. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer *, GetConstantBufferByIndex)(THIS_ _In_ UINT BufferIndex);
  1869. STDMETHOD_(ID3D12ShaderReflectionConstantBuffer *, GetConstantBufferByName)(THIS_ _In_ LPCSTR Name);
  1870. STDMETHOD(GetResourceBindingDesc)(THIS_ _In_ UINT ResourceIndex,
  1871. _Out_ D3D12_SHADER_INPUT_BIND_DESC * pDesc);
  1872. STDMETHOD_(ID3D12ShaderReflectionVariable *, GetVariableByName)(THIS_ _In_ LPCSTR Name);
  1873. STDMETHOD(GetResourceBindingDescByName)(THIS_ _In_ LPCSTR Name,
  1874. _Out_ D3D12_SHADER_INPUT_BIND_DESC * pDesc);
  1875. // Use D3D_RETURN_PARAMETER_INDEX to get description of the return value.
  1876. STDMETHOD_(ID3D12FunctionParameterReflection *, GetFunctionParameter)(THIS_ _In_ INT ParameterIndex) {
  1877. return &g_InvalidFunctionParameter;
  1878. }
  1879. };
  1880. _Use_decl_annotations_
  1881. HRESULT CFunctionReflection::GetDesc(D3D12_FUNCTION_DESC *pDesc) {
  1882. DXASSERT_NOMSG(m_pLibraryReflection);
  1883. IFR(ZeroMemoryToOut(pDesc));
  1884. const ShaderModel* pSM = m_pLibraryReflection->m_pDxilModule->GetShaderModel();
  1885. DXIL::ShaderKind kind = DXIL::ShaderKind::Library;
  1886. if (m_pProps) {
  1887. kind = m_pProps->shaderKind;
  1888. }
  1889. pDesc->Version = EncodeVersion(kind, pSM->GetMajor(), pSM->GetMinor());
  1890. //Unset: LPCSTR Creator; // Creator string
  1891. //Unset: UINT Flags; // Shader compilation/parse flags
  1892. pDesc->ConstantBuffers = (UINT)m_UsedCBs.size();
  1893. pDesc->BoundResources = (UINT)m_UsedResources.size();
  1894. //Unset: UINT InstructionCount; // Number of emitted instructions
  1895. //Unset: UINT TempRegisterCount; // Number of temporary registers used
  1896. //Unset: UINT TempArrayCount; // Number of temporary arrays used
  1897. //Unset: UINT DefCount; // Number of constant defines
  1898. //Unset: UINT DclCount; // Number of declarations (input + output)
  1899. //Unset: UINT TextureNormalInstructions; // Number of non-categorized texture instructions
  1900. //Unset: UINT TextureLoadInstructions; // Number of texture load instructions
  1901. //Unset: UINT TextureCompInstructions; // Number of texture comparison instructions
  1902. //Unset: UINT TextureBiasInstructions; // Number of texture bias instructions
  1903. //Unset: UINT TextureGradientInstructions; // Number of texture gradient instructions
  1904. //Unset: UINT FloatInstructionCount; // Number of floating point arithmetic instructions used
  1905. //Unset: UINT IntInstructionCount; // Number of signed integer arithmetic instructions used
  1906. //Unset: UINT UintInstructionCount; // Number of unsigned integer arithmetic instructions used
  1907. //Unset: UINT StaticFlowControlCount; // Number of static flow control instructions used
  1908. //Unset: UINT DynamicFlowControlCount; // Number of dynamic flow control instructions used
  1909. //Unset: UINT MacroInstructionCount; // Number of macro instructions used
  1910. //Unset: UINT ArrayInstructionCount; // Number of array instructions used
  1911. //Unset: UINT MovInstructionCount; // Number of mov instructions used
  1912. //Unset: UINT MovcInstructionCount; // Number of movc instructions used
  1913. //Unset: UINT ConversionInstructionCount; // Number of type conversion instructions used
  1914. //Unset: UINT BitwiseInstructionCount; // Number of bitwise arithmetic instructions used
  1915. //Unset: D3D_FEATURE_LEVEL MinFeatureLevel; // Min target of the function byte code
  1916. //Unset: UINT64 RequiredFeatureFlags; // Required feature flags
  1917. pDesc->Name = m_Name.c_str();
  1918. //Unset: INT FunctionParameterCount; // Number of logical parameters in the function signature (not including return)
  1919. //Unset: BOOL HasReturn; // TRUE, if function returns a value, false - it is a subroutine
  1920. //Unset: BOOL Has10Level9VertexShader; // TRUE, if there is a 10L9 VS blob
  1921. //Unset: BOOL Has10Level9PixelShader; // TRUE, if there is a 10L9 PS blob
  1922. return S_OK;
  1923. }
  1924. // BufferIndex is relative to used constant buffers here
  1925. ID3D12ShaderReflectionConstantBuffer *CFunctionReflection::GetConstantBufferByIndex(UINT BufferIndex) {
  1926. DXASSERT_NOMSG(m_pLibraryReflection);
  1927. if (BufferIndex >= m_UsedCBs.size())
  1928. return &g_InvalidSRConstantBuffer;
  1929. return m_pLibraryReflection->_GetConstantBufferByIndex(m_UsedCBs[BufferIndex]);
  1930. }
  1931. ID3D12ShaderReflectionConstantBuffer *CFunctionReflection::GetConstantBufferByName(LPCSTR Name) {
  1932. DXASSERT_NOMSG(m_pLibraryReflection);
  1933. return m_pLibraryReflection->_GetConstantBufferByName(Name);
  1934. }
  1935. HRESULT CFunctionReflection::GetResourceBindingDesc(UINT ResourceIndex,
  1936. D3D12_SHADER_INPUT_BIND_DESC * pDesc) {
  1937. DXASSERT_NOMSG(m_pLibraryReflection);
  1938. if (ResourceIndex >= m_UsedResources.size())
  1939. return E_INVALIDARG;
  1940. return m_pLibraryReflection->_GetResourceBindingDesc(m_UsedResources[ResourceIndex], pDesc);
  1941. }
  1942. ID3D12ShaderReflectionVariable * CFunctionReflection::GetVariableByName(LPCSTR Name) {
  1943. DXASSERT_NOMSG(m_pLibraryReflection);
  1944. return m_pLibraryReflection->_GetVariableByName(Name);
  1945. }
  1946. HRESULT CFunctionReflection::GetResourceBindingDescByName(LPCSTR Name,
  1947. D3D12_SHADER_INPUT_BIND_DESC * pDesc) {
  1948. DXASSERT_NOMSG(m_pLibraryReflection);
  1949. return m_pLibraryReflection->_GetResourceBindingDescByName(Name, pDesc);
  1950. }
  1951. // DxilLibraryReflection
  1952. // From DxilContainerAssembler:
  1953. static llvm::Function *FindUsingFunction(llvm::Value *User) {
  1954. if (llvm::Instruction *I = dyn_cast<llvm::Instruction>(User)) {
  1955. // Instruction should be inside a basic block, which is in a function
  1956. return cast<llvm::Function>(I->getParent()->getParent());
  1957. }
  1958. // User can be either instruction, constant, or operator. But User is an
  1959. // operator only if constant is a scalar value, not resource pointer.
  1960. llvm::Constant *CU = cast<llvm::Constant>(User);
  1961. if (!CU->user_empty())
  1962. return FindUsingFunction(*CU->user_begin());
  1963. else
  1964. return nullptr;
  1965. }
  1966. void DxilLibraryReflection::AddResourceUseToFunctions(DxilResourceBase &resource, unsigned resIndex) {
  1967. Constant *var = resource.GetGlobalSymbol();
  1968. if (var) {
  1969. for (auto user : var->users()) {
  1970. // Find the function.
  1971. if (llvm::Function *F = FindUsingFunction(user)) {
  1972. auto funcReflector = m_FunctionsByPtr[F];
  1973. funcReflector->AddResourceReference(resIndex);
  1974. if (resource.GetClass() == DXIL::ResourceClass::CBuffer) {
  1975. funcReflector->AddCBReference(resource.GetID());
  1976. }
  1977. }
  1978. }
  1979. }
  1980. }
  1981. void DxilLibraryReflection::AddResourceDependencies() {
  1982. std::map<StringRef, CFunctionReflection*> orderedMap;
  1983. for (auto &F : m_pModule->functions()) {
  1984. if (F.isDeclaration())
  1985. continue;
  1986. auto &func = m_FunctionMap[F.getName()];
  1987. DXASSERT(!func.get(), "otherwise duplicate named functions");
  1988. func.reset(new CFunctionReflection());
  1989. func->Initialize(this, &F);
  1990. m_FunctionsByPtr[&F] = func.get();
  1991. orderedMap[F.getName()] = func.get();
  1992. }
  1993. // Fill in function vector sorted by name
  1994. m_FunctionVector.clear();
  1995. m_FunctionVector.reserve(orderedMap.size());
  1996. for (auto &it : orderedMap) {
  1997. m_FunctionVector.push_back(it.second);
  1998. }
  1999. UINT resIndex = 0;
  2000. for (auto &resource : m_Resources) {
  2001. switch ((UINT32)resource.Type) {
  2002. case D3D_SIT_CBUFFER:
  2003. AddResourceUseToFunctions(m_pDxilModule->GetCBuffer(resource.uID), resIndex);
  2004. break;
  2005. case D3D_SIT_TBUFFER: // TODO: Handle when TBuffers are added to CB list
  2006. case D3D_SIT_TEXTURE:
  2007. case D3D_SIT_STRUCTURED:
  2008. case D3D_SIT_BYTEADDRESS:
  2009. case D3D_SIT_RTACCELERATIONSTRUCTURE:
  2010. AddResourceUseToFunctions(m_pDxilModule->GetSRV(resource.uID), resIndex);
  2011. break;
  2012. case D3D_SIT_UAV_RWTYPED:
  2013. case D3D_SIT_UAV_RWSTRUCTURED:
  2014. case D3D_SIT_UAV_RWBYTEADDRESS:
  2015. case D3D_SIT_UAV_APPEND_STRUCTURED:
  2016. case D3D_SIT_UAV_CONSUME_STRUCTURED:
  2017. case D3D_SIT_UAV_RWSTRUCTURED_WITH_COUNTER:
  2018. AddResourceUseToFunctions(m_pDxilModule->GetUAV(resource.uID), resIndex);
  2019. break;
  2020. case D3D_SIT_SAMPLER:
  2021. AddResourceUseToFunctions(m_pDxilModule->GetSampler(resource.uID), resIndex);
  2022. break;
  2023. }
  2024. resIndex++;
  2025. }
  2026. }
  2027. static void CollectCBufUsageForLib(Value *V, std::vector<unsigned> &cbufUsage) {
  2028. for (auto user : V->users()) {
  2029. Value *V = user;
  2030. if (auto *CI = dyn_cast<CallInst>(V)) {
  2031. if (hlsl::OP::IsDxilOpFuncCallInst(CI, hlsl::OP::OpCode::CreateHandleForLib)) {
  2032. CollectCBufUsage(CI, cbufUsage);
  2033. }
  2034. } else if (isa<GEPOperator>(V) ||
  2035. isa<LoadInst>(V)) {
  2036. CollectCBufUsageForLib(user, cbufUsage);
  2037. }
  2038. }
  2039. }
  2040. void DxilLibraryReflection::SetCBufferUsage() {
  2041. unsigned cbSize = std::min(m_CBs.size(), m_pDxilModule->GetCBuffers().size());
  2042. for (unsigned i=0;i<cbSize;i++) {
  2043. std::vector<unsigned> cbufUsage;
  2044. CollectCBufUsageForLib(m_pDxilModule->GetCBuffer(i).GetGlobalSymbol(), cbufUsage);
  2045. SetCBufVarUsage(*m_CBs[i], cbufUsage);
  2046. }
  2047. }
  2048. // ID3D12LibraryReflection
  2049. HRESULT DxilLibraryReflection::Load(IDxcBlob *pBlob,
  2050. const DxilPartHeader *pPart) {
  2051. IFR(LoadModule(pBlob, pPart));
  2052. try {
  2053. AddResourceDependencies();
  2054. SetCBufferUsage();
  2055. return S_OK;
  2056. }
  2057. CATCH_CPP_RETURN_HRESULT();
  2058. }
  2059. _Use_decl_annotations_
  2060. HRESULT DxilLibraryReflection::GetDesc(D3D12_LIBRARY_DESC * pDesc) {
  2061. IFR(ZeroMemoryToOut(pDesc));
  2062. //Unset: LPCSTR Creator; // The name of the originator of the library.
  2063. //Unset: UINT Flags; // Compilation flags.
  2064. //UINT FunctionCount; // Number of functions exported from the library.
  2065. pDesc->FunctionCount = (UINT)m_FunctionVector.size();
  2066. return S_OK;
  2067. }
  2068. _Use_decl_annotations_
  2069. ID3D12FunctionReflection *DxilLibraryReflection::GetFunctionByIndex(INT FunctionIndex) {
  2070. if (FunctionIndex >= m_FunctionVector.size())
  2071. return &g_InvalidFunction;
  2072. return m_FunctionVector[FunctionIndex];
  2073. }
  2074. // DxilRuntimeReflection implementation
  2075. #include "dxc/HLSL/DxilRuntimeReflection.inl"
  2076. #else
  2077. void hlsl::CreateDxcContainerReflection(IDxcContainerReflection **ppResult) {
  2078. *ppResult = nullptr;
  2079. }
  2080. DEFINE_CROSS_PLATFORM_UUIDOF(IDxcContainerReflection)
  2081. #endif // LLVM_ON_WIN32