Переглянути джерело

PSV0: Fix GS multi stream out ViewID dependency serialization (#5204)

GetViewIDOutputMask and GetInputToOutputTable take an output index to support multiple stream outputs in geometry shaders.
Previously, these functions would map the same local pointer to a location for serializing or deserializing the dependency bit vectors.
This causes data to overlap and size to be computed based only on the last used stream.

This change fixes this by using an array of pointers for each of these.
Then, a new test verifies output for two streams, which fails without this fix.
Tex Riddell 2 роки тому
батько
коміт
ccd97acc02

+ 15 - 14
include/dxc/DxilContainer/DxilPipelineStateValidation.h

@@ -31,6 +31,7 @@ inline uint32_t PSVComputeInputOutputTableDwords(uint32_t InputVectors, uint32_t
 }
 #define PSVALIGN(ptr, alignbits) (((ptr) + ((1 << (alignbits))-1)) & ~((1 << (alignbits))-1))
 #define PSVALIGN4(ptr) (((ptr) + 3) & ~3)
+#define PSV_GS_MAX_STREAMS 4
 
 #ifndef NDEBUG
 #define PSV_RETB(exp) do { if(!(exp)) { assert(false && #exp); return false; } } while(0)
@@ -55,7 +56,7 @@ struct DSInfo {
 struct GSInfo {
   uint32_t InputPrimitive;              // hlsl::DXIL::InputPrimitive/D3D10_SB_PRIMITIVE
   uint32_t OutputTopology;              // hlsl::DXIL::PrimitiveTopology/D3D10_SB_PRIMITIVE_TOPOLOGY
-  uint32_t OutputStreamMask;            // max streams == 4
+  uint32_t OutputStreamMask;            // max streams == 4 (PSV_GS_MAX_STREAMS)
   char OutputPositionPresent;
 };
 struct PSInfo {
@@ -130,7 +131,7 @@ struct PSVRuntimeInfo1 : public PSVRuntimeInfo0
 
   // Number of packed vectors per signature
   uint8_t SigInputVectors;
-  uint8_t SigOutputVectors[4];      // Array for GS Stream Out Index
+  uint8_t SigOutputVectors[PSV_GS_MAX_STREAMS];      // Array for GS Stream Out Index
 };
 
 struct PSVRuntimeInfo2 : public PSVRuntimeInfo1
@@ -382,7 +383,7 @@ struct PSVInitInfo
   uint8_t SigPatchConstOrPrimElements = 0;
   uint8_t SigInputVectors = 0;
   uint8_t SigPatchConstOrPrimVectors = 0;
-  uint8_t SigOutputVectors[4] = {0, 0, 0, 0};
+  uint8_t SigOutputVectors[PSV_GS_MAX_STREAMS] = {0, 0, 0, 0};
 
   static_assert(MAX_PSV_VERSION == 2, "otherwise this needs updating.");
   uint32_t RuntimeInfoSize() const {
@@ -418,9 +419,9 @@ class DxilPipelineStateValidation
   void *m_pSigInputElements = nullptr;
   void *m_pSigOutputElements = nullptr;
   void *m_pSigPatchConstOrPrimElements = nullptr;
-  uint32_t *m_pViewIDOutputMask = nullptr;
+  uint32_t *m_pViewIDOutputMask[PSV_GS_MAX_STREAMS] = {nullptr, nullptr, nullptr, nullptr};
   uint32_t *m_pViewIDPCOrPrimOutputMask = nullptr;
-  uint32_t *m_pInputToOutputTable = nullptr;
+  uint32_t *m_pInputToOutputTable[PSV_GS_MAX_STREAMS] = {nullptr, nullptr, nullptr, nullptr};
   uint32_t *m_pInputToPCOutputTable = nullptr;
   uint32_t *m_pPCInputToOutputTable = nullptr;
 
@@ -599,9 +600,9 @@ public:
 
   // ViewID dependencies
   PSVComponentMask GetViewIDOutputMask(unsigned streamIndex = 0) const {
-    if (!m_pViewIDOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex])
+    if (streamIndex >= PSV_GS_MAX_STREAMS || !m_pViewIDOutputMask[streamIndex] || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex])
       return PSVComponentMask();
-    return PSVComponentMask(m_pViewIDOutputMask, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
+    return PSVComponentMask(m_pViewIDOutputMask[streamIndex], m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
   }
   PSVComponentMask GetViewIDPCOutputMask() const {
     if ((!IsHS() && !IsMS()) || !m_pViewIDPCOrPrimOutputMask || !m_pPSVRuntimeInfo1 || !m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors)
@@ -611,8 +612,8 @@ public:
 
   // Input to Output dependencies
   PSVDependencyTable GetInputToOutputTable(unsigned streamIndex = 0) const {
-    if (m_pInputToOutputTable && m_pPSVRuntimeInfo1) {
-      return PSVDependencyTable(m_pInputToOutputTable, m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
+    if (streamIndex < PSV_GS_MAX_STREAMS && m_pInputToOutputTable[streamIndex] && m_pPSVRuntimeInfo1) {
+      return PSVDependencyTable(m_pInputToOutputTable[streamIndex], m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[streamIndex]);
     }
     return PSVDependencyTable();
   }
@@ -802,7 +803,7 @@ inline bool DxilPipelineStateValidation::ReadOrWrite(
       m_pPSVRuntimeInfo1->SigOutputElements = initInfo.SigOutputElements;
       m_pPSVRuntimeInfo1->SigPatchConstOrPrimElements = initInfo.SigPatchConstOrPrimElements;
       m_pPSVRuntimeInfo1->UsesViewID = initInfo.UsesViewID;
-      for (unsigned i = 0; i < 4; i++) {
+      for (unsigned i = 0; i < PSV_GS_MAX_STREAMS; i++) {
         m_pPSVRuntimeInfo1->SigOutputVectors[i] = initInfo.SigOutputVectors[i];
       }
       if (IsHS() || IsDS() || IsMS()) {
@@ -845,9 +846,9 @@ inline bool DxilPipelineStateValidation::ReadOrWrite(
 
     // ViewID dependencies
     if (m_pPSVRuntimeInfo1->UsesViewID) {
-      for (unsigned i = 0; i < 4; i++) {
+      for (unsigned i = 0; i < PSV_GS_MAX_STREAMS; i++) {
         if (m_pPSVRuntimeInfo1->SigOutputVectors[i]) {
-          PSV_RETB(rw.MapArray(&m_pViewIDOutputMask,
+          PSV_RETB(rw.MapArray(&m_pViewIDOutputMask[i],
             PSVComputeMaskDwordsFromVectors(m_pPSVRuntimeInfo1->SigOutputVectors[i])));
         }
         if (!IsGS())
@@ -860,9 +861,9 @@ inline bool DxilPipelineStateValidation::ReadOrWrite(
     }
 
     // Input to Output dependencies
-    for (unsigned i = 0; i < 4; i++) {
+    for (unsigned i = 0; i < PSV_GS_MAX_STREAMS; i++) {
       if (!IsMS() && m_pPSVRuntimeInfo1->SigOutputVectors[i] > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
-        PSV_RETB(rw.MapArray(&m_pInputToOutputTable,
+        PSV_RETB(rw.MapArray(&m_pInputToOutputTable[i],
           PSVComputeInputOutputTableDwords(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i])));
       }
       if (!IsGS())

+ 111 - 0
tools/clang/unittests/HLSL/DxilContainerTest.cpp

@@ -97,6 +97,7 @@ public:
 
   TEST_METHOD(CompileWhenDebugSourceThenSourceMatters)
   TEST_METHOD(CompileAS_CheckPSV0)
+  TEST_METHOD(CompileGS_CheckPSV0_ViewID)
   TEST_METHOD(CompileWhenOkThenCheckRDAT)
   TEST_METHOD(CompileWhenOkThenCheckRDAT2)
   TEST_METHOD(CompileWhenOkThenCheckReflection1)
@@ -940,6 +941,116 @@ TEST_F(DxilContainerTest, CompileAS_CheckPSV0) {
   VERIFY_IS_TRUE(blobFound);
 }
 
+TEST_F(DxilContainerTest, CompileGS_CheckPSV0_ViewID) {
+  if (m_ver.SkipDxilVersion(1, 7)) return;
+
+  // Verify that ViewID and Input to Output masks are correct for
+  // geometry shader with multiple stream outputs.
+  // This acts as a regression test for issue #5199, where a single pointer was
+  // used for the ViewID dependency mask and a single pointer for input to
+  // output dependencies, when each of these needed a separate pointer per
+  // stream.  The effect was an overlapping and clobbering of data for earlier
+  // streams.
+  // Skip validator versions < 1.7 since they lack the fix.
+
+  const unsigned ARRAY_SIZE = 8;
+  const char gsSource[] =
+    "#define ARRAY_SIZE 8\n"
+    "struct GSOut0 { float4 pos : SV_Position; };\n"
+    "struct GSOut1 { float4 arr[ARRAY_SIZE] : Array; };\n"
+    "[shader(\"geometry\")]\n"
+    "[maxvertexcount(1)]\n"
+    "void main(point float4 input[1] : COORD,\n"
+    "          inout PointStream<GSOut0> out0,\n"
+    "          inout PointStream<GSOut1> out1,\n"
+    "          uint vid : SV_ViewID) {\n"
+    " GSOut0 o0 = (GSOut0)0;\n"
+    " GSOut1 o1 = (GSOut1)0;\n"
+    " o0.pos = input[0];\n"
+    " out0.Append(o0);\n"
+    " out0.RestartStrip();\n"
+    " [unroll]\n"
+    " for (uint i = 0; i < ARRAY_SIZE; i++)\n"
+    "   o1.arr[i] = input[0][i%4] + vid;\n"
+    " out1.Append(o1);\n"
+    " out1.RestartStrip();\n"
+    "}";
+
+  CComPtr<IDxcCompiler> pCompiler;
+  CComPtr<IDxcBlobEncoding> pSource;
+  CComPtr<IDxcBlob> pProgram;
+  CComPtr<IDxcOperationResult> pResult;
+
+  // Compile the shader
+  VERIFY_SUCCEEDED(CreateCompiler(&pCompiler));
+  CreateBlobFromText(gsSource, &pSource);
+  VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"hlsl.hlsl", L"main",
+                                      L"gs_6_3", nullptr, 0, nullptr, 0,
+                                      nullptr, &pResult));
+  HRESULT hrStatus;
+  VERIFY_SUCCEEDED(pResult->GetStatus(&hrStatus));
+  VERIFY_SUCCEEDED(hrStatus);
+  VERIFY_SUCCEEDED(pResult->GetResult(&pProgram));
+
+  // Get PSV0 part
+  CComPtr<IDxcContainerReflection> containerReflection;
+  VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcContainerReflection, &containerReflection));
+  VERIFY_SUCCEEDED(containerReflection->Load(pProgram));
+  uint32_t partIdx = 0;
+  VERIFY_SUCCEEDED(containerReflection->FindFirstPartKind((uint32_t)hlsl::DxilFourCC::DFCC_PipelineStateValidation, &partIdx));
+  CComPtr<IDxcBlob> pBlob;
+  VERIFY_SUCCEEDED(containerReflection->GetPartContent(partIdx, &pBlob));
+
+  // Init PSV and verify ViewID masks and Input to Output Masks
+  DxilPipelineStateValidation PSV;
+  PSV.InitFromPSV0(pBlob->GetBufferPointer(), pBlob->GetBufferSize());
+  PSVShaderKind kind = PSV.GetShaderKind();
+  VERIFY_ARE_EQUAL(PSVShaderKind::Geometry, kind);
+  PSVRuntimeInfo0* pInfo = PSV.GetPSVRuntimeInfo0();
+  VERIFY_IS_NOT_NULL(pInfo);
+
+  // Stream 0 should have no direct ViewID dependency:
+  PSVComponentMask viewIDMask0 = PSV.GetViewIDOutputMask(0);
+  VERIFY_IS_TRUE(viewIDMask0.IsValid());
+  VERIFY_ARE_EQUAL(1U, viewIDMask0.NumVectors);
+  for (unsigned i = 0; i < 4; i++) {
+    VERIFY_IS_FALSE(viewIDMask0.Get(i));
+  }
+
+  // Everything in stream 1 should be dependent on ViewID:
+  PSVComponentMask viewIDMask1 = PSV.GetViewIDOutputMask(1);
+  VERIFY_IS_TRUE(viewIDMask1.IsValid());
+  VERIFY_ARE_EQUAL(ARRAY_SIZE, viewIDMask1.NumVectors);
+  for (unsigned i = 0; i < ARRAY_SIZE * 4; i++) {
+    VERIFY_IS_TRUE(viewIDMask1.Get(i));
+  }
+
+  // Stream 0 is simple assignment of input vector:
+  PSVDependencyTable ioTable0 = PSV.GetInputToOutputTable(0);
+  VERIFY_IS_TRUE(ioTable0.IsValid());
+  VERIFY_ARE_EQUAL(1U, ioTable0.OutputVectors);
+  for (unsigned i = 0; i < 4; i++) {
+    PSVComponentMask ioMask0_i = ioTable0.GetMaskForInput(i);
+    // input 0123 -> output 0123
+    for (unsigned j = 0; j < 4; j++) {
+      VERIFY_ARE_EQUAL(i == j, ioMask0_i.Get(j));
+    }
+  }
+
+  // Each vector in Stream 1 combines one component of input with ViewID:
+  PSVDependencyTable ioTable1 = PSV.GetInputToOutputTable(1);
+  VERIFY_IS_TRUE(ioTable1.IsValid());
+  VERIFY_ARE_EQUAL(ARRAY_SIZE, ioTable1.OutputVectors);
+  for (unsigned i = 0; i < 4; i++) {
+    PSVComponentMask ioMask1_i = ioTable1.GetMaskForInput(i);
+    for (unsigned j = 0; j < ARRAY_SIZE * 4; j++) {
+      // Vector output vector/component dependency by input component:
+      // 0000, 1111, 2222, 3333, 0000, 1111, 2222, 3333, ...
+      VERIFY_ARE_EQUAL(i == (j / 4) % 4, ioMask1_i.Get(j));
+    }
+  }
+}
+
 TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
   if (m_ver.SkipDxilVersion(1, 3)) return;
   const char *shader = "float c_buf;"