DxilPipelineStateValidation.h 34 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPipelineStateValidation.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Defines data used by the D3D runtime for PSO validation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #ifndef __DXIL_PIPELINE_STATE_VALIDATION__H__
  12. #define __DXIL_PIPELINE_STATE_VALIDATION__H__
  13. #include <stdint.h>
  14. #include <string.h>
  15. // How many dwords are required for mask with one bit per component, 4 components per vector
  16. inline uint32_t PSVComputeMaskDwordsFromVectors(uint32_t Vectors) { return (Vectors + 7) >> 3; }
  17. inline uint32_t PSVComputeInputOutputTableSize(uint32_t InputVectors, uint32_t OutputVectors) {
  18. return sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(OutputVectors) * InputVectors * 4;
  19. }
  20. #define PSVALIGN(ptr, alignbits) (((ptr) + ((1 << (alignbits))-1)) & ~((1 << (alignbits))-1))
  21. #define PSVALIGN4(ptr) (((ptr) + 3) & ~3)
  22. // Versioning is additive and based on size
  23. struct PSVRuntimeInfo0
  24. {
  25. union {
  26. struct VSInfo {
  27. char OutputPositionPresent;
  28. } VS;
  29. struct HSInfo {
  30. uint32_t InputControlPointCount; // max control points == 32
  31. uint32_t OutputControlPointCount; // max control points == 32
  32. uint32_t TessellatorDomain; // hlsl::DXIL::TessellatorDomain/D3D11_SB_TESSELLATOR_DOMAIN
  33. uint32_t TessellatorOutputPrimitive; // hlsl::DXIL::TessellatorOutputPrimitive/D3D11_SB_TESSELLATOR_OUTPUT_PRIMITIVE
  34. } HS;
  35. struct DSInfo {
  36. uint32_t InputControlPointCount; // max control points == 32
  37. char OutputPositionPresent;
  38. uint32_t TessellatorDomain; // hlsl::DXIL::TessellatorDomain/D3D11_SB_TESSELLATOR_DOMAIN
  39. } DS;
  40. struct GSInfo {
  41. uint32_t InputPrimitive; // hlsl::DXIL::InputPrimitive/D3D10_SB_PRIMITIVE
  42. uint32_t OutputTopology; // hlsl::DXIL::PrimitiveTopology/D3D10_SB_PRIMITIVE_TOPOLOGY
  43. uint32_t OutputStreamMask; // max streams == 4
  44. char OutputPositionPresent;
  45. } GS;
  46. struct PSInfo {
  47. char DepthOutput;
  48. char SampleFrequency;
  49. } PS;
  50. };
  51. uint32_t MinimumExpectedWaveLaneCount; // minimum lane count required, 0 if unused
  52. uint32_t MaximumExpectedWaveLaneCount; // maximum lane count required, 0xffffffff if unused
  53. };
  54. enum class PSVShaderKind : uint8_t // DXIL::ShaderKind
  55. {
  56. Pixel = 0,
  57. Vertex,
  58. Geometry,
  59. Hull,
  60. Domain,
  61. Compute,
  62. Invalid,
  63. };
  64. struct PSVRuntimeInfo1 : public PSVRuntimeInfo0
  65. {
  66. uint8_t ShaderStage; // PSVShaderKind
  67. uint8_t UsesViewID;
  68. union {
  69. uint16_t MaxVertexCount; // MaxVertexCount for GS only (max 1024)
  70. uint8_t SigPatchConstantVectors; // Output for HS; Input for DS
  71. };
  72. // PSVSignatureElement counts
  73. uint8_t SigInputElements;
  74. uint8_t SigOutputElements;
  75. uint8_t SigPatchConstantElements;
  76. // Number of packed vectors per signature
  77. uint8_t SigInputVectors;
  78. uint8_t SigOutputVectors[4]; // Array for GS Stream Out Index
  79. };
  80. enum class PSVResourceType
  81. {
  82. Invalid = 0,
  83. Sampler,
  84. CBV,
  85. SRVTyped,
  86. SRVRaw,
  87. SRVStructured,
  88. UAVTyped,
  89. UAVRaw,
  90. UAVStructured,
  91. UAVStructuredWithCounter,
  92. NumEntries
  93. };
  94. // Versioning is additive and based on size
  95. struct PSVResourceBindInfo0
  96. {
  97. uint32_t ResType; // PSVResourceType
  98. uint32_t Space;
  99. uint32_t LowerBound;
  100. uint32_t UpperBound;
  101. };
  102. // PSVResourceBindInfo1 would derive and extend
  103. // Helpers for output dependencies (ViewID and Input-Output tables)
  104. struct PSVComponentMask {
  105. uint32_t *Mask;
  106. uint32_t NumVectors;
  107. PSVComponentMask() : Mask(nullptr), NumVectors(0) {}
  108. PSVComponentMask(const PSVComponentMask &other) : Mask(other.Mask), NumVectors(other.NumVectors) {}
  109. PSVComponentMask(uint32_t *pMask, uint32_t outputVectors)
  110. : Mask(pMask),
  111. NumVectors(outputVectors)
  112. {}
  113. const PSVComponentMask &operator|=(const PSVComponentMask &other) {
  114. uint32_t dwords = PSVComputeMaskDwordsFromVectors(NumVectors < other.NumVectors ? NumVectors : other.NumVectors);
  115. for (uint32_t i = 0; i < dwords; ++i) {
  116. Mask[i] |= other.Mask[i];
  117. }
  118. return *this;
  119. }
  120. bool Get(uint32_t ComponentIndex) const {
  121. if(ComponentIndex < NumVectors * 4)
  122. return (bool)(Mask[ComponentIndex >> 5] & (1 << (ComponentIndex & 0x1F)));
  123. return false;
  124. }
  125. void Set(uint32_t ComponentIndex) {
  126. if (ComponentIndex < NumVectors * 4)
  127. Mask[ComponentIndex >> 5] |= (1 << (ComponentIndex & 0x1F));
  128. }
  129. void Clear(uint32_t ComponentIndex) {
  130. if (ComponentIndex < NumVectors * 4)
  131. Mask[ComponentIndex >> 5] &= ~(1 << (ComponentIndex & 0x1F));
  132. }
  133. bool IsValid() { return Mask != nullptr; }
  134. };
  135. struct PSVDependencyTable {
  136. uint32_t *Table;
  137. uint32_t InputVectors;
  138. uint32_t OutputVectors;
  139. PSVDependencyTable() : Table(nullptr), InputVectors(0), OutputVectors(0) {}
  140. PSVDependencyTable(const PSVDependencyTable &other) : Table(other.Table), InputVectors(other.InputVectors), OutputVectors(other.OutputVectors) {}
  141. PSVDependencyTable(uint32_t *pTable, uint32_t inputVectors, uint32_t outputVectors)
  142. : Table(pTable),
  143. InputVectors(inputVectors),
  144. OutputVectors(outputVectors)
  145. {}
  146. PSVComponentMask GetMaskForInput(uint32_t inputComponentIndex) {
  147. if (!Table || !InputVectors || !OutputVectors)
  148. return PSVComponentMask();
  149. return PSVComponentMask(Table + (PSVComputeMaskDwordsFromVectors(OutputVectors) * inputComponentIndex), OutputVectors);
  150. }
  151. bool IsValid() { return Table != nullptr; }
  152. };
  153. // Table of null-terminated strings, overall size aligned to dword boundary, last byte must be null
  154. struct PSVStringTable {
  155. const char *Table;
  156. uint32_t Size;
  157. PSVStringTable() : Table(nullptr), Size(0) {}
  158. PSVStringTable(const char *table, uint32_t size) : Table(table), Size(size) {}
  159. const char *Get(uint32_t offset) const {
  160. _Analysis_assume_(offset < Size && Table && Table[Size-1] == '\0');
  161. return Table + offset;
  162. }
  163. };
  164. struct PSVString {
  165. uint32_t Offset;
  166. PSVString() : Offset(0) {}
  167. PSVString(uint32_t offset) : Offset(offset) {}
  168. const char *Get(const PSVStringTable &table) const { return table.Get(Offset); }
  169. };
  170. struct PSVSemanticIndexTable {
  171. const uint32_t *Table;
  172. uint32_t Entries;
  173. PSVSemanticIndexTable() : Table(nullptr), Entries(0) {}
  174. PSVSemanticIndexTable(const uint32_t *table, uint32_t entries) : Table(table), Entries(entries) {}
  175. const uint32_t *Get(uint32_t offset) const {
  176. _Analysis_assume_(offset < Entries && Table);
  177. return Table + offset;
  178. }
  179. };
  180. struct PSVSemanticIndexes {
  181. uint32_t Offset;
  182. PSVSemanticIndexes() : Offset(0) {}
  183. PSVSemanticIndexes(uint32_t offset) : Offset(offset) {}
  184. uint32_t *Get(const PSVSemanticIndexTable &table) const { table.Get(Offset); }
  185. };
  186. enum class PSVSemanticKind : uint8_t // DXIL::SemanticKind
  187. {
  188. Arbitrary,
  189. VertexID,
  190. InstanceID,
  191. Position,
  192. RenderTargetArrayIndex,
  193. ViewPortArrayIndex,
  194. ClipDistance,
  195. CullDistance,
  196. OutputControlPointID,
  197. DomainLocation,
  198. PrimitiveID,
  199. GSInstanceID,
  200. SampleIndex,
  201. IsFrontFace,
  202. Coverage,
  203. InnerCoverage,
  204. Target,
  205. Depth,
  206. DepthLessEqual,
  207. DepthGreaterEqual,
  208. StencilRef,
  209. DispatchThreadID,
  210. GroupID,
  211. GroupIndex,
  212. GroupThreadID,
  213. TessFactor,
  214. InsideTessFactor,
  215. ViewID,
  216. Barycentrics,
  217. Invalid,
  218. };
  219. struct PSVSignatureElement0
  220. {
  221. uint32_t SemanticName; // Offset into PSVStringTable
  222. uint32_t SemanticIndexes; // Offset into PSVSemanticIndexTable, count == Rows
  223. uint8_t Rows; // Number of rows this element occupies
  224. uint8_t StartRow; // Starting row of packing location if allocated
  225. uint8_t ColsAndStart; // 0:4 = Cols, 4:6 = StartCol, 6:7 == Allocated
  226. uint8_t SemanticKind; // PSVSemanticKind
  227. uint8_t ComponentType; // DxilProgramSigCompType
  228. uint8_t InterpolationMode; // DXIL::InterpolationMode or D3D10_SB_INTERPOLATION_MODE
  229. uint8_t DynamicMaskAndStream; // 0:4 = DynamicIndexMask, 4:6 = OutputStream (0-3)
  230. uint8_t Reserved;
  231. };
  232. // Provides convenient access to packed PSVSignatureElementN structure
  233. class PSVSignatureElement
  234. {
  235. private:
  236. const PSVStringTable &m_StringTable;
  237. const PSVSemanticIndexTable &m_SemanticIndexTable;
  238. const PSVSignatureElement0 *m_pElement0;
  239. public:
  240. PSVSignatureElement(const PSVStringTable &stringTable, const PSVSemanticIndexTable &semanticIndexTable, const PSVSignatureElement0 *pElement0)
  241. : m_StringTable(stringTable), m_SemanticIndexTable(semanticIndexTable), m_pElement0(pElement0) {}
  242. const char *GetSemanticName() const { return !m_pElement0 ? "" : m_StringTable.Get(m_pElement0->SemanticName); }
  243. const uint32_t *GetSemanticIndexes() const { return !m_pElement0 ? nullptr: m_SemanticIndexTable.Get(m_pElement0->SemanticIndexes); }
  244. uint32_t GetRows() const { return !m_pElement0 ? 0 : ((uint32_t)m_pElement0->Rows); }
  245. uint32_t GetCols() const { return !m_pElement0 ? 0 : ((uint32_t)m_pElement0->ColsAndStart & 0xF); }
  246. bool IsAllocated() const { return !m_pElement0 ? false : !!(m_pElement0->ColsAndStart & 0x40); }
  247. int32_t GetStartRow() const { return !m_pElement0 ? 0 : !IsAllocated() ? -1 : (int32_t)m_pElement0->StartRow; }
  248. int32_t GetStartCol() const { return !m_pElement0 ? 0 : !IsAllocated() ? -1 : (int32_t)((m_pElement0->ColsAndStart >> 4) & 0x3); }
  249. PSVSemanticKind GetSemanticKind() const { return !m_pElement0 ? (PSVSemanticKind)0 : (PSVSemanticKind)m_pElement0->SemanticKind; }
  250. uint32_t GetComponentType() const { return !m_pElement0 ? 0 : (uint32_t)m_pElement0->ComponentType; }
  251. uint32_t GetInterpolationMode() const { return !m_pElement0 ? 0 : (uint32_t)m_pElement0->InterpolationMode; }
  252. uint32_t GetOutputStream() const { return !m_pElement0 ? 0 : (uint32_t)(m_pElement0->DynamicMaskAndStream >> 4) & 0x3; }
  253. uint32_t GetDynamicIndexMask() const { return !m_pElement0 ? 0 : (uint32_t)m_pElement0->DynamicMaskAndStream & 0xF; }
  254. };
  255. struct PSVInitInfo
  256. {
  257. PSVInitInfo(uint32_t psvVersion)
  258. : PSVVersion(psvVersion),
  259. ResourceCount(0),
  260. ShaderStage(PSVShaderKind::Invalid),
  261. StringTable(),
  262. SemanticIndexTable(),
  263. UsesViewID(0),
  264. SigInputElements(0),
  265. SigOutputElements(0),
  266. SigPatchConstantElements(0),
  267. SigInputVectors(0),
  268. SigPatchConstantVectors(0)
  269. {}
  270. uint32_t PSVVersion;
  271. uint32_t ResourceCount;
  272. PSVShaderKind ShaderStage;
  273. PSVStringTable StringTable;
  274. PSVSemanticIndexTable SemanticIndexTable;
  275. uint8_t UsesViewID;
  276. uint8_t SigInputElements;
  277. uint8_t SigOutputElements;
  278. uint8_t SigPatchConstantElements;
  279. uint8_t SigInputVectors;
  280. uint8_t SigPatchConstantVectors;
  281. uint8_t SigOutputVectors[4] = {0, 0, 0, 0};
  282. };
  283. class DxilPipelineStateValidation
  284. {
  285. uint32_t m_uPSVRuntimeInfoSize;
  286. PSVRuntimeInfo0* m_pPSVRuntimeInfo0;
  287. PSVRuntimeInfo1* m_pPSVRuntimeInfo1;
  288. uint32_t m_uResourceCount;
  289. uint32_t m_uPSVResourceBindInfoSize;
  290. void* m_pPSVResourceBindInfo;
  291. PSVStringTable m_StringTable;
  292. PSVSemanticIndexTable m_SemanticIndexTable;
  293. uint32_t m_uPSVSignatureElementSize;
  294. void* m_pSigInputElements;
  295. void* m_pSigOutputElements;
  296. void* m_pSigPatchConstantElements;
  297. uint32_t* m_pViewIDOutputMask;
  298. uint32_t* m_pViewIDPCOutputMask;
  299. uint32_t* m_pInputToOutputTable;
  300. uint32_t* m_pInputToPCOutputTable;
  301. uint32_t* m_pPCInputToOutputTable;
  302. public:
  303. DxilPipelineStateValidation() :
  304. m_uPSVRuntimeInfoSize(0),
  305. m_pPSVRuntimeInfo0(nullptr),
  306. m_pPSVRuntimeInfo1(nullptr),
  307. m_uResourceCount(0),
  308. m_uPSVResourceBindInfoSize(0),
  309. m_pPSVResourceBindInfo(nullptr),
  310. m_StringTable(),
  311. m_SemanticIndexTable(),
  312. m_uPSVSignatureElementSize(0),
  313. m_pSigInputElements(nullptr),
  314. m_pSigOutputElements(nullptr),
  315. m_pSigPatchConstantElements(nullptr),
  316. m_pViewIDOutputMask(nullptr),
  317. m_pViewIDPCOutputMask(nullptr),
  318. m_pInputToOutputTable(nullptr),
  319. m_pInputToPCOutputTable(nullptr),
  320. m_pPCInputToOutputTable(nullptr)
  321. {
  322. }
  323. // Init() from PSV0 blob part that looks like:
  324. // uint32_t PSVRuntimeInfo_size
  325. // { PSVRuntimeInfoN structure }
  326. // uint32_t ResourceCount
  327. // If ResourceCount > 0:
  328. // uint32_t PSVResourceBindInfo_size
  329. // { PSVResourceBindInfoN structure } * ResourceCount
  330. // If PSVRuntimeInfo1:
  331. // uint32_t StringTableSize (dword aligned)
  332. // If StringTableSize:
  333. // { StringTableSize bytes, null terminated }
  334. // uint32_t SemanticIndexTableEntries (number of dwords)
  335. // If SemanticIndexTableEntries:
  336. // { semantic index } * SemanticIndexTableEntries
  337. // If SigInputElements || SigOutputElements || SigPatchConstantElements:
  338. // uint32_t PSVSignatureElement_size
  339. // { PSVSignatureElementN structure } * SigInputElements
  340. // { PSVSignatureElementN structure } * SigOutputElements
  341. // { PSVSignatureElementN structure } * SigPatchConstantElements
  342. // If (UsesViewID):
  343. // For (i : each stream index 0-3):
  344. // If (SigOutputVectors[i] non-zero):
  345. // { uint32_t * PSVComputeMaskDwordsFromVectors(SigOutputVectors[i]) }
  346. // - Outputs affected by ViewID as a bitmask
  347. // If (HS and SigPatchConstantVectors non-zero):
  348. // { uint32_t * PSVComputeMaskDwordsFromVectors(SigPatchConstantVectors) }
  349. // - PCOutputs affected by ViewID as a bitmask
  350. // For (i : each stream index 0-3):
  351. // If (SigInputVectors and SigOutputVectors[i] non-zero):
  352. // { PSVComputeInputOutputTableSize(SigInputVectors, SigOutputVectors[i]) }
  353. // - Outputs affected by inputs as a table of bitmasks
  354. // If (HS and SigPatchConstantVectors and SigInputVectors non-zero):
  355. // { PSVComputeInputOutputTableSize(SigInputVectors, SigPatchConstantVectors) }
  356. // - Patch constant outputs affected by inputs as a table of bitmasks
  357. // If (DS and SigOutputVectors[0] and SigPatchConstantVectors non-zero):
  358. // { PSVComputeInputOutputTableSize(SigPatchConstantVectors, SigOutputVectors[0]) }
  359. // - Outputs affected by patch constant inputs as a table of bitmasks
  360. // returns true if no errors occurred.
  361. bool InitFromPSV0(const void* pBits, uint32_t size) {
  362. if(!(pBits != nullptr)) return false;
  363. const uint8_t* pCurBits = (uint8_t*)pBits;
  364. uint32_t minsize = sizeof(PSVRuntimeInfo0) + sizeof(uint32_t) * 2;
  365. if(!(size >= minsize)) return false;
  366. m_uPSVRuntimeInfoSize = *((const uint32_t*)pCurBits);
  367. if(!(m_uPSVRuntimeInfoSize >= sizeof(PSVRuntimeInfo0))) return false;
  368. pCurBits += sizeof(uint32_t);
  369. minsize = m_uPSVRuntimeInfoSize + sizeof(uint32_t) * 2;
  370. if(!(size >= minsize)) return false;
  371. m_pPSVRuntimeInfo0 = const_cast<PSVRuntimeInfo0*>((const PSVRuntimeInfo0*)pCurBits);
  372. if(m_uPSVRuntimeInfoSize >= sizeof(PSVRuntimeInfo1))
  373. m_pPSVRuntimeInfo1 = const_cast<PSVRuntimeInfo1*>((const PSVRuntimeInfo1*)pCurBits);
  374. pCurBits += m_uPSVRuntimeInfoSize;
  375. m_uResourceCount = *(const uint32_t*)pCurBits;
  376. pCurBits += sizeof(uint32_t);
  377. if (m_uResourceCount > 0) {
  378. minsize += sizeof(uint32_t);
  379. if(!(size >= minsize)) return false;
  380. m_uPSVResourceBindInfoSize = *(const uint32_t*)pCurBits;
  381. pCurBits += sizeof(uint32_t);
  382. minsize += m_uPSVResourceBindInfoSize * m_uResourceCount;
  383. if(!(m_uPSVResourceBindInfoSize >= sizeof(PSVResourceBindInfo0))) return false;
  384. if(!(size >= minsize)) return false;
  385. m_pPSVResourceBindInfo = static_cast<void*>(const_cast<uint8_t*>(pCurBits));
  386. }
  387. pCurBits += m_uPSVResourceBindInfoSize * m_uResourceCount;
  388. if (m_pPSVRuntimeInfo1) {
  389. minsize += sizeof(uint32_t) * 2; // m_StringTable.Size and m_SemanticIndexTable.Entries
  390. if (!(size >= minsize)) return false;
  391. m_StringTable.Size = PSVALIGN4(*(uint32_t*)pCurBits);
  392. if (m_StringTable.Size != *(uint32_t*)pCurBits)
  393. return false; // Illegal: Size not aligned
  394. minsize += m_StringTable.Size;
  395. if (!(size >= minsize)) return false;
  396. pCurBits += sizeof(uint32_t);
  397. m_StringTable.Table = (const char *)pCurBits;
  398. pCurBits += m_StringTable.Size;
  399. m_SemanticIndexTable.Entries = *(uint32_t*)pCurBits;
  400. minsize += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
  401. if (!(size >= minsize)) return false;
  402. pCurBits += sizeof(uint32_t);
  403. m_SemanticIndexTable.Table = (uint32_t*)pCurBits;
  404. pCurBits += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
  405. // Dxil Signature Elements
  406. if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstantElements) {
  407. minsize += sizeof(uint32_t);
  408. if (!(size >= minsize)) return false;
  409. m_uPSVSignatureElementSize = *(uint32_t*)pCurBits;
  410. if (m_uPSVSignatureElementSize < sizeof(PSVSignatureElement0))
  411. return false; // Illegal: Size smaller than first version
  412. pCurBits += sizeof(uint32_t);
  413. minsize += m_uPSVSignatureElementSize * (m_pPSVRuntimeInfo1->SigInputElements + m_pPSVRuntimeInfo1->SigOutputElements + m_pPSVRuntimeInfo1->SigPatchConstantElements);
  414. if (!(size >= minsize)) return false;
  415. }
  416. if (m_pPSVRuntimeInfo1->SigInputElements) {
  417. m_pSigInputElements = (PSVSignatureElement0*)pCurBits;
  418. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigInputElements;
  419. }
  420. if (m_pPSVRuntimeInfo1->SigOutputElements) {
  421. m_pSigOutputElements = (PSVSignatureElement0*)pCurBits;
  422. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigOutputElements;
  423. }
  424. if (m_pPSVRuntimeInfo1->SigPatchConstantElements) {
  425. m_pSigPatchConstantElements = (PSVSignatureElement0*)pCurBits;
  426. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstantElements;
  427. }
  428. // ViewID dependencies
  429. if (m_pPSVRuntimeInfo1->UsesViewID) {
  430. for (unsigned i = 0; i < 4; i++) {
  431. if (m_pPSVRuntimeInfo1->SigOutputVectors[i]) {
  432. minsize += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  433. if (!(size >= minsize)) return false;
  434. m_pViewIDOutputMask = (uint32_t*)pCurBits;
  435. pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  436. }
  437. if (!IsGS())
  438. break;
  439. }
  440. if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors) {
  441. minsize += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
  442. if (!(size >= minsize)) return false;
  443. m_pViewIDPCOutputMask = (uint32_t*)pCurBits;
  444. pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
  445. }
  446. }
  447. // Input to Output dependencies
  448. for (unsigned i = 0; i < 4; i++) {
  449. if (m_pPSVRuntimeInfo1->SigOutputVectors[i] > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
  450. minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  451. if (!(size >= minsize)) return false;
  452. m_pInputToOutputTable = (uint32_t*)pCurBits;
  453. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  454. }
  455. if (!IsGS())
  456. break;
  457. }
  458. if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
  459. minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
  460. if (!(size >= minsize)) return false;
  461. m_pInputToPCOutputTable = (uint32_t*)pCurBits;
  462. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
  463. }
  464. if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0) {
  465. minsize += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
  466. if (!(size >= minsize)) return false;
  467. m_pPCInputToOutputTable = (uint32_t*)pCurBits;
  468. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
  469. }
  470. }
  471. return true;
  472. }
  473. // Initialize a new buffer
  474. // call with null pBuffer to get required size
  475. bool InitNew(uint32_t ResourceCount, void *pBuffer, uint32_t *pSize) {
  476. PSVInitInfo initInfo(0);
  477. initInfo.ResourceCount = ResourceCount;
  478. return InitNew(initInfo, pBuffer, pSize);
  479. }
  480. bool InitNew(const PSVInitInfo &initInfo, void *pBuffer, uint32_t *pSize) {
  481. if(!(pSize)) return false;
  482. if (initInfo.PSVVersion > 1) return false;
  483. // Versioned structure sizes
  484. m_uPSVRuntimeInfoSize = sizeof(PSVRuntimeInfo0);
  485. m_uPSVResourceBindInfoSize = sizeof(PSVResourceBindInfo0);
  486. m_uPSVSignatureElementSize = sizeof(PSVSignatureElement0);
  487. if (initInfo.PSVVersion > 0) {
  488. m_uPSVRuntimeInfoSize = sizeof(PSVRuntimeInfo1);
  489. }
  490. // PSVVersion 0
  491. uint32_t size = m_uPSVRuntimeInfoSize + sizeof(uint32_t) * 2;
  492. if (initInfo.ResourceCount) {
  493. size += sizeof(uint32_t) + (m_uPSVResourceBindInfoSize * initInfo.ResourceCount);
  494. }
  495. // PSVVersion 1
  496. if (initInfo.PSVVersion > 0) {
  497. size += sizeof(uint32_t) + PSVALIGN4(initInfo.StringTable.Size);
  498. size += sizeof(uint32_t) + sizeof(uint32_t) * initInfo.SemanticIndexTable.Entries;
  499. if (initInfo.SigInputElements || initInfo.SigOutputElements || initInfo.SigPatchConstantElements) {
  500. size += sizeof(uint32_t); // PSVSignatureElement_size
  501. }
  502. size += m_uPSVSignatureElementSize * initInfo.SigInputElements;
  503. size += m_uPSVSignatureElementSize * initInfo.SigOutputElements;
  504. size += m_uPSVSignatureElementSize * initInfo.SigPatchConstantElements;
  505. if (initInfo.UsesViewID) {
  506. for (unsigned i = 0; i < 4; i++) {
  507. size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigOutputVectors[i]);
  508. if (initInfo.ShaderStage != PSVShaderKind::Geometry)
  509. break;
  510. }
  511. if (initInfo.ShaderStage == PSVShaderKind::Hull)
  512. size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigPatchConstantVectors);
  513. }
  514. for (unsigned i = 0; i < 4; i++) {
  515. if (initInfo.SigOutputVectors[i] > 0 && initInfo.SigInputVectors > 0) {
  516. size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigOutputVectors[i]);
  517. if (initInfo.ShaderStage != PSVShaderKind::Geometry)
  518. break;
  519. }
  520. }
  521. if (initInfo.ShaderStage == PSVShaderKind::Hull && initInfo.SigPatchConstantVectors > 0 && initInfo.SigInputVectors > 0) {
  522. size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigPatchConstantVectors);
  523. }
  524. if (initInfo.ShaderStage == PSVShaderKind::Domain && initInfo.SigOutputVectors[0] > 0 && initInfo.SigPatchConstantVectors > 0) {
  525. size += PSVComputeInputOutputTableSize(initInfo.SigPatchConstantVectors, initInfo.SigOutputVectors[0]);
  526. }
  527. }
  528. // Validate or return required size
  529. if (pBuffer) {
  530. if(!(*pSize >= size)) return false;
  531. } else {
  532. *pSize = size;
  533. return true;
  534. }
  535. // PSVVersion 0
  536. memset(pBuffer, 0, size);
  537. uint8_t* pCurBits = (uint8_t*)pBuffer;
  538. *(uint32_t*)pCurBits = m_uPSVRuntimeInfoSize;
  539. pCurBits += sizeof(uint32_t);
  540. m_pPSVRuntimeInfo0 = (PSVRuntimeInfo0*)pCurBits;
  541. if (initInfo.PSVVersion > 0) {
  542. m_pPSVRuntimeInfo1 = (PSVRuntimeInfo1*)pCurBits;
  543. }
  544. pCurBits += m_uPSVRuntimeInfoSize;
  545. // Set resource info:
  546. m_uResourceCount = initInfo.ResourceCount;
  547. *(uint32_t*)pCurBits = m_uResourceCount;
  548. pCurBits += sizeof(uint32_t);
  549. if (m_uResourceCount > 0) {
  550. *(uint32_t*)pCurBits = m_uPSVResourceBindInfoSize;
  551. pCurBits += sizeof(uint32_t);
  552. m_pPSVResourceBindInfo = pCurBits;
  553. }
  554. pCurBits += m_uPSVResourceBindInfoSize * m_uResourceCount;
  555. // PSVVersion 1
  556. if (initInfo.PSVVersion) {
  557. m_pPSVRuntimeInfo1->ShaderStage = (uint8_t)initInfo.ShaderStage;
  558. m_pPSVRuntimeInfo1->UsesViewID = initInfo.UsesViewID;
  559. m_pPSVRuntimeInfo1->SigInputElements = initInfo.SigInputElements;
  560. m_pPSVRuntimeInfo1->SigOutputElements = initInfo.SigOutputElements;
  561. m_pPSVRuntimeInfo1->SigPatchConstantElements = initInfo.SigPatchConstantElements;
  562. m_pPSVRuntimeInfo1->SigInputVectors = initInfo.SigInputVectors;
  563. memcpy(m_pPSVRuntimeInfo1->SigOutputVectors, initInfo.SigOutputVectors, 4);
  564. if (IsHS() || IsDS()) {
  565. m_pPSVRuntimeInfo1->SigPatchConstantVectors = initInfo.SigPatchConstantVectors;
  566. }
  567. // Note: if original size was unaligned, padding has already been zero initialized
  568. m_StringTable.Size = PSVALIGN4(initInfo.StringTable.Size);
  569. *(uint32_t*)pCurBits = m_StringTable.Size;
  570. pCurBits += sizeof(uint32_t);
  571. memcpy(pCurBits, initInfo.StringTable.Table, initInfo.StringTable.Size);
  572. m_StringTable.Table = (const char *)pCurBits;
  573. pCurBits += m_StringTable.Size;
  574. m_SemanticIndexTable.Entries = initInfo.SemanticIndexTable.Entries;
  575. *(uint32_t*)pCurBits = m_SemanticIndexTable.Entries;
  576. pCurBits += sizeof(uint32_t);
  577. memcpy(pCurBits, initInfo.SemanticIndexTable.Table, sizeof(uint32_t) * initInfo.SemanticIndexTable.Entries);
  578. m_SemanticIndexTable.Table = (const uint32_t*)pCurBits;
  579. pCurBits += sizeof(uint32_t) * m_SemanticIndexTable.Entries;
  580. // Dxil Signature Elements
  581. if (m_pPSVRuntimeInfo1->SigInputElements || m_pPSVRuntimeInfo1->SigOutputElements || m_pPSVRuntimeInfo1->SigPatchConstantElements) {
  582. *(uint32_t*)pCurBits = m_uPSVSignatureElementSize;
  583. pCurBits += sizeof(uint32_t);
  584. }
  585. if (m_pPSVRuntimeInfo1->SigInputElements) {
  586. m_pSigInputElements = (PSVSignatureElement0*)pCurBits;
  587. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigInputElements;
  588. }
  589. if (m_pPSVRuntimeInfo1->SigOutputElements) {
  590. m_pSigOutputElements = (PSVSignatureElement0*)pCurBits;
  591. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigOutputElements;
  592. }
  593. if (m_pPSVRuntimeInfo1->SigPatchConstantElements) {
  594. m_pSigPatchConstantElements = (PSVSignatureElement0*)pCurBits;
  595. pCurBits += m_uPSVSignatureElementSize * m_pPSVRuntimeInfo1->SigPatchConstantElements;
  596. }
  597. // ViewID dependencies
  598. if (m_pPSVRuntimeInfo1->UsesViewID) {
  599. for (unsigned i = 0; i < 4; i++) {
  600. if (m_pPSVRuntimeInfo1->SigOutputVectors[i]) {
  601. m_pViewIDOutputMask = (uint32_t*)pCurBits;
  602. pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  603. }
  604. if (!IsGS())
  605. break;
  606. }
  607. if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors) {
  608. m_pViewIDPCOutputMask = (uint32_t*)pCurBits;
  609. pCurBits += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigPatchConstantVectors);
  610. }
  611. }
  612. // Input to Output dependencies
  613. for (unsigned i = 0; i < 4; i++) {
  614. if (m_pPSVRuntimeInfo1->SigOutputVectors[i] > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
  615. m_pInputToOutputTable = (uint32_t*)pCurBits;
  616. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i]);
  617. }
  618. if (!IsGS())
  619. break;
  620. }
  621. if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
  622. m_pInputToPCOutputTable = (uint32_t*)pCurBits;
  623. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
  624. }
  625. if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstantVectors > 0) {
  626. m_pPCInputToOutputTable = (uint32_t*)pCurBits;
  627. pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
  628. }
  629. }
  630. return true;
  631. }
  632. PSVRuntimeInfo0* GetPSVRuntimeInfo0() const {
  633. return m_pPSVRuntimeInfo0;
  634. }
  635. PSVRuntimeInfo1* GetPSVRuntimeInfo1() const {
  636. return m_pPSVRuntimeInfo1;
  637. }
  638. uint32_t GetBindCount() const {
  639. return m_uResourceCount;
  640. }
  641. PSVResourceBindInfo0* GetPSVResourceBindInfo0(uint32_t index) const {
  642. if (index < m_uResourceCount && m_pPSVResourceBindInfo &&
  643. sizeof(PSVResourceBindInfo0) <= m_uPSVResourceBindInfoSize) {
  644. return (PSVResourceBindInfo0*)((uint8_t*)m_pPSVResourceBindInfo +
  645. (index * m_uPSVResourceBindInfoSize));
  646. }
  647. return nullptr;
  648. }
  649. const PSVStringTable &GetStringTable() const { return m_StringTable; }
  650. const PSVSemanticIndexTable &GetSemanticIndexTable() const { return m_SemanticIndexTable; }
  651. // Signature element access
  652. uint32_t GetSigInputElements() const {
  653. if (m_pPSVRuntimeInfo1)
  654. return m_pPSVRuntimeInfo1->SigInputElements;
  655. return 0;
  656. }
  657. uint32_t GetSigOutputElements() const {
  658. if (m_pPSVRuntimeInfo1)
  659. return m_pPSVRuntimeInfo1->SigOutputElements;
  660. return 0;
  661. }
  662. uint32_t GetSigPatchConstantElements() const {
  663. if (m_pPSVRuntimeInfo1)
  664. return m_pPSVRuntimeInfo1->SigPatchConstantElements;
  665. return 0;
  666. }
  667. PSVSignatureElement0* GetInputElement0(uint32_t index) const {
  668. if (m_pPSVRuntimeInfo1 && m_pSigInputElements &&
  669. index < m_pPSVRuntimeInfo1->SigInputElements &&
  670. sizeof(PSVSignatureElement0) <= m_uPSVSignatureElementSize) {
  671. return (PSVSignatureElement0*)((uint8_t*)m_pSigInputElements +
  672. (index * m_uPSVSignatureElementSize));
  673. }
  674. return nullptr;
  675. }
  676. PSVSignatureElement0* GetOutputElement0(uint32_t index) const {
  677. if (m_pPSVRuntimeInfo1 && m_pSigOutputElements &&
  678. index < m_pPSVRuntimeInfo1->SigOutputElements &&
  679. sizeof(PSVSignatureElement0) <= m_uPSVSignatureElementSize) {
  680. return (PSVSignatureElement0*)((uint8_t*)m_pSigOutputElements +
  681. (index * m_uPSVSignatureElementSize));
  682. }
  683. return nullptr;
  684. }
  685. PSVSignatureElement0* GetPatchConstantElement0(uint32_t index) const {
  686. if (m_pPSVRuntimeInfo1 && m_pSigPatchConstantElements &&
  687. index < m_pPSVRuntimeInfo1->SigPatchConstantElements &&
  688. sizeof(PSVSignatureElement0) <= m_uPSVSignatureElementSize) {
  689. return (PSVSignatureElement0*)((uint8_t*)m_pSigPatchConstantElements +
  690. (index * m_uPSVSignatureElementSize));
  691. }
  692. return nullptr;
  693. }
  694. // More convenient wrapper:
  695. PSVSignatureElement GetSignatureElement(PSVSignatureElement0* pElement0) const {
  696. return PSVSignatureElement(m_StringTable, m_SemanticIndexTable, pElement0);
  697. }
  698. PSVShaderKind GetShaderKind() const {
  699. if (m_pPSVRuntimeInfo1 && m_pPSVRuntimeInfo1->ShaderStage < (uint8_t)PSVShaderKind::Invalid)
  700. return (PSVShaderKind)m_pPSVRuntimeInfo1->ShaderStage;
  701. return PSVShaderKind::Invalid;
  702. }
  703. bool IsVS() const { return GetShaderKind() == PSVShaderKind::Vertex; }
  704. bool IsHS() const { return GetShaderKind() == PSVShaderKind::Hull; }
  705. bool IsDS() const { return GetShaderKind() == PSVShaderKind::Domain; }
  706. bool IsGS() const { return GetShaderKind() == PSVShaderKind::Geometry; }
  707. bool IsPS() const { return GetShaderKind() == PSVShaderKind::Pixel; }
  708. bool IsCS() const { return GetShaderKind() == PSVShaderKind::Compute; }
  709. // ViewID dependencies
  710. PSVComponentMask GetViewIDOutputMask(unsigned streamIndex = 0) const {
  711. if (!m_pViewIDOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex])
  712. return PSVComponentMask();
  713. return PSVComponentMask(m_pViewIDOutputMask, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
  714. }
  715. PSVComponentMask GetViewIDPCOutputMask() const {
  716. if (!IsHS() || !m_pViewIDPCOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigPatchConstantVectors)
  717. return PSVComponentMask();
  718. return PSVComponentMask(m_pViewIDPCOutputMask, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
  719. }
  720. // Input to Output dependencies
  721. PSVDependencyTable GetInputToOutputTable(unsigned streamIndex = 0) const {
  722. if (m_pInputToOutputTable && m_pPSVRuntimeInfo1) {
  723. return PSVDependencyTable(m_pInputToOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
  724. }
  725. return PSVDependencyTable();
  726. }
  727. PSVDependencyTable GetInputToPCOutputTable() const {
  728. if (IsHS() && m_pInputToPCOutputTable && m_pPSVRuntimeInfo1) {
  729. return PSVDependencyTable(m_pInputToPCOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstantVectors);
  730. }
  731. return PSVDependencyTable();
  732. }
  733. PSVDependencyTable GetPCInputToOutputTable() const {
  734. if (IsDS() && m_pPCInputToOutputTable && m_pPSVRuntimeInfo1) {
  735. return PSVDependencyTable(m_pPCInputToOutputTable, m_pPSVRuntimeInfo1->SigPatchConstantVectors, m_pPSVRuntimeInfo1->SigOutputVectors[0]);
  736. }
  737. return PSVDependencyTable();
  738. }
  739. };
  740. namespace hlsl {
  741. class ViewIDValidator {
  742. public:
  743. enum class Result {
  744. Success = 0,
  745. SuccessWithViewIDDependentTessFactor,
  746. InsufficientInputSpace,
  747. InsufficientOutputSpace,
  748. InsufficientPCSpace,
  749. MismatchedSignatures,
  750. MismatchedPCSignatures,
  751. InvalidUsage,
  752. InvalidPSVVersion,
  753. InvalidPSV,
  754. };
  755. virtual ~ViewIDValidator() {}
  756. virtual Result ValidateStage(const DxilPipelineStateValidation &PSV,
  757. bool bFinalStage,
  758. bool bExpandInputOnly,
  759. unsigned &mismatchElementId) = 0;
  760. };
  761. ViewIDValidator* NewViewIDValidator(unsigned viewIDCount, unsigned gsRastStreamIndex);
  762. }
  763. #endif // __DXIL_PIPELINE_STATE_VALIDATION__H__