浏览代码

Add DxilResourceFlags to DxilRuntimeReflection for UAV Counter/GC/ROV.

Tex Riddell 7 年之前
父节点
当前提交
ae9f78ff15

+ 10 - 2
include/dxc/HLSL/DxilRuntimeReflection.h

@@ -74,6 +74,14 @@ public:
   }
 };
 
+enum class DxilResourceFlag : uint32_t {
+  None                      = 0,
+  UAVGloballyCoherent       = 1 << 0,
+  UAVCounter                = 1 << 1,
+  UAVRasterizerOrderedView  = 1 << 2,
+  DynamicIndexing           = 1 << 3,
+};
+
 struct RuntimeDataResourceInfo {
   uint32_t Class; // hlsl::DXIL::ResourceClass
   uint32_t Kind;  // hlsl::DXIL::ResourceKind
@@ -82,7 +90,7 @@ struct RuntimeDataResourceInfo {
   uint32_t LowerBound;
   uint32_t UpperBound;
   uint32_t Name;  // resource name as an offset for string table
-  uint32_t Flags; // Not implemented yet
+  uint32_t Flags; // hlsl::RDAT::DxilResourceFlag
 };
 
 struct RuntimeDataFunctionInfo {
@@ -351,7 +359,7 @@ typedef struct DXIL_RESOURCE {
   uint32_t UpperBound;
   uint32_t LowerBound;
   LPCWSTR Name;
-  uint32_t Flags;
+  uint32_t Flags; // hlsl::RDAT::DxilResourceFlag
 } DXIL_RESOURCE;
 
 typedef struct DXIL_FUNCTION {

+ 10 - 0
lib/HLSL/DxilContainerAssembler.cpp

@@ -877,6 +877,16 @@ private:
     info.UpperBound = resource.GetUpperBound();
     info.Name = stringIndex;
     info.Flags = 0;
+    if (ResourceClass::UAV == resourceClass) {
+      DxilResource *pRes = static_cast<DxilResource*>(&resource);
+      if (pRes->HasCounter())
+        info.Flags |= static_cast<uint32_t>(DxilResourceFlag::UAVCounter);
+      if (pRes->IsGloballyCoherent())
+        info.Flags |= static_cast<uint32_t>(DxilResourceFlag::UAVGloballyCoherent);
+      if (pRes->IsROV())
+        info.Flags |= static_cast<uint32_t>(DxilResourceFlag::UAVRasterizerOrderedView);
+      // TODO: add dynamic index flag
+    }
     resourceTable.Insert(info);
   }
 

+ 86 - 26
tools/clang/unittests/HLSL/DxilContainerTest.cpp

@@ -674,23 +674,47 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
     "RWTexture1D<int4> tex : register(u5);"
     "Texture1D<float4> tex2 : register(t0);"
     "RWByteAddressBuffer b_buf;"
+    "struct Foo { float2 f2; int2 i2; };"
+    "AppendStructuredBuffer<Foo> append_buf;"
+    "ConsumeStructuredBuffer<Foo> consume_buf;"
+    "RasterizerOrderedByteAddressBuffer rov_buf;"
+    "globallycoherent RWByteAddressBuffer gc_buf;"
     "float function_import(float x);"
     "float function0(min16float x) { "
     "  return x + 1 + tex[0].x; }"
     "float function1(float x, min12int i) {"
     "  return x + c_buf + b_buf.Load(x) + tex2[i].x; }"
-    "float function2(float x) { return x + function_import(x); }";
+    "float function2(float x) { return x + function_import(x); }"
+    "float function3(int i) {"
+    "  Foo f = consume_buf.Consume();"
+    "  f.f2 += 0.5; append_buf.Append(f);"
+    "  rov_buf.Store(i, f.i2.x);"
+    "  gc_buf.Store(i, f.i2.y);"
+    "  b_buf.Store(i, f.i2.x + f.i2.y); }";
   CComPtr<IDxcCompiler> pCompiler;
   CComPtr<IDxcBlobEncoding> pSource;
   CComPtr<IDxcBlob> pProgram;
   CComPtr<IDxcBlobEncoding> pDisassembly;
   CComPtr<IDxcOperationResult> pResult;
 
+  struct CheckResFlagInfo { std::string name; hlsl::DXIL::ResourceKind kind; hlsl::RDAT::DxilResourceFlag flag; };
+  const unsigned numResFlagCheck = 5;
+  CheckResFlagInfo resFlags[numResFlagCheck] = {
+    { "b_buf", hlsl::DXIL::ResourceKind::RawBuffer, hlsl::RDAT::DxilResourceFlag::None },
+    { "append_buf", hlsl::DXIL::ResourceKind::StructuredBuffer, hlsl::RDAT::DxilResourceFlag::UAVCounter },
+    { "consume_buf", hlsl::DXIL::ResourceKind::StructuredBuffer, hlsl::RDAT::DxilResourceFlag::UAVCounter },
+    { "gc_buf", hlsl::DXIL::ResourceKind::RawBuffer, hlsl::RDAT::DxilResourceFlag::UAVGloballyCoherent },
+    { "rov_buf", hlsl::DXIL::ResourceKind::RawBuffer, hlsl::RDAT::DxilResourceFlag::UAVRasterizerOrderedView }
+  };
+
   VERIFY_SUCCEEDED(CreateCompiler(&pCompiler));
   CreateBlobFromText(shader, &pSource);
   VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"hlsl.hlsl", L"main",
                                       L"lib_6_3", nullptr, 0, nullptr, 0,
                                       nullptr, &pResult));
