DxilPipelineStateValidation.h 35 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPipelineStateValidation.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Defines data used by the D3D runtime for PSO validation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #ifndef __DXIL_PIPELINE_STATE_VALIDATION__H__
  12. #define __DXIL_PIPELINE_STATE_VALIDATION__H__
  13. #include <stdint.h>
  14. #include <string.h>
  15. #ifndef UINT_MAX
  16. #define UINT_MAX 0xffffffff
  17. #endif
  18. // How many dwords are required for mask with one bit per component, 4 components per vector
  19. inline uint32_t PSVComputeMaskDwordsFromVectors(uint32_t Vectors) { return (Vectors + 7) >> 3; }
  20. inline uint32_t PSVComputeInputOutputTableSize(uint32_t InputVectors, uint32_t OutputVectors) {
  21. return sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(OutputVectors) * InputVectors * 4;
  22. }
  23. #define PSVALIGN(ptr, alignbits) (((ptr) + ((1 << (alignbits))-1)) & ~((1 << (alignbits))-1))
  24. #define PSVALIGN4(ptr) (((ptr) + 3) & ~3)
  25. struct VSInfo {
  26. char OutputPositionPresent;
  27. };
  28. struct HSInfo {
  29. uint32_t InputControlPointCount; // max control points == 32
  30. uint32_t OutputControlPointCount; // max control points == 32
  31. uint32_t TessellatorDomain; // hlsl::DXIL::TessellatorDomain/D3D11_SB_TESSELLATOR_DOMAIN
  32. uint32_t TessellatorOutputPrimitive; // hlsl::DXIL::TessellatorOutputPrimitive/D3D11_SB_TESSELLATOR_OUTPUT_PRIMITIVE
  33. };
  34. struct DSInfo {
  35. uint32_t InputControlPointCount; // max control points == 32
  36. char OutputPositionPresent;
  37. uint32_t TessellatorDomain; // hlsl::DXIL::TessellatorDomain/D3D11_SB_TESSELLATOR_DOMAIN
  38. };
  39. struct GSInfo {
  40. uint32_t InputPrimitive; // hlsl::DXIL::InputPrimitive/D3D10_SB_PRIMITIVE
  41. uint32_t OutputTopology; // hlsl::DXIL::PrimitiveTopology/D3D10_SB_PRIMITIVE_TOPOLOGY
  42. uint32_t OutputStreamMask; // max streams == 4
  43. char OutputPositionPresent;
  44. };
  45. struct PSInfo {
  46. char DepthOutput;
  47. char SampleFrequency;
  48. };
  49. struct MSInfo {
  50. uint32_t GroupSharedBytesUsed;
  51. uint32_t GroupSharedBytesDependentOnViewID;
  52. uint32_t PayloadSizeInBytes;
  53. uint16_t MaxOutputVertices;
  54. uint16_t MaxOutputPrimitives;
  55. };
  56. struct ASInfo {
  57. uint32_t PayloadSizeInBytes;
  58. };
  59. struct MSInfo1 {
  60. uint8_t SigPrimVectors; // Primitive output for MS
  61. uint8_t MeshOutputTopology;
  62. };
  63. // Versioning is additive and based on size
  64. struct PSVRuntimeInfo0
  65. {
  66. union {
  67. VSInfo VS;
  68. HSInfo HS;
  69. DSInfo DS;
  70. GSInfo GS;
  71. PSInfo PS;
  72. MSInfo MS;
  73. ASInfo AS;
  74. };
  75. uint32_t MinimumExpectedWaveLaneCount; // minimum lane count required, 0 if unused
  76. uint32_t MaximumExpectedWaveLaneCount; // maximum lane count required, 0xffffffff if unused
  77. };
  78. enum class PSVShaderKind : uint8_t // DXIL::ShaderKind
  79. {
  80. Pixel = 0,
  81. Vertex,
  82. Geometry,
  83. Hull,
  84. Domain,
  85. Compute,
  86. Library,
  87. RayGeneration,
  88. Intersection,
  89. AnyHit,
  90. ClosestHit,
  91. Miss,
  92. Callable,
  93. Mesh,
  94. Amplification,
  95. Invalid,
  96. };
  97. struct PSVRuntimeInfo1 : public PSVRuntimeInfo0
  98. {
  99. uint8_t ShaderStage; // PSVShaderKind
  100. uint8_t UsesViewID;
  101. union {
  102. uint16_t MaxVertexCount; // MaxVertexCount for GS only (max 1024)
  103. uint8_t SigPatchConstOrPrimVectors; // Output for HS; Input for DS; Primitive output for MS (overlaps MS1::SigPrimVectors)
  104. struct MSInfo1 MS1;
  105. };
  106. // PSVSignatureElement counts
  107. uint8_t SigInputElements;
  108. uint8_t SigOutputElements;
  109. uint8_t SigPatchConstOrPrimElements;
  110. // Number of packed vectors per signature
  111. uint8_t SigInputVectors;
  112. uint8_t SigOutputVectors[4]; // Array for GS Stream Out Index
  113. };
  114. enum class PSVResourceType
  115. {
  116. Invalid = 0,
  117. Sampler,
  118. CBV,
  119. SRVTyped,
  120. SRVRaw,
  121. SRVStructured,
  122. UAVTyped,
  123. UAVRaw,
  124. UAVStructured,
  125. UAVStructuredWithCounter,
  126. NumEntries
  127. };
  128. enum class PSVResourceKind
  129. {
  130. Invalid = 0,
  131. Texture1D,
  132. Texture2D,
  133. Texture2DMS,
  134. Texture3D,
  135. TextureCube,
  136. Texture1DArray,
  137. Texture2DArray,
  138. Texture2DMSArray,
  139. TextureCubeArray,
  140. TypedBuffer,
  141. RawBuffer,
  142. StructuredBuffer,
  143. CBuffer,
  144. Sampler,
  145. TBuffer,
  146. RTAccelerationStructure,
  147. NumEntries
  148. };
  149. // Table of null-terminated strings, overall size aligned to dword boundary, last byte must be null
  150. struct PSVStringTable {
  151. const char *Table;
  152. uint32_t Size;
  153. PSVStringTable() : Table(nullptr), Size(0) {}
  154. PSVStringTable(const char *table, uint32_t size) : Table(table), Size(size) {}
  155. const char *Get(uint32_t offset) const {
  156. _Analysis_assume_(offset < Size && Table && Table[Size-1] == '\0');
  157. return Table + offset;
  158. }
  159. };
  160. // Versioning is additive and based on size
  161. struct PSVResourceBindInfo0
  162. {
  163. uint32_t ResType; // PSVResourceType
  164. uint32_t Space;
  165. uint32_t LowerBound;
  166. uint32_t UpperBound;
  167. };
  168. // Helpers for output dependencies (ViewID and Input-Output tables)
  169. struct PSVComponentMask {
  170. uint32_t *Mask;
  171. uint32_t NumVectors;
  172. PSVComponentMask() : Mask(nullptr), NumVectors(0) {}
  173. PSVComponentMask(const PSVComponentMask &other) : Mask(other.Mask), NumVectors(other.NumVectors) {}
  174. PSVComponentMask(uint32_t *pMask, uint32_t outputVectors)
  175. : Mask(pMask),
  176. NumVectors(outputVectors)
  177. {}
  178. const PSVComponentMask &operator|=(const PSVComponentMask &other) {
  179. uint32_t dwords = PSVComputeMaskDwordsFromVectors(NumVectors < other.NumVectors ? NumVectors : other.NumVectors);
  180. for (uint32_t i = 0; i < dwords; ++i) {
  181. Mask[i] |= other.Mask[i];
  182. }
  183. return *this;
  184. }
  185. bool Get(uint32_t ComponentIndex) const {
  186. if(ComponentIndex < NumVectors * 4)
  187. return (bool)(Mask[ComponentIndex >> 5] & (1 << (ComponentIndex & 0x1F)));
  188. return false;
  189. }
  190. void Set(uint32_t ComponentIndex) {
  191. if (ComponentIndex < NumVectors * 4)
  192. Mask[ComponentIndex >> 5] |= (1 << (ComponentIndex & 0x1F));
  193. }
  194. void Clear(uint32_t ComponentIndex) {
  195. if (ComponentIndex < NumVectors * 4)
  196. Mask[ComponentIndex >> 5] &= ~(1 << (ComponentIndex & 0x1F));
  197. }
  198. bool IsValid() { return Mask != nullptr; }
  199. };
  200. struct PSVDependencyTable {
  201. uint32_t *Table;
  202. uint32_t InputVectors;
  203. uint32_t OutputVectors;
  204. PSVDependencyTable() : Table(nullptr), InputVectors(0), OutputVectors(0) {}
  205. PSVDependencyTable(const PSVDependencyTable &other) : Table(other.Table), InputVectors(other.InputVectors), OutputVectors(other.OutputVectors) {}
  206. PSVDependencyTable(uint32_t *pTable, uint32_t inputVectors, uint32_t outputVectors)
  207. : Table(pTable),
  208. InputVectors(inputVectors),
  209. OutputVectors(outputVectors)
  210. {}
  211. PSVComponentMask GetMaskForInput(uint32_t inputComponentIndex) {
  212. if (!Table || !InputVectors || !OutputVectors)
  213. return PSVComponentMask();
  214. return PSVComponentMask(Table + (PSVComputeMaskDwordsFromVectors(OutputVectors) * inputComponentIndex), OutputVectors);
  215. }
  216. bool IsValid() { return Table != nullptr; }
  217. };
  218. struct PSVString {
  219. uint32_t Offset;
  220. PSVString() : Offset(0) {}
  221. PSVString(uint32_t offset) : Offset(offset) {}
  222. const char *Get(const PSVStringTable &table) const { return table.Get(Offset); }
  223. };
  224. struct PSVSemanticIndexTable {
  225. const uint32_t *Table;
  226. uint32_t Entries;
  227. PSVSemanticIndexTable() : Table(nullptr), Entries(0) {}
  228. PSVSemanticIndexTable(const uint32_t *table, uint32_t entries) : Table(table), Entries(entries) {}
  229. const uint32_t *Get(uint32_t offset) const {
  230. _Analysis_assume_(offset < Entries && Table);
  231. return Table + offset;
  232. }
  233. };
  234. struct PSVSemanticIndexes {
  235. uint32_t Offset;
  236. PSVSemanticIndexes() : Offset(0) {}
  237. PSVSemanticIndexes(uint32_t offset) : Offset(offset) {}
  238. const uint32_t *Get(const PSVSemanticIndexTable &table) const { return table.Get(Offset); }
  239. };
  240. enum class PSVSemanticKind : uint8_t // DXIL::SemanticKind
  241. {
  242. Arbitrary,
  243. VertexID,
  244. InstanceID,
  245. Position,
  246. RenderTargetArrayIndex,
  247. ViewPortArrayIndex,
  248. ClipDistance,
  249. CullDistance,
  250. OutputControlPointID,
  251. DomainLocation,
  252. PrimitiveID,
  253. GSInstanceID,
  254. SampleIndex,
  255. IsFrontFace,
  256. Coverage,
  257. InnerCoverage,
  258. Target,
  259. Depth,
  260. DepthLessEqual,
  261. DepthGreaterEqual,
  262. StencilRef,
  263. DispatchThreadID,
  264. GroupID,
  265. GroupIndex,
  266. GroupThreadID,
  267. TessFactor,
  268. InsideTessFactor,
  269. ViewID,
  270. Barycentrics,
  271. ShadingRate,
  272. CullPrimitive,
  273. Invalid,
  274. };
  275. struct PSVSignatureElement0
  276. {
  277. uint32_t SemanticName; // Offset into StringTable
  278. uint32_t SemanticIndexes; // Offset into PSVSemanticIndexTable, count == Rows
  279. uint8_t Rows; // Number of rows this element occupies
  280. uint8_t StartRow; // Starting row of packing location if allocated
  281. uint8_t ColsAndStart; // 0:4 = Cols, 4:6 = StartCol, 6:7 == Allocated
  282. uint8_t SemanticKind; // PSVSemanticKind
  283. uint8_t ComponentType; // DxilProgramSigCompType
  284. uint8_t InterpolationMode; // DXIL::InterpolationMode or D3D10_SB_INTERPOLATION_MODE
  285. uint8_t DynamicMaskAndStream; // 0:4 = DynamicIndexMask, 4:6 = OutputStream (0-3)
  286. uint8_t Reserved;
  287. };
  288. // Provides convenient access to packed PSVSignatureElementN structure
  289. class PSVSignatureElement
  290. {
  291. private:
  292. const PSVStringTable &m_StringTable;
  293. const PSVSemanticIndexTable &m_SemanticIndexTable;
  294. const PSVSignatureElement0 *m_pElement0;
  295. public:
  296. PSVSignatureElement(const PSVStringTable &stringTable, const PSVSemanticIndexTable &semanticIndexTable, const PSVSignatureElement0 *pElement0)
  297. : m_StringTable(stringTable), m_SemanticIndexTable(semanticIndexTable), m_pElement0(pElement0) {}
  298. const char *GetSemanticName() const { return !m_pElement0 ? "" : m_StringTable.Get(m_pElement0->SemanticName); }
  299. const uint32_t *GetSemanticIndexes() const { return !m_pElement0 ? nullptr: m_SemanticIndexTable.Get(m_pElement0->SemanticIndexes); }
  300. uint32_t GetRows() const { return !m_pElement0 ? 0 : ((uint32_t)m_pElement0->Rows); }
  301. uint32_t GetCols() const { return !m_pElement0 ? 0 : ((uint32_t)m_pElement0->ColsAndStart & 0xF); }
  302. bool IsAllocated() const { return !m_pElement0 ? false : !!(m_pElement0->ColsAndStart & 0x40); }
  303. int32_t GetStartRow() const { return !m_pElement0 ? 0 : !IsAllocated() ? -1 : (int32_t)m_pElement0->StartRow; }
  304. int32_t GetStartCol() const { return !m_pElement0 ? 0 : !IsAllocated() ? -1 : (int32_t)((m_pElement0->ColsAndStart >> 4) & 0x3); }
  305. PSVSemanticKind GetSemanticKind() const { return !m_pElement0 ? (PSVSemanticKind)0 : (PSVSemanticKind)m_pElement0->SemanticKind; }
  306. uint32_t GetComponentType() const { return !m_pElement0 ? 0 : (uint32_t)m_pElement0->ComponentType; }
  307. uint32_t GetInterpolationMode() const { return !m_pElement0 ? 0 : (uint32_t)m_pElement0->InterpolationMode; }
  308. uint32_t GetOutputStream() const { return !m_pElement0 ? 0 : (uint32_t)(m_pElement0->DynamicMaskAndStream >> 4) & 0x3; }
  309. uint32_t GetDynamicIndexMask() const { return !m_pElement0 ? 0 : (uint32_t)m_pElement0->DynamicMaskAndStream & 0xF; }
  310. };
  311. #define MAX_PSV_VERSION 1
  312. struct PSVInitInfo
  313. {
  314. PSVInitInfo(uint32_t psvVersion)
  315. : PSVVersion(psvVersion),
  316. ResourceCount(0),
  317. ShaderStage(PSVShaderKind::Invalid),
  318. StringTable(),
  319. SemanticIndexTable(),
  320. UsesViewID(0),
  321. SigInputElements(0),
  322. SigOutputElements(0),
  323. SigPatchConstOrPrimElements(0),
  324. SigInputVectors(0),
  325. SigPatchConstOrPrimVectors(0)
  326. {}
  327. uint32_t PSVVersion;
  328. uint32_t ResourceCount;
  329. PSVShaderKind ShaderStage;
  330. PSVStringTable StringTable;
  331. PSVSemanticIndexTable SemanticIndexTable;
  332. uint8_t UsesViewID;
  333. uint8_t SigInputElements;
  334. uint8_t SigOutputElements;
  335. uint8_t SigPatchConstOrPrimElements;
  336. uint8_t SigInputVectors;
  337. uint8_t SigPatchConstOrPrimVectors;
  338. uint8_t SigOutputVectors[4] = {0, 0, 0, 0};
  339. };
  340. class DxilPipelineStateValidation
  341. {
  342. uint32_t m_uPSVRuntimeInfoSize;
  343. PSVRuntimeInfo0* m_pPSVRuntimeInfo0;
  344. PSVRuntimeInfo1* m_pPSVRuntimeInfo1;
  345. uint32_t m_uResourceCount;
  346. uint32_t m_uPSVResourceBindInfoSize;
  347. void* m_pPSVResourceBindInfo;
  348. PSVStringTable m_StringTable;
  349. PSVSemanticIndexTable m_SemanticIndexTable;
  350. uint32_t m_uPSVSignatureElementSize;
  351. void* m_pSigInputElements;
  352. void* m_pSigOutputElements;
  353. void* m_pSigPatchConstOrPrimElements;
  354. uint32_t* m_pViewIDOutputMask;
  355. uint32_t* m_pViewIDPCOrPrimOutputMask;
  356. uint32_t* m_pInputToOutputTable;
  357. uint32_t* m_pInputToPCOutputTable;
  358. uint32_t* m_pPCInputToOutputTable;
  359. public:
  360. DxilPipelineStateValidation() :
  361. m_uPSVRuntimeInfoSize(0),
  362. m_pPSVRuntimeInfo0(nullptr),
  363. m_pPSVRuntimeInfo1(nullptr),
  364. m_uResourceCount(0),
  365. m_uPSVResourceBindInfoSize(0),
  366. m_pPSVResourceBindInfo(nullptr),
  367. m_StringTable(),
  368. m_SemanticIndexTable(),
  369. m_uPSVSignatureElementSize(0),
  370. m_pSigInputElements(nullptr),
  371. m_pSigOutputElements(nullptr),
  372. m_pSigPatchConstOrPrimElements(nullptr),
  373. m_pViewIDOutputMask(nullptr),
  374. m_pViewIDPCOrPrimOutputMask(nullptr),
  375. m_pInputToOutputTable(nullptr),
  376. m_pInputToPCOutputTable(nullptr),
  377. m_pPCInputToOutputTable(nullptr)
  378. {
  379. }
  380. // Init() from PSV0 blob part that looks like:
  381. // uint32_t PSVRuntimeInfo_size
  382. // { PSVRuntimeInfoN structure }
  383. // uint32_t ResourceCount
  384. // If ResourceCount > 0:
  385. // uint32_t PSVResourceBindInfo_size
  386. // { PSVResourceBindInfoN structure } * ResourceCount
  387. // If PSVRuntimeInfo1:
  388. // uint32_t StringTableSize (dword aligned)
  389. // If StringTableSize:
  390. // { StringTableSize bytes, null terminated }
  391. // uint32_t SemanticIndexTableEntries (number of dwords)
  392. // If SemanticIndexTableEntries:
  393. // { semantic index } * SemanticIndexTableEntries
  394. // If SigInputElements || SigOutputElements || SigPatchConstOrPrimElements:
  395. // uint32_t PSVSignatureElement_size
  396. // { PSVSignatureElementN structure } * SigInputElements
  397. // { PSVSignatureElementN structure } * SigOutputElements
  398. // { PSVSignatureElementN structure } * SigPatchConstOrPrimElements
  399. // If (UsesViewID):
  400. // For (i : each stream index 0-3):
  401. // If (SigOutputVectors[i] non-zero):
  402. // { uint32_t * PSVComputeMaskDwordsFromVectors(SigOutputVectors[i]) }
  403. // - Outputs affected by ViewID as a bitmask
  404. // If (HS and SigPatchConstOrPrimVectors non-zero):
  405. // { uint32_t * PSVComputeMaskDwordsFromVectors(SigPatchConstOrPrimVectors) }
  406. // - PCOutputs affected by ViewID as a bitmask
  407. // For (i : each stream index 0-3):
  408. // If (SigInputVectors and SigOutputVectors[i] non-zero):
  409. // { PSVComputeInputOutputTableSize(SigInputVectors, SigOutputVectors[i]) }
  410. // - Outputs affected by inputs as a table of bitmasks
  411. // If (HS and SigPatchConstOrPrimVectors and SigInputVectors non-zero):
  412. // { PSVComputeInputOutputTableSize(SigInputVectors, SigPatchConstOrPrimVectors) }
  413. // - Patch constant outputs affected by inputs as a table of bitmasks
  414. // If (DS and SigOutputVectors[0] and SigPatchConstOrPrimVectors non-zero):
  415. // { PSVComputeInputOutputTableSize(SigPatchConstOrPrimVectors, SigOutputVectors[0]) }
  416. // - Outputs affected by patch constant inputs as a table of bitmasks
  417. // returns true if no errors occurred.
  418. bool InitFromPSV0(const void* pBits, uint32_t size) {
  419. if(!(pBits != nullptr)) return false;
  420. uint8_t* pCurBits = (uint8_t*)const_cast<void*>(pBits);
  421. uint32_t minsize = sizeof(PSVRuntimeInfo0) + sizeof(uint32_t) * 2;
  422. if(!(size >= minsize)) return false;
  423. m_uPSVRuntimeInfoSize = *((const uint32_t*)pCurBits);
  424. if(!(m_uPSVRuntimeInfoSize >= sizeof(PSVRuntimeInfo0))) return false;
  425. pCurBits += sizeof(uint32_t);
  426. minsize = m_uPSVRuntimeInfoSize + sizeof(uint32_t) * 2;
  427. if(!(size >= minsize)) return false;
  428. m_pPSVRuntimeInfo0 = const_cast<PSVRuntimeInfo0*>((const PSVRuntimeInfo0*)pCurBits);
  429. if(m_uPSVRuntimeInfoSize >= sizeof(PSVRuntimeInfo1))
  430. m_pPSVRuntimeInfo1 = const_cast<PSVRuntimeInfo1*>((const PSVRuntimeInfo1*)pCurBits);
  431. pCurBits += m_uPSVRuntimeInfoSize;
  432. m_uResourceCount = *(const uint32_t*)pCurBits;
  433. pCurBits += sizeof(uint32_t);
  434. if (m_uResourceCount > 0) {
  435. minsize += sizeof(uint32_t);
  436. if(!(size >= minsize)) return false;
  437. m_uPSVResourceBindInfoSize = *(const uint32_t*)pCurBits;
  438. pCurBits += sizeof(uint32_t);
  439. minsize += m_uPSVResourceBindInfoSize * m_uResourceCount;
  440. if(!(m_uPSVResourceBindInfoSize >= sizeof(PSVResourceBindInfo0))) return false;
  441. if(!(size >= minsize)) return false;
  442. m_pPSVResourceBindInfo = static_cast<void*>(const_cast<uint8_t*>(pCurBits));
  443. }
  444. pCurBits += m_uPSVResourceBindInfoSize * m_uResourceCount;
  445. if (m_pPSVRuntimeInfo1) {
  446. minsize += sizeof(uint32_t) * 2; // m_StringTable.Size and m_SemanticIndexTable.Entries
  447. if (!(size >= minsize)) return false;
  448. m_StringTable.Size = PSVALIGN4(*(uint32_t*)pCurBits);
  449. if (m_StringTable.Size != *(uint32_t*)pCurBits)
  450. return false; // Illegal: Size not aligned
  451. minsize += m_StringTable.Size;
  452. if (!(size >= minsize)) return false;
  453. pCurBits += sizeof(uint32_t);
  454. m_StringTable.Table = (const char *)pCurBits;
  455. pCurBits += m_StringTable.Size;
  456. m_SemanticIndexTable.Entries = *(uint32_t*)pCurBits;
  457. minsize += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
  458. if (!(size >= minsize)) return false;
  459. pCurBits += sizeof(uint32_t);
  460. m_SemanticIndexTable.Table = (uint32_t*)pCurBits;
  461. pCurBits += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
  462. // Dxil Signature Elements
  463. if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
  464. minsize += sizeof(uint32_t);
  465. if (!(size >= minsize)) return false;
  466. m_uPSVSignatureElementSize = *(uint32_t*)pCurBits;
  467. if (m_uPSVSignatureElementSize < sizeof(PSVSignatureElement0))
  468. return false; // Illegal: Size smaller than first version
  469. pCurBits += sizeof(uint32_t);
  470. minsize += m_uPSVSignatureElementSize * (m_pPSVRuntimeInfo1->SigInputElements + m_pPSVRuntimeInfo1->SigOutputElements + m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements);
  471. if (!(size >= minsize)) return false;
  472. }
  473. if (m_pPSVRuntimeInfo1->SigInputElements) {
  474. m_pSigInputElements = (PSVSignatureElement0*)pCurBits;
  475. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigInputElements;
  476. }
  477. if (m_pPSVRuntimeInfo1->SigOutputElements) {
  478. m_pSigOutputElements = (PSVSignatureElement0*)pCurBits;
  479. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigOutputElements;
  480. }
  481. if (m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
  482. m_pSigPatchConstOrPrimElements = (PSVSignatureElement0*)pCurBits;
  483. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
  484. }
  485. // ViewID dependencies
  486. if (m_pPSVRuntimeInfo1->UsesViewID) {
  487. for (unsigned i = 0; i < 4; i++) {
  488. if (m_pPSVRuntimeInfo1->SigOutputVectors[i]) {
  489. minsize += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  490. if (!(size >= minsize)) return false;
  491. m_pViewIDOutputMask = (uint32_t*)pCurBits;
  492. pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  493. }
  494. if (!IsGS())
  495. break;
  496. }
  497. if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors) {
  498. minsize += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
  499. if (!(size >= minsize)) return false;
  500. m_pViewIDPCOrPrimOutputMask = (uint32_t*)pCurBits;
  501. pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
  502. }
  503. }
  504. // Input to Output dependencies
  505. for (unsigned i = 0; i < 4; i++) {
  506. if (!IsMS() && m_pPSVRuntimeInfo1->SigOutputVectors[i] > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
  507. minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  508. if (!(size >= minsize)) return false;
  509. m_pInputToOutputTable = (uint32_t*)pCurBits;
  510. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  511. }
  512. if (!IsGS())
  513. break;
  514. }
  515. if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
  516. minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
  517. if (!(size >= minsize)) return false;
  518. m_pInputToPCOutputTable = (uint32_t*)pCurBits;
  519. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
  520. }
  521. if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0) {
  522. minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
  523. if (!(size >= minsize)) return false;
  524. m_pPCInputToOutputTable = (uint32_t*)pCurBits;
  525. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
  526. }
  527. }
  528. return true;
  529. }
  530. // Initialize a new buffer
  531. // call with null pBuffer to get required size
  532. bool InitNew(uint32_t ResourceCount, void *pBuffer, uint32_t *pSize) {
  533. PSVInitInfo initInfo(0);
  534. initInfo.ResourceCount = ResourceCount;
  535. return InitNew(initInfo, pBuffer, pSize);
  536. }
  537. bool InitNew(const PSVInitInfo &initInfo, void *pBuffer, uint32_t *pSize) {
  538. if(!(pSize)) return false;
  539. if (initInfo.PSVVersion > MAX_PSV_VERSION) return false;
  540. // Versioned structure sizes
  541. m_uPSVRuntimeInfoSize = sizeof(PSVRuntimeInfo0);
  542. m_uPSVResourceBindInfoSize = sizeof(PSVResourceBindInfo0);
  543. m_uPSVSignatureElementSize = sizeof(PSVSignatureElement0);
  544. if (initInfo.PSVVersion > 0) {
  545. m_uPSVRuntimeInfoSize = sizeof(PSVRuntimeInfo1);
  546. }
  547. // PSVVersion 0
  548. uint32_t size = m_uPSVRuntimeInfoSize + sizeof(uint32_t) * 2;
  549. if (initInfo.ResourceCount) {
  550. size += sizeof(uint32_t) + (m_uPSVResourceBindInfoSize * initInfo.ResourceCount);
  551. }
  552. // PSVVersion 1
  553. if (initInfo.PSVVersion > 0) {
  554. size += sizeof(uint32_t) + PSVALIGN4(initInfo.StringTable.Size);
  555. size += sizeof(uint32_t) + sizeof(uint32_t) * initInfo.SemanticIndexTable.Entries;
  556. if (initInfo.SigInputElements || initInfo.SigOutputElements || initInfo.SigPatchConstOrPrimElements) {
  557. size += sizeof(uint32_t); // PSVSignatureElement_size
  558. }
  559. size += m_uPSVSignatureElementSize * initInfo.SigInputElements;
  560. size += m_uPSVSignatureElementSize * initInfo.SigOutputElements;
  561. size += m_uPSVSignatureElementSize * initInfo.SigPatchConstOrPrimElements;
  562. if (initInfo.UsesViewID) {
  563. for (unsigned i = 0; i < 4; i++) {
  564. size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigOutputVectors[i]);
  565. if (initInfo.ShaderStage != PSVShaderKind::Geometry)
  566. break;
  567. }
  568. if (initInfo.ShaderStage == PSVShaderKind::Hull || initInfo.ShaderStage == PSVShaderKind::Mesh)
  569. size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigPatchConstOrPrimVectors);
  570. }
  571. if (initInfo.SigInputVectors > 0) {
  572. for (unsigned i = 0; i < 4; i++) {
  573. if (initInfo.SigOutputVectors[i] > 0) {
  574. size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigOutputVectors[i]);
  575. if (initInfo.ShaderStage != PSVShaderKind::Geometry)
  576. break;
  577. }
  578. }
  579. if (initInfo.ShaderStage == PSVShaderKind::Hull && initInfo.SigPatchConstOrPrimVectors > 0 && initInfo.SigInputVectors > 0) {
  580. size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigPatchConstOrPrimVectors);
  581. }
  582. }
  583. if (initInfo.ShaderStage == PSVShaderKind::Domain && initInfo.SigOutputVectors[0] > 0 && initInfo.SigPatchConstOrPrimVectors > 0) {
  584. size += PSVComputeInputOutputTableSize(initInfo.SigPatchConstOrPrimVectors, initInfo.SigOutputVectors[0]);
  585. }
  586. }
  587. // Validate or return required size
  588. if (pBuffer) {
  589. if(!(*pSize >= size)) return false;
  590. } else {
  591. *pSize = size;
  592. return true;
  593. }
  594. // PSVVersion 0
  595. memset(pBuffer, 0, size);
  596. uint8_t* pCurBits = (uint8_t*)pBuffer;
  597. *(uint32_t*)pCurBits = m_uPSVRuntimeInfoSize;
  598. pCurBits += sizeof(uint32_t);
  599. m_pPSVRuntimeInfo0 = (PSVRuntimeInfo0*)pCurBits;
  600. if (initInfo.PSVVersion > 0) {
  601. m_pPSVRuntimeInfo1 = (PSVRuntimeInfo1*)pCurBits;
  602. }
  603. pCurBits += m_uPSVRuntimeInfoSize;
  604. // Set resource info:
  605. m_uResourceCount = initInfo.ResourceCount;
  606. *(uint32_t*)pCurBits = m_uResourceCount;
  607. pCurBits += sizeof(uint32_t);
  608. if (m_uResourceCount > 0) {
  609. *(uint32_t*)pCurBits = m_uPSVResourceBindInfoSize;
  610. pCurBits += sizeof(uint32_t);
  611. m_pPSVResourceBindInfo = pCurBits;
  612. }
  613. pCurBits += m_uPSVResourceBindInfoSize * m_uResourceCount;
  614. // PSVVersion 1
  615. if (initInfo.PSVVersion) {
  616. m_pPSVRuntimeInfo1->ShaderStage = (uint8_t)initInfo.ShaderStage;
  617. m_pPSVRuntimeInfo1->UsesViewID = initInfo.UsesViewID;
  618. m_pPSVRuntimeInfo1->SigInputElements = initInfo.SigInputElements;
  619. m_pPSVRuntimeInfo1->SigOutputElements = initInfo.SigOutputElements;
  620. m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements = initInfo.SigPatchConstOrPrimElements;
  621. m_pPSVRuntimeInfo1->SigInputVectors = initInfo.SigInputVectors;
  622. memcpy(m_pPSVRuntimeInfo1->SigOutputVectors, initInfo.SigOutputVectors, 4);
  623. if (IsHS() || IsDS() || IsMS()) {
  624. m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors = initInfo.SigPatchConstOrPrimVectors;
  625. }
  626. // Note: if original size was unaligned, padding has already been zero initialized
  627. m_StringTable.Size = PSVALIGN4(initInfo.StringTable.Size);
  628. *(uint32_t*)pCurBits = m_StringTable.Size;
  629. pCurBits += sizeof(uint32_t);
  630. memcpy(pCurBits, initInfo.StringTable.Table, initInfo.StringTable.Size);
  631. m_StringTable.Table = (const char *)pCurBits;
  632. pCurBits += m_StringTable.Size;
  633. m_SemanticIndexTable.Entries = initInfo.SemanticIndexTable.Entries;
  634. *(uint32_t*)pCurBits = m_SemanticIndexTable.Entries;
  635. pCurBits += sizeof(uint32_t);
  636. memcpy(pCurBits, initInfo.SemanticIndexTable.Table, sizeof(uint32_t) * initInfo.SemanticIndexTable.Entries);
  637. m_SemanticIndexTable.Table = (const uint32_t*)pCurBits;
  638. pCurBits += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
  639. // Dxil Signature Elements
  640. if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
  641. *(uint32_t*)pCurBits = m_uPSVSignatureElementSize;
  642. pCurBits += sizeof(uint32_t);
  643. }
  644. if (m_pPSVRuntimeInfo1->SigInputElements) {
  645. m_pSigInputElements = (PSVSignatureElement0*)pCurBits;
  646. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigInputElements;
  647. }
  648. if (m_pPSVRuntimeInfo1->SigOutputElements) {
  649. m_pSigOutputElements = (PSVSignatureElement0*)pCurBits;
  650. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigOutputElements;
  651. }
  652. if (m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements) {
  653. m_pSigPatchConstOrPrimElements = (PSVSignatureElement0*)pCurBits;
  654. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
  655. }
  656. // ViewID dependencies
  657. if (m_pPSVRuntimeInfo1->UsesViewID) {
  658. for (unsigned i = 0; i < 4; i++) {
  659. if (m_pPSVRuntimeInfo1->SigOutputVectors[i]) {
  660. m_pViewIDOutputMask = (uint32_t*)pCurBits;
  661. pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  662. }
  663. if (!IsGS())
  664. break;
  665. }
  666. if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors) {
  667. m_pViewIDPCOrPrimOutputMask = (uint32_t*)pCurBits;
  668. pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
  669. }
  670. }
  671. // Input to Output dependencies
  672. if (m_pPSVRuntimeInfo1->SigInputVectors > 0) {
  673. for (unsigned i = 0; i < 4; i++) {
  674. if (m_pPSVRuntimeInfo1->SigOutputVectors[i] > 0) {
  675. m_pInputToOutputTable = (uint32_t*)pCurBits;
  676. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  677. }
  678. if (!IsGS())
  679. break;
  680. }
  681. if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
  682. m_pInputToPCOutputTable = (uint32_t*)pCurBits;
  683. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
  684. }
  685. }
  686. if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0) {
  687. m_pPCInputToOutputTable = (uint32_t*)pCurBits;
  688. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
  689. }
  690. }
  691. return true;
  692. }
  693. PSVRuntimeInfo0* GetPSVRuntimeInfo0() const {
  694. return m_pPSVRuntimeInfo0;
  695. }
  696. PSVRuntimeInfo1* GetPSVRuntimeInfo1() const {
  697. return m_pPSVRuntimeInfo1;
  698. }
  699. uint32_t GetBindCount() const {
  700. return m_uResourceCount;
  701. }
  702. PSVResourceBindInfo0* GetPSVResourceBindInfo0(uint32_t index) const {
  703. if (index < m_uResourceCount && m_pPSVResourceBindInfo &&
  704. sizeof(PSVResourceBindInfo0) <= m_uPSVResourceBindInfoSize) {
  705. return (PSVResourceBindInfo0*)((uint8_t*)m_pPSVResourceBindInfo +
  706. (index * m_uPSVResourceBindInfoSize));
  707. }
  708. return nullptr;
  709. }
  710. const PSVStringTable &GetStringTable() const { return m_StringTable; }
  711. const PSVSemanticIndexTable &GetSemanticIndexTable() const { return m_SemanticIndexTable; }
  712. // Signature element access
  713. uint32_t GetSigInputElements() const {
  714. if (m_pPSVRuntimeInfo1)
  715. return m_pPSVRuntimeInfo1->SigInputElements;
  716. return 0;
  717. }
  718. uint32_t GetSigOutputElements() const {
  719. if (m_pPSVRuntimeInfo1)
  720. return m_pPSVRuntimeInfo1->SigOutputElements;
  721. return 0;
  722. }
  723. uint32_t GetSigPatchConstOrPrimElements() const {
  724. if (m_pPSVRuntimeInfo1)
  725. return m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements;
  726. return 0;
  727. }
  728. PSVSignatureElement0* GetInputElement0(uint32_t index) const {
  729. if (m_pPSVRuntimeInfo1 && m_pSigInputElements &&
  730. index < m_pPSVRuntimeInfo1->SigInputElements &&
  731. sizeof(PSVSignatureElement0) <= m_uPSVSignatureElementSize) {
  732. return (PSVSignatureElement0*)((uint8_t*)m_pSigInputElements +
  733. (index * m_uPSVSignatureElementSize));
  734. }
  735. return nullptr;
  736. }
  737. PSVSignatureElement0* GetOutputElement0(uint32_t index) const {
  738. if (m_pPSVRuntimeInfo1 && m_pSigOutputElements &&
  739. index < m_pPSVRuntimeInfo1->SigOutputElements &&
  740. sizeof(PSVSignatureElement0) <= m_uPSVSignatureElementSize) {
  741. return (PSVSignatureElement0*)((uint8_t*)m_pSigOutputElements +
  742. (index * m_uPSVSignatureElementSize));
  743. }
  744. return nullptr;
  745. }
  746. PSVSignatureElement0* GetPatchConstOrPrimElement0(uint32_t index) const {
  747. if (m_pPSVRuntimeInfo1 && m_pSigPatchConstOrPrimElements &&
  748. index < m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements &&
  749. sizeof(PSVSignatureElement0) <= m_uPSVSignatureElementSize) {
  750. return (PSVSignatureElement0*)((uint8_t*)m_pSigPatchConstOrPrimElements +
  751. (index * m_uPSVSignatureElementSize));
  752. }
  753. return nullptr;
  754. }
  755. // More convenient wrapper:
  756. PSVSignatureElement GetSignatureElement(PSVSignatureElement0* pElement0) const {
  757. return PSVSignatureElement(m_StringTable, m_SemanticIndexTable, pElement0);
  758. }
  759. PSVShaderKind GetShaderKind() const {
  760. if (m_pPSVRuntimeInfo1 && m_pPSVRuntimeInfo1->ShaderStage < (uint8_t)PSVShaderKind::Invalid)
  761. return (PSVShaderKind)m_pPSVRuntimeInfo1->ShaderStage;
  762. return PSVShaderKind::Invalid;
  763. }
  764. bool IsVS() const { return GetShaderKind() == PSVShaderKind::Vertex; }
  765. bool IsHS() const { return GetShaderKind() == PSVShaderKind::Hull; }
  766. bool IsDS() const { return GetShaderKind() == PSVShaderKind::Domain; }
  767. bool IsGS() const { return GetShaderKind() == PSVShaderKind::Geometry; }
  768. bool IsPS() const { return GetShaderKind() == PSVShaderKind::Pixel; }
  769. bool IsCS() const { return GetShaderKind() == PSVShaderKind::Compute; }
  770. bool IsMS() const { return GetShaderKind() == PSVShaderKind::Mesh; }
  771. bool IsAS() const { return GetShaderKind() == PSVShaderKind::Amplification; }
  772. // ViewID dependencies
  773. PSVComponentMask GetViewIDOutputMask(unsigned streamIndex = 0) const {
  774. if (!m_pViewIDOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex])
  775. return PSVComponentMask();
  776. return PSVComponentMask(m_pViewIDOutputMask, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
  777. }
  778. PSVComponentMask GetViewIDPCOutputMask() const {
  779. if ((!IsHS() && !IsMS()) || !m_pViewIDPCOrPrimOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors)
  780. return PSVComponentMask();
  781. return PSVComponentMask(m_pViewIDPCOrPrimOutputMask, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
  782. }
  783. // Input to Output dependencies
  784. PSVDependencyTable GetInputToOutputTable(unsigned streamIndex = 0) const {
  785. if (m_pInputToOutputTable && m_pPSVRuntimeInfo1) {
  786. return PSVDependencyTable(m_pInputToOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
  787. }
  788. return PSVDependencyTable();
  789. }
  790. PSVDependencyTable GetInputToPCOutputTable() const {
  791. if ((IsHS() || IsMS()) && m_pInputToPCOutputTable && m_pPSVRuntimeInfo1) {
  792. return PSVDependencyTable(m_pInputToPCOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
  793. }
  794. return PSVDependencyTable();
  795. }
  796. PSVDependencyTable GetPCInputToOutputTable() const {
  797. if (IsDS() && m_pPCInputToOutputTable && m_pPSVRuntimeInfo1) {
  798. return PSVDependencyTable(m_pPCInputToOutputTable, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
  799. }
  800. return PSVDependencyTable();
  801. }
  802. };
  803. namespace hlsl {
  804. class ViewIDValidator {
  805. public:
  806. enum class Result {
  807. Success = 0,
  808. SuccessWithViewIDDependentTessFactor,
  809. InsufficientInputSpace,
  810. InsufficientOutputSpace,
  811. InsufficientPCSpace,
  812. MismatchedSignatures,
  813. MismatchedPCSignatures,
  814. InvalidUsage,
  815. InvalidPSVVersion,
  816. InvalidPSV,
  817. };
  818. virtual ~ViewIDValidator() {}
  819. virtual Result ValidateStage(const DxilPipelineStateValidation &PSV,
  820. bool bFinalStage,
  821. bool bExpandInputOnly,
  822. unsigned &mismatchElementId) = 0;
  823. };
  824. ViewIDValidator* NewViewIDValidator(unsigned viewIDCount, unsigned gsRastStreamIndex);
  825. }
  826. #endif // __DXIL_PIPELINE_STATE_VALIDATION__H__