Bladeren bron

Allow [Get/Set]NumThreads on Mesh/Amplification shaders (#2393)

Tex Riddell 6 jaren geleden
bovenliggende
commit
c8f7a6c970
3 gewijzigde bestanden met toevoegingen van 78 en 10 verwijderingen
  1. 1 1
      include/dxc/DXIL/DxilModule.h
  2. 15 9
      lib/DXIL/DxilModule.cpp
  3. 62 0
      tools/clang/unittests/HLSL/DxilModuleTest.cpp

+ 1 - 1
include/dxc/DXIL/DxilModule.h

@@ -225,7 +225,7 @@ public:
   // This funciton must be called after unused resources are removed from DxilModule
   bool ModuleHasMulticomponentUAVLoads();
 
-  // Compute shader.
+  // Compute/Mesh/Amplification shader.
   void SetNumThreads(unsigned x, unsigned y, unsigned z);
   unsigned GetNumThreads(unsigned idx) const;
 

+ 15 - 9
lib/DXIL/DxilModule.cpp

@@ -360,24 +360,30 @@ void DxilModule::CollectShaderFlagsForModule() {
 }
 
 void DxilModule::SetNumThreads(unsigned x, unsigned y, unsigned z) {
-  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsCS(),
-           "only works for CS profile");
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 &&
+           (m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()),
+           "only works for CS/MS/AS profiles");
   DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
-  DXASSERT(props.IsCS(), "Must be CS profile");
-  unsigned *numThreads = props.ShaderProps.CS.numThreads;
+  DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
+  unsigned *numThreads = props.IsCS() ? props.ShaderProps.CS.numThreads :
+    props.IsMS() ? props.ShaderProps.MS.numThreads : props.ShaderProps.AS.numThreads;
   numThreads[0] = x;
   numThreads[1] = y;
   numThreads[2] = z;
 }
 unsigned DxilModule::GetNumThreads(unsigned idx) const {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 &&
+           (m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()),
+           "only works for CS/MS/AS profiles");
   DXASSERT(idx < 3, "Thread dimension index must be 0-2");
-  if (!m_pSM->IsCS())
-    return 0;
-  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
   __analysis_assume(idx < 3);
+  if (!(m_pSM->IsCS() || m_pSM->IsMS() || m_pSM->IsAS()))
+    return 0;
   const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
-  DXASSERT(props.IsCS(), "Must be CS profile");
-  return props.ShaderProps.CS.numThreads[idx];
+  DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
+  const unsigned *numThreads = props.IsCS() ? props.ShaderProps.CS.numThreads :
+    props.IsMS() ? props.ShaderProps.MS.numThreads : props.ShaderProps.AS.numThreads;
+  return numThreads[idx];
 }
 
 DXIL::InputPrimitive DxilModule::GetInputPrimitive() const {

+ 62 - 0
tools/clang/unittests/HLSL/DxilModuleTest.cpp

@@ -61,6 +61,10 @@ public:
   TEST_METHOD(Precise5)
   TEST_METHOD(Precise6)
   TEST_METHOD(Precise7)
+
+  TEST_METHOD(CSGetNumThreads)
+  TEST_METHOD(MSGetNumThreads)
+  TEST_METHOD(ASGetNumThreads)
 };
 
 bool DxilModuleTest::InitSupport() {
@@ -425,3 +429,61 @@ TEST_F(DxilModuleTest, Precise7) {
   }
   VERIFY_ARE_EQUAL(numChecks, 4);
 }
+
+TEST_F(DxilModuleTest, CSGetNumThreads) {
+  Compiler c(m_dllSupport);
+  c.Compile(
+    "[numthreads(8, 4, 2)]\n"
+    "void main() {\n"
+    "}\n"
+    ,
+    L"cs_6_0"
+  );
+
+  DxilModule &DM = c.GetDxilModule();
+  VERIFY_ARE_EQUAL(8, DM.GetNumThreads(0));
+  VERIFY_ARE_EQUAL(4, DM.GetNumThreads(1));
+  VERIFY_ARE_EQUAL(2, DM.GetNumThreads(2));
+}
+
+TEST_F(DxilModuleTest, MSGetNumThreads) {
+  Compiler c(m_dllSupport);
+  if (c.SkipDxil_Test(1,5)) return;
+  c.Compile(
+    "struct MeshPerVertex { float4 pos : SV_Position; };\n"
+    "[numthreads(8, 4, 2)]\n"
+    "[outputtopology(\"triangle\")]\n"
+    "void main(\n"
+    "          out indices uint3 primIndices[1]\n"
+    ") {\n"
+    "    SetMeshOutputCounts(0, 0);\n"
+    "}\n"
+    ,
+    L"ms_6_5"
+  );
+
+  DxilModule &DM = c.GetDxilModule();
+  VERIFY_ARE_EQUAL(8, DM.GetNumThreads(0));
+  VERIFY_ARE_EQUAL(4, DM.GetNumThreads(1));
+  VERIFY_ARE_EQUAL(2, DM.GetNumThreads(2));
+}
+
+TEST_F(DxilModuleTest, ASGetNumThreads) {
+  Compiler c(m_dllSupport);
+  if (c.SkipDxil_Test(1,5)) return;
+  c.Compile(
+    "struct Payload { uint i; };\n"
+    "[numthreads(8, 4, 2)]\n"
+    "void main() {\n"
+    "  Payload pld = {0};\n"
+    "    DispatchMesh(1, 1, 1, pld);\n"
+    "}\n"
+    ,
+    L"as_6_5"
+  );
+
+  DxilModule &DM = c.GetDxilModule();
+  VERIFY_ARE_EQUAL(8, DM.GetNumThreads(0));
+  VERIFY_ARE_EQUAL(4, DM.GetNumThreads(1));
+  VERIFY_ARE_EQUAL(2, DM.GetNumThreads(2));
+}