+  HRESULT hrStatus;
+  VERIFY_SUCCEEDED(pResult->GetStatus(&hrStatus));
+  VERIFY_SUCCEEDED(hrStatus);
   VERIFY_SUCCEEDED(pResult->GetResult(&pProgram));
   CComPtr<IDxcContainerReflection> containerReflection;
   uint32_t partCount;
@@ -711,7 +735,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
       context.InitFromRDAT((char *)pBlob->GetBufferPointer());
       FunctionTableReader *funcTableReader = context.GetFunctionTableReader();
       ResourceTableReader *resTableReader = context.GetResourceTableReader();
-      VERIFY_IS_TRUE(funcTableReader->GetNumFunctions() == 3);
+      VERIFY_ARE_EQUAL(funcTableReader->GetNumFunctions(), 4);
       std::string str("function");
       for (uint32_t j = 0; j < funcTableReader->GetNumFunctions(); ++j) {
         FunctionReader funcReader = funcTableReader->GetItem(j);
@@ -720,39 +744,55 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
         std::string cur_str = str;
         cur_str.push_back('0' + j);
         if (cur_str.compare("function0") == 0) {
-          VERIFY_IS_TRUE(funcReader.GetNumResources() == 1);
+          VERIFY_ARE_EQUAL(funcReader.GetNumResources(), 1);
           hlsl::ShaderFlags flag;
           flag.SetUAVLoadAdditionalFormats(true);
           flag.SetLowPrecisionPresent(true);
           uint64_t rawFlag = flag.GetShaderFlagsRaw();
-          VERIFY_IS_TRUE(funcReader.GetFeatureFlag() == rawFlag);
+          VERIFY_ARE_EQUAL(funcReader.GetFeatureFlag(), rawFlag);
           ResourceReader resReader = funcReader.GetResource(0);
-          VERIFY_IS_TRUE(resReader.GetResourceClass() == hlsl::DXIL::ResourceClass::UAV);
-          VERIFY_IS_TRUE(resReader.GetResourceKind() == hlsl::DXIL::ResourceKind::Texture1D);
+          VERIFY_ARE_EQUAL(resReader.GetResourceClass(), hlsl::DXIL::ResourceClass::UAV);
+          VERIFY_ARE_EQUAL(resReader.GetResourceKind(), hlsl::DXIL::ResourceKind::Texture1D);
         }
         else if (cur_str.compare("function1") == 0) {
           hlsl::ShaderFlags flag;
           flag.SetLowPrecisionPresent(true);
           uint64_t rawFlag = flag.GetShaderFlagsRaw();
-          VERIFY_IS_TRUE(funcReader.GetFeatureFlag() == rawFlag);
-          VERIFY_IS_TRUE(funcReader.GetNumResources() == 3);
+          VERIFY_ARE_EQUAL(funcReader.GetFeatureFlag(), rawFlag);
+          VERIFY_ARE_EQUAL(funcReader.GetNumResources(), 3);
         }
         else if (cur_str.compare("function2") == 0) {
-          VERIFY_IS_TRUE((funcReader.GetFeatureFlag() & 0xffffffffffffffff) == 0);
-          VERIFY_IS_TRUE(funcReader.GetNumResources() == 0);
+          VERIFY_ARE_EQUAL(funcReader.GetFeatureFlag() & 0xffffffffffffffff, 0);
+          VERIFY_ARE_EQUAL(funcReader.GetNumResources(), 0);
           std::string dependency = funcReader.GetDependency(0);
           VERIFY_IS_TRUE(dependency.find("function_import") != std::string::npos);
         }
+        else if (cur_str.compare("function3") == 0) {
+          VERIFY_ARE_EQUAL(funcReader.GetFeatureFlag() & 0xffffffffffffffff, 0);
+          VERIFY_ARE_EQUAL(funcReader.GetNumResources(), numResFlagCheck);
+          for (unsigned i = 0; i < funcReader.GetNumResources(); ++i) {
+            ResourceReader resReader = funcReader.GetResource(0);
+            VERIFY_ARE_EQUAL(resReader.GetResourceClass(), hlsl::DXIL::ResourceClass::UAV);
+            unsigned j = 0;
+            for (; j < numResFlagCheck; ++j) {
+              if (resFlags[j].name.compare(resReader.GetName()) == 0)
+                break;
+            }
+            VERIFY_IS_LESS_THAN(j, numResFlagCheck);
+            VERIFY_ARE_EQUAL(resReader.GetResourceKind(), resFlags[j].kind);
+            VERIFY_ARE_EQUAL(resReader.GetFlags(), static_cast<uint32_t>(resFlags[j].flag));
+          }
+        }
         else {
           IFTBOOLMSG(false, E_FAIL, "unknown function name");
         }
       }
