Explorar o código

More RootSignature validation fixes.

- Implement RootSignatureHandle::Deserialize()
- Fix/clean up validation usage
Tex Riddell %!s(int64=8) %!d(string=hai) anos
pai
achega
7c2c2c5d54

+ 1 - 1
include/dxc/HLSL/DxilRootSignature.h

@@ -348,7 +348,7 @@ void SerializeRootSignature(const DxilVersionedRootSignatureDesc *pRootSignature
 
 void DeserializeRootSignature(__in_bcount(SrcDataSizeInBytes) const void *pSrcData,
                               __in uint32_t SrcDataSizeInBytes,
-                              __out DxilVersionedRootSignatureDesc **ppRootSignature);
+                              __out const DxilVersionedRootSignatureDesc **ppRootSignature);
 
 // Takes PSV - pipeline state validation data, not shader container.
 bool VerifyRootSignatureWithShaderPSV(__in const DxilVersionedRootSignatureDesc *pDesc,

+ 9 - 4
lib/HLSL/DxilRootSignature.cpp

@@ -87,6 +87,11 @@ void RootSignatureHandle::EnsureSerializedAvailable() {
   }
 }
 
+void RootSignatureHandle::Deserialize() {
+  DXASSERT_NOMSG(m_pSerialized && !m_pDesc);
+  DeserializeRootSignature((uint8_t*)m_pSerialized->GetBufferPointer(), (uint32_t)m_pSerialized->GetBufferSize(), &m_pDesc);
+}
+
 void RootSignatureHandle::LoadSerialized(const uint8_t *pData,
                                          unsigned length) {
   DXASSERT_NOMSG(IsEmpty());
@@ -1431,7 +1436,7 @@ CVersionedRootSignatureDeserializer::~CVersionedRootSignatureDeserializer() {
 
 void CVersionedRootSignatureDeserializer::Initialize(__in_bcount(SrcDataSizeInBytes) const void *pSrcData,
                                                      __in uint32_t SrcDataSizeInBytes) {
-  DxilVersionedRootSignatureDesc *pRootSignature = nullptr;
+  const DxilVersionedRootSignatureDesc *pRootSignature = nullptr;
   DeserializeRootSignature(pSrcData, SrcDataSizeInBytes, &pRootSignature);
 
   switch (pRootSignature->Version) {
@@ -1581,12 +1586,12 @@ void DeserializeRootSignatureTemplate(__in_bcount(SrcDataSizeInBytes) const void
 _Use_decl_annotations_
 void DeserializeRootSignature(const void *pSrcData,
                               uint32_t SrcDataSizeInBytes,
-                              DxilVersionedRootSignatureDesc **ppRootSignature) {
+                              const DxilVersionedRootSignatureDesc **ppRootSignature) {
   DxilVersionedRootSignatureDesc *pRootSignature = nullptr;
+  IFTBOOL(pSrcData != nullptr && SrcDataSizeInBytes != 0 && ppRootSignature != nullptr, E_INVALIDARG);
+  IFTBOOL(*ppRootSignature == nullptr, E_INVALIDARG);
   const char *pData = (const char *)pSrcData;
   IFTBOOL(pData + sizeof(uint32_t) < pData + SrcDataSizeInBytes, E_FAIL);
-  IFTBOOL(pSrcData != nullptr && SrcDataSizeInBytes != 0 && ppRootSignature != nullptr, E_FAIL);
-  *ppRootSignature = nullptr;
 
   DxilRootSignatureVersion Version = (DxilRootSignatureVersion)((uint32_t*)pData)[0];
 

+ 17 - 24
lib/HLSL/DxilValidation.cpp

@@ -4275,23 +4275,15 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
   // Validate Root Signature
   if (pPSVPart) {
     if (pRootSignaturePart) {
-      DxilVersionedRootSignatureDesc* pDesc = nullptr;
       try {
-        DeserializeRootSignature(GetDxilPartData(pRootSignaturePart), pRootSignaturePart->PartSize, &pDesc);
+        RootSignatureHandle RS;
+        RS.LoadSerialized((const uint8_t*)GetDxilPartData(pRootSignaturePart), pRootSignaturePart->PartSize);
+        RS.Deserialize();
+        IFTBOOL(VerifyRootSignatureWithShaderPSV(RS.GetDesc(),
+                                                  pDxilModule->GetShaderModel()->GetKind(),
+                                                  GetDxilPartData(pPSVPart), pPSVPart->PartSize,
+                                                  DiagStream), DXC_E_INCORRECT_ROOT_SIGNATURE);
       } catch (...) {
-        pDesc = nullptr;
-      }
-      if (pDesc) {
-        try {
-          IFTBOOL(VerifyRootSignatureWithShaderPSV(pDesc,
-                                                   pDxilModule->GetShaderModel()->GetKind(),
-                                                   GetDxilPartData(pPSVPart), pPSVPart->PartSize,
-                                                   DiagStream), DXC_E_INCORRECT_ROOT_SIGNATURE);
-        } catch (...) {
-          DeleteRootSignature(pDesc);
-          ValCtx.EmitError(ValidationRule::ContainerRootSignatureIncompatible);
-        }
-      } else {
         ValCtx.EmitError(ValidationRule::ContainerRootSignatureIncompatible);
       }
     }
@@ -4384,23 +4376,24 @@ HRESULT ValidateDxilBitcode(
 
   DxilModule &dxilModule = pModule->GetDxilModule();
   if (!dxilModule.GetRootSignature().IsEmpty()) {
-    const RootSignatureHandle &RS = dxilModule.GetRootSignature();
     unique_ptr<DxilPartWriter> pWriter(NewPSVWriter(dxilModule));
     DXASSERT_NOMSG(pWriter->size());
     unique_ptr<unsigned char[]> pPSVData(new unsigned char[pWriter->size()]);
-    DxilVersionedRootSignatureDesc* pDesc = nullptr;
-    if (!RS.GetDesc()) {
-      DeserializeRootSignature(RS.GetSerializedBytes(), RS.GetSerializedSize(), &pDesc);
-      if (!pDesc)
-        return DXC_E_INCORRECT_ROOT_SIGNATURE;
-    }
+    const DxilVersionedRootSignatureDesc* pDesc = dxilModule.GetRootSignature().GetDesc();
+    RootSignatureHandle RS;
     try {
-      IFTBOOL(VerifyRootSignatureWithShaderPSV(pDesc ? pDesc : RS.GetDesc(),
+      if (!pDesc) {
+        RS.Assign(nullptr, dxilModule.GetRootSignature().GetSerialized());
+        RS.Deserialize();
+        pDesc = RS.GetDesc();
+        if (!pDesc)
+          return DXC_E_INCORRECT_ROOT_SIGNATURE;
+      }
+      IFTBOOL(VerifyRootSignatureWithShaderPSV(pDesc,
                                                dxilModule.GetShaderModel()->GetKind(),
                                                pPSVData.get(), pWriter->size(),
                                                DiagStream), DXC_E_INCORRECT_ROOT_SIGNATURE);
     } catch (...) {
-      DeleteRootSignature(pDesc);
       return DXC_E_INCORRECT_ROOT_SIGNATURE;
     }
   }