Browse Source

DerivativesTest: quad2D layout and refactor (#3506)

Refactor Derivatives test to be clearer, more correct, and more in
keeping with the original intent. Center index is now nearer the center.
Conversion of 2D index to 1D is now correct for the quad layout.
Dispatch sizes are separated by CS, Mesh, and undefined results to make
testing easier and more complete.

Also adds the 2D quad formats while adapting 1D variants to be correct
with the latest spec
Greg Roth 4 years ago
parent
commit
6482ac9413
2 changed files with 157 additions and 88 deletions
  1. 58 20
      tools/clang/test/HLSL/ShaderOpArith.xml
  2. 99 68
      tools/clang/unittests/HLSL/ExecutionTest.cpp

+ 58 - 20
tools/clang/test/HLSL/ShaderOpArith.xml

@@ -109,7 +109,7 @@
       {32.0f, 64.0f, 128.0f, 256.0f},
       {256.0f, 512.0f, 1024.0f, 2048.0f}
     </Resource>
-    <Resource Name="RTarget" Dimension="TEXTURE2D" Width="32" Height="32" Format="R32G32B32A32_FLOAT" Flags="ALLOW_RENDER_TARGET" InitialResourceState="COPY_DEST" ReadBack="true" />
+    <Resource Name="RTarget" Dimension="TEXTURE2D" Width="64" Height="64" Format="R32G32B32A32_FLOAT" Flags="ALLOW_RENDER_TARGET" InitialResourceState="COPY_DEST" ReadBack="true" />
     <Resource Name="U0" Dimension="BUFFER" Width="16384"
               Flags="ALLOW_UNORDERED_ACCESS" InitialResourceState="COPY_DEST"
               Init="Zero" ReadBack="true" />
@@ -203,17 +203,66 @@
           { 1.0f, 0.0f },
           { 1.0f, 1.0f }};
 
