Ver código fonte

Fix RootSignature Validation bugs.

Tex Riddell 8 anos atrás
pai
commit
6de3365aa2
2 arquivos alterados com 31 adições e 48 exclusões
  1. 23 40
      lib/HLSL/DxilRootSignature.cpp
  2. 8 8
      lib/HLSL/DxilValidation.cpp

+ 23 - 40
lib/HLSL/DxilRootSignature.cpp

@@ -24,6 +24,7 @@
 #include <algorithm>
 #include <utility>
 #include <vector>
+#include <set>
 
 using namespace llvm;
 using std::string;
@@ -242,41 +243,20 @@ unsigned SimpleSerializer::GetSize() {
 template <typename T>
 class CIntervalCollection {
 private:
-  std::vector<T> m_set;
+  std::set<T> m_set;
 public:
-  T* FindIntersectingInterval(const T &I) {
-    DXASSERT(m_set.size() < INT_MAX,
-             "else too many interval entries, and min<max check can undeflow");
-    int mid, min = 0, max = (int)m_set.size();
-    while (min < max) {
-      mid = (min + max) / 2;
-      T &R = m_set[mid];
-      int order = I.overlap(R);
-      if (order == 0) return &R;
-      if (order < 0)
-        max = mid - 1;
-      else
-        min = mid + 1;
-    }
+  const T* FindIntersectingInterval(const T &I) {
+    auto it = m_set.find(I);
+    if (it != m_set.end())
+      return &*it;
     return nullptr;
   }
   void Insert(const T& value) {
-    // Find the first element that is greater or equal to value.
-    auto it = std::lower_bound(m_set.begin(), m_set.end(), value);
-    if (it == m_set.end()) {
-      m_set.push_back(value);
-    }
-    else {
-      m_set.insert(it, value);
+    auto result = m_set.insert(value);
+    UNREFERENCED_PARAMETER(result);
 #if DBG
-      // Verify that the insertion didn't violate disjoint range assumptions.
-      for (size_t i = 1; i < m_set.size(); ++i) {
-        DXASSERT_NOMSG(m_set[i - 1].overlap(m_set[i]));
-        DXASSERT_NOMSG(m_set[i - 1].space < m_set[i].space ||
-                       m_set[i - 1].ub < m_set[i].lb);
-      }
+    DXASSERT(result.second, "otherwise interval collides with existing in collection");
 #endif
-    }
   }
 };
 
@@ -333,7 +313,7 @@ private:
     // Sort by space, then lower bound.
     bool operator<(const RegisterRange& other) const {
       return space < other.space ||
-        (space == other.space && lb < other.lb);
+        (space == other.space && ub < other.lb);
     }
     // Like a regular -1,0,1 comparison, but 0 indicates overlap.
     int overlap(const RegisterRange& other) const {
@@ -352,11 +332,11 @@ private:
                         unsigned NumRegisters, unsigned BaseRegister,
                         unsigned RegisterSpace, DiagnosticPrinter &DiagPrinter);
 
-  RegisterRange *FindCoveringInterval(DxilDescriptorRangeType RangeType,
-                                      DxilShaderVisibility VisType,
-                                      unsigned Num,
-                                      unsigned LB,
-                                      unsigned Space);
+  const RegisterRange *FindCoveringInterval(DxilDescriptorRangeType RangeType,
+                                            DxilShaderVisibility VisType,
+                                            unsigned Num,
+                                            unsigned LB,
+                                            unsigned Space);
 
   RegisterRanges &
   GetRanges(DxilShaderVisibility VisType, DxilDescriptorRangeType DescType) {
@@ -510,7 +490,7 @@ void RootSignatureVerifier::AddRegisterRange(unsigned iRP,
     }
   }
 
-  RegisterRange *pNode = nullptr;
+  const RegisterRange *pNode = nullptr;
   DxilShaderVisibility NodeVis = VisType;
   if (VisType == DxilShaderVisibility::All) {
     // Check for overlap with each visibility type.
@@ -581,7 +561,7 @@ void RootSignatureVerifier::AddRegisterRange(unsigned iRP,
   GetRanges(VisType, DescType).Insert(interval);
 }
 
-RootSignatureVerifier::RegisterRange *
+const RootSignatureVerifier::RegisterRange *
 RootSignatureVerifier::FindCoveringInterval(DxilDescriptorRangeType RangeType,
                                             DxilShaderVisibility VisType,
                                             unsigned Num,
@@ -591,7 +571,10 @@ RootSignatureVerifier::FindCoveringInterval(DxilDescriptorRangeType RangeType,
   RR.space = Space;
   RR.lb = LB;
   RR.ub = LB + Num - 1;
-  return GetRanges(VisType, RangeType).FindIntersectingInterval(RR);
+  const RootSignatureVerifier::RegisterRange *pRange = GetRanges(DxilShaderVisibility::All, RangeType).FindIntersectingInterval(RR);
+  if (!pRange && VisType != DxilShaderVisibility::All)
+    pRange = GetRanges(VisType, RangeType).FindIntersectingInterval(RR);
+  return pRange;
 }
 
 static DxilDescriptorRangeType GetRangeType(DxilRootParameterType RPT) {
@@ -770,7 +753,7 @@ void RootSignatureVerifier::VerifyShader(DxilShaderVisibility VisType,
                                          uint32_t PSVSize,
                                          DiagnosticPrinter &DiagPrinter) {
   DxilPipelineStateValidation PSV;
-  IFTBOOL(!PSV.InitFromPSV0(pPSVData, PSVSize), E_INVALIDARG);
+  IFTBOOL(PSV.InitFromPSV0(pPSVData, PSVSize), E_INVALIDARG);
 
   bool bShaderDeniedByRootSig = false;
   switch (VisType) {
@@ -1602,7 +1585,7 @@ void DeserializeRootSignature(const void *pSrcData,
   DxilVersionedRootSignatureDesc *pRootSignature = nullptr;
   const char *pData = (const char *)pSrcData;
   IFTBOOL(pData + sizeof(uint32_t) < pData + SrcDataSizeInBytes, E_FAIL);
-  IFTBOOL(pSrcData == nullptr || SrcDataSizeInBytes == 0 || ppRootSignature == nullptr, E_FAIL);
+  IFTBOOL(pSrcData != nullptr && SrcDataSizeInBytes != 0 && ppRootSignature != nullptr, E_FAIL);
   *ppRootSignature = nullptr;
 
   DxilRootSignatureVersion Version = (DxilRootSignatureVersion)((uint32_t*)pData)[0];

+ 8 - 8
lib/HLSL/DxilValidation.cpp

@@ -4283,10 +4283,10 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
       }
       if (pDesc) {
         try {
-          VerifyRootSignatureWithShaderPSV(pDesc,
-                                            pDxilModule->GetShaderModel()->GetKind(),
-                                            GetDxilPartData(pPSVPart), pPSVPart->PartSize,
-                                            DiagStream);
+          IFTBOOL(VerifyRootSignatureWithShaderPSV(pDesc,
+                                                   pDxilModule->GetShaderModel()->GetKind(),
+                                                   GetDxilPartData(pPSVPart), pPSVPart->PartSize,
+                                                   DiagStream), DXC_E_INCORRECT_ROOT_SIGNATURE);
         } catch (...) {
           DeleteRootSignature(pDesc);
           ValCtx.EmitError(ValidationRule::ContainerRootSignatureIncompatible);
@@ -4395,10 +4395,10 @@ HRESULT ValidateDxilBitcode(
         return DXC_E_INCORRECT_ROOT_SIGNATURE;
     }
     try {
-      VerifyRootSignatureWithShaderPSV(pDesc ? pDesc : RS.GetDesc(),
-                                       dxilModule.GetShaderModel()->GetKind(),
-                                       pPSVData.get(), pWriter->size(),
-                                       DiagStream);
+      IFTBOOL(VerifyRootSignatureWithShaderPSV(pDesc ? pDesc : RS.GetDesc(),
+                                               dxilModule.GetShaderModel()->GetKind(),
+                                               pPSVData.get(), pWriter->size(),
+                                               DiagStream), DXC_E_INCORRECT_ROOT_SIGNATURE);
     } catch (...) {
       DeleteRootSignature(pDesc);
       return DXC_E_INCORRECT_ROOT_SIGNATURE;