-      VERIFY_IS_TRUE(resTableReader->GetNumResources() == 4);
+      VERIFY_ARE_EQUAL(resTableReader->GetNumResources(), 8);
       // This is validation test for DxilRuntimeReflection implemented on DxilRuntimeReflection.inl
       unique_ptr<DxilRuntimeReflection> pReflection(CreateDxilRuntimeReflection());
       VERIFY_IS_TRUE(pReflection->InitFromRDAT(pBlob->GetBufferPointer()));
       DXIL_LIBRARY_DESC lib_reflection = pReflection->GetLibraryReflection();
-      VERIFY_IS_TRUE(lib_reflection.NumFunctions == 3);
+      VERIFY_ARE_EQUAL(lib_reflection.NumFunctions, 4);
       for (uint32_t j = 0; j < 3; ++j) {
         DXIL_FUNCTION function = lib_reflection.pFunction[j];
         std::string cur_str = str;
@@ -764,14 +804,14 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
           uint64_t rawFlag = flag.GetShaderFlagsRaw();
           uint64_t featureFlag = static_cast<uint64_t>(function.FeatureInfo2) << 32;
           featureFlag |= static_cast<uint64_t>(function.FeatureInfo1);
-          VERIFY_IS_TRUE(featureFlag == rawFlag);
-          VERIFY_IS_TRUE(function.NumResources == 1);
-          VERIFY_IS_TRUE(function.NumFunctionDependencies == 0);
+          VERIFY_ARE_EQUAL(featureFlag, rawFlag);
+          VERIFY_ARE_EQUAL(function.NumResources, 1);
+          VERIFY_ARE_EQUAL(function.NumFunctionDependencies, 0);
           const DXIL_RESOURCE &resource = *function.Resources[0];
-          VERIFY_IS_TRUE(resource.Class == (uint32_t)hlsl::DXIL::ResourceClass::UAV);
-          VERIFY_IS_TRUE(resource.Kind == (uint32_t)hlsl::DXIL::ResourceKind::Texture1D);
+          VERIFY_ARE_EQUAL(resource.Class, (uint32_t)hlsl::DXIL::ResourceClass::UAV);
+          VERIFY_ARE_EQUAL(resource.Kind, (uint32_t)hlsl::DXIL::ResourceKind::Texture1D);
           std::wstring wName = resource.Name;
-          VERIFY_IS_TRUE(wName.compare(L"tex") == 0);
+          VERIFY_ARE_EQUAL(wName.compare(L"tex"), 0);
         }
         else if (cur_str.compare("function1") == 0) {
           hlsl::ShaderFlags flag;
@@ -779,9 +819,9 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
           uint64_t rawFlag = flag.GetShaderFlagsRaw();
           uint64_t featureFlag = static_cast<uint64_t>(function.FeatureInfo2) << 32;
           featureFlag |= static_cast<uint64_t>(function.FeatureInfo1);
-          VERIFY_IS_TRUE(featureFlag == rawFlag);
-          VERIFY_IS_TRUE(function.NumResources == 3);
-          VERIFY_IS_TRUE(function.NumFunctionDependencies == 0);
+          VERIFY_ARE_EQUAL(featureFlag, rawFlag);
+          VERIFY_ARE_EQUAL(function.NumResources, 3);
+          VERIFY_ARE_EQUAL(function.NumFunctionDependencies, 0);
           std::unordered_set<std::wstring> stringSet = { L"$Globals", L"b_buf", L"tex2" };
           for (uint32_t j = 0; j < 3; ++j) {
             const DXIL_RESOURCE &resource = *function.Resources[j];
@@ -790,18 +830,38 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
           }
         }
         else if (cur_str.compare("function2") == 0) {
-          VERIFY_IS_TRUE(function.FeatureInfo1 == 0);
-          VERIFY_IS_TRUE(function.FeatureInfo2 == 0);
-          VERIFY_IS_TRUE(function.NumResources == 0);
-          VERIFY_IS_TRUE(function.NumFunctionDependencies == 1);
+          VERIFY_ARE_EQUAL(function.FeatureInfo1, 0);
+          VERIFY_ARE_EQUAL(function.FeatureInfo2, 0);
+          VERIFY_ARE_EQUAL(function.NumResources, 0);
+          VERIFY_ARE_EQUAL(function.NumFunctionDependencies, 1);
           std::wstring dependency = function.FunctionDependencies[0];
           VERIFY_IS_TRUE(dependency.find(L"function_import") != std::wstring::npos);
         }