-        [NumThreads(MESHDISPATCHX, MESHDISPATCHY, MESHDISPATCHZ)]
-        void ASMain(uint ix : SV_GroupIndex) {
+        uint convert2Dto1D(uint x, uint y, uint width) {
+          // Convert 2D coords to 1D for testing
+          // All completed rows of quads
+          uint prevRows = (y/2)*2*width;
+          // All previous full quads on this quad row
+          uint prevQuads = (x/2)*4;
+          // index into current quad
+          uint quadIx = (y&1)*2 + (x&1);
+          return prevRows + prevQuads + quadIx;
+        }
+
+        float4 PSMain(PSInput input) : SV_TARGET {
+          // Convert from texcoords into a groupIndex equivalent
+          int width = 64;
+          int height = 64;
+          int2 uv = int2(input.uv.x*width, input.uv.y*height);
+
+          uint ix = convert2Dto1D(uv.x, uv.y, DISPATCHX);
+
+          float4 res = 0.0;
+          if (uv.x < DISPATCHX && uv.y < DISPATCHY) {
+            res = DerivTest(uv);
+            g_bufMain[ix] = res;
+          }
+          return res;
+        }
+
+        [NumThreads(DISPATCHX, DISPATCHY, DISPATCHZ)]
+        void CSMain(uint3 id : SV_GroupThreadID, uint ix : SV_GroupIndex) {
+          if (DISPATCHY == 1 && DISPATCHZ == 1)
+            g_bufMain[ix] = DerivTest(ix);
+          else
+            g_bufMain[convert2Dto1D(id.x, id.y, DISPATCHX)] = DerivTest(id.xy);
+        }
+
+#if DISPATCHX * DISPATCHY * DISPATCHZ > 128
+#undef DISPATCHX
+#undef DISPATCHY
+#undef DISPATCHZ
+
+#define DISPATCHX 1
+#define DISPATCHY 1
+#define DISPATCHZ 1
+#endif
+
+        [NumThreads(DISPATCHX, DISPATCHY, DISPATCHZ)]
+        void ASMain(uint3 id : SV_GroupThreadID, uint ix : SV_GroupIndex) {
           Payload payload;
-          g_bufAmp[ix] = DerivTest(ix);
+          if (DISPATCHY == 1 && DISPATCHZ == 1)
+            g_bufAmp[ix] = DerivTest(ix);
+          else
+            g_bufAmp[convert2Dto1D(id.x, id.y, DISPATCHX)] = DerivTest(id.xy);
           payload.nothing = 0;
           DispatchMesh(1, 1, 1, payload);
         }
 
-        [NumThreads(MESHDISPATCHX, MESHDISPATCHY, MESHDISPATCHZ)]
+        [NumThreads(DISPATCHX, DISPATCHY, DISPATCHZ)]
         [OutputTopology("triangle")]
         void MSMain(
+          uint3 id : SV_GroupThreadID,
           uint ix : SV_GroupIndex,
           in payload Payload payload,
           out vertices PSInput verts[6],
@@ -224,23 +273,12 @@
             verts[ix%6].uv = g_UV[ix%6];
             tris[ix&1] = uint3((ix&1)*3, (ix&1)*3 + 1, (ix&1)*3 + 2);
             g_bufMesh[ix] = DerivTest(ix);
-        }
-        float4 PSMain(PSInput input) : SV_TARGET {
-          // Convert from texcoords into a groupIndex equivalent
-          int width = DISPATCHX;
-          int height = DISPATCHY;
-          int2 uv = int2(input.uv.x*width, input.uv.y*height);
-          uint ix = ((uv.y/4)*(width/4))*16 + (uv.x/4)*16 + (((uv.x & 0x2) << 1) | (uv.x & 0x1) | ((uv.y & 0x2) << 2) | ((uv.y & 0x1) << 1));
-
-          float4 res = DerivTest(ix);
-          g_bufMain[ix] = res;
-          return res;
+            if (DISPATCHY == 1 && DISPATCHZ == 1)
+              g_bufMesh[ix] = DerivTest(ix);
+            else
+              g_bufMesh[convert2Dto1D(id.x, id.y, DISPATCHX)] = DerivTest(id.xy);
         }
 
-        [NumThreads(DISPATCHX, DISPATCHY, DISPATCHZ)]
-        void CSMain(uint ix : SV_GroupIndex) {
-          g_bufMain[ix] = DerivTest(ix);
-        }
       ]]>
     </Shader>
   </ShaderOp>

+ 99 - 68
tools/clang/unittests/HLSL/ExecutionTest.cpp

@@ -3059,7 +3059,40 @@ TEST_F(ExecutionTest, PartialDerivTest) {
   VerifyDerivResults(pPixels, offsetCenter);
 }
 
+struct Dispatch {
+  int width, height, depth;
+};
+
+std::shared_ptr<st::ShaderOpTest>
+RunDispatch(ID3D12Device *pDevice, dxc::DxcDllSupport &support,
+            st::ShaderOp *pShaderOp, const Dispatch D) {
+  char compilerOptions[256];
+
+  std::shared_ptr<st::ShaderOpTest> test = std::make_shared<st::ShaderOpTest>();
+  test->SetDxcSupport(&support);
+  test->SetInitCallback(nullptr);
+  test->SetDevice(pDevice);
+
+  // format compiler args
+  VERIFY_IS_TRUE(sprintf_s(compilerOptions, sizeof(compilerOptions),
+                           "-D DISPATCHX=%d -D DISPATCHY=%d -D DISPATCHZ=%d ",
+                           D.width, D.height, D.depth));
+
+  for (st::ShaderOpShader &S : pShaderOp->Shaders)
+    S.Arguments = compilerOptions;
+
+  pShaderOp->DispatchX = D.width;
+  pShaderOp->DispatchY = D.height;
+  pShaderOp->DispatchZ = D.depth;
+
+  test->RunShaderOp(pShaderOp);
+
+  return test;
+}
+
 TEST_F(ExecutionTest, DerivativesTest) {
+  const UINT pixelSize = 4; // always float4
+
   WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
   CComPtr<IStream> pStream;
   ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
@@ -3074,99 +3107,97 @@ TEST_F(ExecutionTest, DerivativesTest) {
 
   st::ShaderOp *pShaderOp = ShaderOpSet->GetShaderOp("Derivatives");
 
-  LPCSTR CS = pShaderOp->CS;
-
-  struct Dispatch {
-    int x, y, z;
-    int mx, my, mz;
-  };
   std::vector<Dispatch> dispatches =
   {
-   {32, 32, 1, 8, 8, 1},
-   {64, 4, 1, 64, 2, 1},
-   {1, 4, 64, 1, 4, 32},
-   {64, 1, 1, 64, 1, 1},
-   {1, 64, 1, 1, 64, 1},
-   {1, 1, 64, 1, 1, 64},
-   {16, 16, 3, 4, 4, 3},
-   {32, 3, 8, 8, 3, 2},
-   {3, 1, 64, 3, 1, 32}
+   {40, 1, 1},
+   {1000, 1, 1},
+   {32, 32, 1},
+   {16, 64, 1},
+   {4, 12, 4},
+   {4, 64, 1},
+   {16, 16, 3},
+   {32, 8, 2}
   };
 
-  char compilerOptions[256];
-  for (Dispatch &D : dispatches) {
-
-    UINT width = D.x;
-    UINT height = D.y;
-    UINT depth = D.z;
-
-    UINT mwidth = D.mx;
-    UINT mheight = D.my;
-    UINT mdepth = D.mz;
-    UINT pixelSize = 4; // always float4
+  std::vector<Dispatch> meshDispatches =
+  {
+   {60, 1, 1},
+   {128, 1, 1},
+   {8, 8, 1},
+   {32, 8, 1},
+   {8, 16, 4},
+   {8, 64, 1},
+   {8, 8, 3},
+  };
 
-    // format compiler args
-    VERIFY_IS_TRUE(sprintf_s(compilerOptions, sizeof(compilerOptions),
-                             "-D DISPATCHX=%d -D DISPATCHY=%d -D DISPATCHZ=%d "
-                             "-D MESHDISPATCHX=%d -D MESHDISPATCHY=%d -D MESHDISPATCHZ=%d",
-                             width, height, depth, mwidth, mheight, mdepth));
+  std::vector<Dispatch> badDispatches =
+  {
+   {16, 3, 1},
+   {2, 16, 1},
+   {33, 1, 1}
+  };
 
-    for (st::ShaderOpShader &S : pShaderOp->Shaders)
-      S.Arguments = compilerOptions;
+  pShaderOp->UseWarpDevice = GetTestParamUseWARP(true);
+  LPCSTR CS = pShaderOp->CS;
 
-    pShaderOp->DispatchX = width;
-    pShaderOp->DispatchY = height;
-    pShaderOp->DispatchZ = depth;
+  MappedData data;
 
+  for (Dispatch &D : dispatches) {
     // Test Compute Shader
-    pShaderOp->CS = CS;
-    std::shared_ptr<ShaderOpTestResult> test = RunShaderOpTestAfterParse(pDevice, m_support, "Derivatives", nullptr, ShaderOpSet);
-    MappedData data;
+    std::shared_ptr<st::ShaderOpTest> test = RunDispatch(pDevice, m_support, pShaderOp, D);
 
-    test->Test->GetReadBackData("U0", &data);
-    const float *pPixels = (float *)data.data();
+    test->GetReadBackData("U0", &data);
+
+    float *pPixels = (float *)data.data();;
 
-    // To find roughly the center for compute, divide the pixel count in half,
-    // truncate to next lowest power of 16 (4x4), which is the repeating period
-    // and then add 10 to reach the point the test expects
-    UINT centerIndex = (((UINT64)(width * height * depth)/2) & ~0xF) + 10;
+    UINT centerIndex = 0;
+    if (D.height == 1) {
+      centerIndex = (((UINT64)(D.width * D.height * D.depth) / 2) & ~0xF) + 10;
+    } else {
+      // To find roughly the center for compute, divide the height and width in half,
+      // truncate to the previous multiple of 4 to get to the start of the repeating pattern
+      // and then add 2 rows to get to the second row of quads and 2 to get to the first texel
+      // of the second row of that quad row
+      UINT centerRow = ((D.height/2UL) & ~0x3) + 2;
+      UINT centerCol = ((D.width/2UL) & ~0x3) + 2;
+      centerIndex = centerRow * D.width + centerCol;
+    }
     UINT offsetCenter = centerIndex * pixelSize;
     LogCommentFmt(L"Verifying derivatives in compute shader results");
     VerifyDerivResults(pPixels, offsetCenter);
+  }
 
-    if (DoesDeviceSupportMeshAmpDerivatives(pDevice)) {
-      // Disable CS so mesh goes forward
-      pShaderOp->CS = nullptr;
-      test = RunShaderOpTestAfterParse(pDevice, m_support, "Derivatives", nullptr, ShaderOpSet);
-      test->Test->GetReadBackData("U1", &data);
-      pPixels = (float *)data.data();
-      centerIndex = (((UINT64)(mwidth * mheight * mdepth)/2) & ~0xF) + 10;
-      offsetCenter = centerIndex * pixelSize;
+  if (DoesDeviceSupportMeshAmpDerivatives(pDevice)) {
+    // Disable CS so mesh goes forward
+    pShaderOp->CS = nullptr;
+
+    for (Dispatch &D : meshDispatches) {
+      std::shared_ptr<st::ShaderOpTest> test = RunDispatch(pDevice, m_support, pShaderOp, D);
+
+      test->GetReadBackData("U1", &data);
+      const float *pPixels = (float *)data.data();
+      UINT centerIndex = (((UINT64)(D.width * D.height * D.depth)/2) & ~0xF) + 10;
+      UINT offsetCenter = centerIndex * pixelSize;
       LogCommentFmt(L"Verifying derivatives in mesh shader results");
       VerifyDerivResults(pPixels, offsetCenter);
 
-      test->Test->GetReadBackData("U2", &data);
+      test->GetReadBackData("U2", &data);
       pPixels = (float *)data.data();
       LogCommentFmt(L"Verifying derivatives in amplification shader results");
       VerifyDerivResults(pPixels, offsetCenter);
     }
   }
 
-  // Final test with not divisible by 4 dispatch size just to make sure it runs
-  for (st::ShaderOpShader &S : pShaderOp->Shaders)
-    S.Arguments = "-D DISPATCHX=3 -D DISPATCHY=3 -D DISPATCHZ=3 "
-                  "-D MESHDISPATCHX=3 -D MESHDISPATCHY=3 -D MESHDISPATCHZ=3";
-
-  pShaderOp->DispatchX = 3;
-  pShaderOp->DispatchY = 3;
-  pShaderOp->DispatchZ = 3;
+  // Final tests with invalid dispatch size just to make sure they run
+  for (Dispatch &D : badDispatches) {
+    // Test Compute Shader
+    pShaderOp->CS = CS;
+    std::shared_ptr<st::ShaderOpTest> test = RunDispatch(pDevice, m_support, pShaderOp, D);
 
-  // Test Compute Shader
-  pShaderOp->CS = CS;
-  std::shared_ptr<ShaderOpTestResult> test = RunShaderOpTestAfterParse(pDevice, m_support, "Derivatives", nullptr, ShaderOpSet);
-  if (DoesDeviceSupportMeshAmpDerivatives(pDevice)) {
-    pShaderOp->CS = nullptr;
-    test = RunShaderOpTestAfterParse(pDevice, m_support, "Derivatives", nullptr, ShaderOpSet);
+    if (DoesDeviceSupportMeshAmpDerivatives(pDevice)) {
+      pShaderOp->CS = nullptr;
+      test = RunDispatch(pDevice, m_support, pShaderOp, D);
+    }
   }
 }