+        else if (cur_str.compare("function3") == 0) {
+          VERIFY_ARE_EQUAL(function.FeatureInfo1, 0);
+          VERIFY_ARE_EQUAL(function.FeatureInfo2, 0);
+          VERIFY_ARE_EQUAL(function.NumResources, numResFlagCheck);
+          VERIFY_ARE_EQUAL(function.NumFunctionDependencies, 0);
+          for (unsigned i = 0; i < function.NumResources; ++i) {
+            const DXIL_RESOURCE *res = function.Resources[i];
+            VERIFY_ARE_EQUAL(res->Class, static_cast<uint32_t>(hlsl::DXIL::ResourceClass::UAV));
+            unsigned j = 0;
+            for (; j < numResFlagCheck; ++j) {
+              CA2W WName(resFlags[j].name.c_str());
+              std::wstring compareName(WName);
+              if (compareName.compare(res->Name) == 0)
+                break;
+            }
+            VERIFY_IS_LESS_THAN(j, numResFlagCheck);
+            VERIFY_ARE_EQUAL(res->Kind, static_cast<uint32_t>(resFlags[j].kind));
+            VERIFY_ARE_EQUAL(res->Flags, static_cast<uint32_t>(resFlags[j].flag));
+          }
+        }
         else {
           IFTBOOLMSG(false, E_FAIL, "unknown function name");
         }
       }
-      VERIFY_IS_TRUE(lib_reflection.NumResources == 4);
+      VERIFY_IS_TRUE(lib_reflection.NumResources == 8);
     }
   }
   IFTBOOLMSG(blobFound, E_FAIL, "failed to find RDAT blob after compiling");