Преглед изворни кода

Support valiadation/signing of root signature target output (#2560)

Tex Riddell пре 5 година
родитељ
комит
c54dd6b57c

+ 6 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -678,6 +678,12 @@ public:
 
           pOutputBlob.Release();
           IFT(pContainerStream.QueryInterface(&pOutputBlob));
+          if (!opts.DisableValidation) {
+            CComPtr<IDxcBlobEncoding> pValErrors;
+            // Validation failure communicated through diagnostic error
+            dxcutil::ValidateRootSignatureInContainer(
+              pOutputBlob, &compiler.getDiagnostics());
+          }
         }
       }
       // SPIRV change starts

+ 31 - 0
tools/clang/tools/dxcompiler/dxcutil.cpp

@@ -267,6 +267,37 @@ HRESULT ValidateAndAssembleToContainer(AssembleInputs &inputs) {
   return valHR;
 }
 
+HRESULT ValidateRootSignatureInContainer(
+    IDxcBlob *pRootSigContainer, clang::DiagnosticsEngine *pDiag) {
+  HRESULT valHR = S_OK;
+  CComPtr<IDxcValidator> pValidator;
+  CComPtr<IDxcOperationResult> pValResult;
+  CreateValidator(pValidator);
+  IFT(pValidator->Validate(pRootSigContainer,
+        DxcValidatorFlags_RootSignatureOnly | DxcValidatorFlags_InPlaceEdit,
+        &pValResult));
+  IFT(pValResult->GetStatus(&valHR));
+  if (pDiag) {
+    if (FAILED(valHR)) {
+      CComPtr<IDxcBlobEncoding> pErrors;
+      IFT(pValResult->GetErrorBuffer(&pErrors));
+#ifdef DBG
+      UINT32 codePage = CP_ACP;
+      BOOL known = FALSE;
+      IFT(pErrors->GetEncoding(&known, &codePage));
+      DXASSERT_NOMSG(!known || codePage == CP_ACP || codePage == CP_UTF8);
+#endif
+      StringRef errRef((const char *)pErrors->GetBufferPointer(),
+                       pErrors->GetBufferSize());
+      unsigned DiagID = pDiag->getCustomDiagID(
+        clang::DiagnosticsEngine::Error,
+        "root signature validation errors\r\n%0");
+      pDiag->Report(DiagID) << errRef;
+    }
+  }
+  return valHR;
+}
+
 void CreateOperationResultFromOutputs(
     IDxcBlob *pResultBlob, CComPtr<IStream> &pErrorStream,
     const std::string &warnings, bool hasErrorOccurred,

+ 2 - 0
tools/clang/tools/dxcompiler/dxcutil.h

@@ -61,6 +61,8 @@ struct AssembleInputs {
 };
 
 HRESULT ValidateAndAssembleToContainer(AssembleInputs &inputs);
+HRESULT ValidateRootSignatureInContainer(
+    IDxcBlob *pRootSigContainer, clang::DiagnosticsEngine *pDiag = nullptr);
 void GetValidatorVersion(unsigned *pMajor, unsigned *pMinor);
 void AssembleToContainer(AssembleInputs &inputs);
 HRESULT Disassemble(IDxcBlob *pProgram, llvm::raw_string_ostream &Stream);

+ 16 - 7
tools/clang/tools/dxcompiler/dxcvalidator.cpp

@@ -259,18 +259,27 @@ HRESULT DxcValidator::RunRootSignatureValidation(
   const DxilProgramHeader *pProgramHeader = GetDxilProgramHeader(pDxilContainer, DFCC_DXIL);
   const DxilPartHeader *pPSVPart = GetDxilPartByType(pDxilContainer, DFCC_PipelineStateValidation);
   const DxilPartHeader *pRSPart = GetDxilPartByType(pDxilContainer, DFCC_RootSignature);
-  IFRBOOL(pPSVPart && pRSPart, DXC_E_MISSING_PART);
+  IFRBOOL(pRSPart, DXC_E_MISSING_PART);
+  if (pProgramHeader) {
+    // Container has shader part, make sure we have PSV.
+    IFRBOOL(pPSVPart, DXC_E_MISSING_PART);
+  }
   try {
     RootSignatureHandle RSH;
     RSH.LoadSerialized((const uint8_t*)GetDxilPartData(pRSPart), pRSPart->PartSize);
     RSH.Deserialize();
     raw_stream_ostream DiagStream(pDiagStream);
-    IFRBOOL(VerifyRootSignatureWithShaderPSV(RSH.GetDesc(),
-                                             GetVersionShaderType(pProgramHeader->ProgramVersion),
-                                             GetDxilPartData(pPSVPart),
-                                             pPSVPart->PartSize,
-                                             DiagStream),
-      DXC_E_INCORRECT_ROOT_SIGNATURE);
+    if (pProgramHeader) {
+      IFRBOOL(VerifyRootSignatureWithShaderPSV(RSH.GetDesc(),
+                                               GetVersionShaderType(pProgramHeader->ProgramVersion),
+                                               GetDxilPartData(pPSVPart),
+                                               pPSVPart->PartSize,
+                                               DiagStream),
+              DXC_E_INCORRECT_ROOT_SIGNATURE);
+    } else {
+      IFRBOOL(VerifyRootSignature(RSH.GetDesc(), DiagStream, false),
+              DXC_E_INCORRECT_ROOT_SIGNATURE);
+    }
   } catch(...) {
     return DXC_E_IR_VERIFICATION_FAILED;
   }

+ 24 - 5
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -1532,20 +1532,39 @@ TEST_F(CompilerTest, CompileWithRootSignatureThenStripRootSignature) {
   hlsl::DxilPartHeader *pPartHeader = hlsl::GetDxilPartByType(
       pContainerHeader, hlsl::DxilFourCC::DFCC_RootSignature);
   VERIFY_IS_NOT_NULL(pPartHeader);
+  pResult.Release();
   
   // Remove root signature
-  CComPtr<IDxcBlob> pNewProgram;
+  CComPtr<IDxcBlob> pProgramRootSigRemoved;
   CComPtr<IDxcContainerBuilder> pBuilder;
   VERIFY_SUCCEEDED(CreateContainerBuilder(&pBuilder));
   VERIFY_SUCCEEDED(pBuilder->Load(pProgram));
   VERIFY_SUCCEEDED(pBuilder->RemovePart(hlsl::DxilFourCC::DFCC_RootSignature));
-  pResult.Release();
   VERIFY_SUCCEEDED(pBuilder->SerializeContainer(&pResult));
-  VERIFY_SUCCEEDED(pResult->GetResult(&pNewProgram));
-  pContainerHeader = (hlsl::DxilContainerHeader *)(pNewProgram->GetBufferPointer());
+  VERIFY_SUCCEEDED(pResult->GetResult(&pProgramRootSigRemoved));
+  pContainerHeader = (hlsl::DxilContainerHeader *)(pProgramRootSigRemoved->GetBufferPointer());
+  hlsl::DxilPartHeader *pPartHeaderShouldBeNull = hlsl::GetDxilPartByType(pContainerHeader,
+                                        hlsl::DxilFourCC::DFCC_RootSignature);
+  VERIFY_IS_NULL(pPartHeaderShouldBeNull);
+  pBuilder.Release();
+  pResult.Release();
+
+  // Add root signature back
+  CComPtr<IDxcBlobEncoding> pRootSignatureBlob;
+  CComPtr<IDxcLibrary> pLibrary;
+  CComPtr<IDxcBlob> pProgramRootSigAdded;
+  VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
+  VERIFY_SUCCEEDED(pLibrary->CreateBlobWithEncodingFromPinned(
+    hlsl::GetDxilPartData(pPartHeader), pPartHeader->PartSize, 0, &pRootSignatureBlob));
+  VERIFY_SUCCEEDED(CreateContainerBuilder(&pBuilder));
+  VERIFY_SUCCEEDED(pBuilder->Load(pProgramRootSigRemoved));
+  pBuilder->AddPart(hlsl::DxilFourCC::DFCC_RootSignature, pRootSignatureBlob);
+  pBuilder->SerializeContainer(&pResult);
+  VERIFY_SUCCEEDED(pResult->GetResult(&pProgramRootSigAdded));
+  pContainerHeader = (hlsl::DxilContainerHeader *)(pProgramRootSigAdded->GetBufferPointer());
   pPartHeader = hlsl::GetDxilPartByType(pContainerHeader,
                                         hlsl::DxilFourCC::DFCC_RootSignature);
-  VERIFY_IS_NULL(pPartHeader);
+  VERIFY_IS_NOT_NULL(pPartHeader);
 }
 #endif // Container builder unsupported
 

+ 23 - 5
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -292,6 +292,8 @@ public:
   TEST_METHOD(AmplificationGreaterThanMaxZ)
   TEST_METHOD(AmplificationGreaterThanMaxXYZ)
 
+  TEST_METHOD(ValidateRootSigContainer)
+
   dxc::DxcDllSupport m_dllSupport;
   VersionSupportInfo m_ver;
 
@@ -305,11 +307,10 @@ public:
     }
   }
 
-  void CheckValidationMsgs(IDxcBlob *pBlob, llvm::ArrayRef<LPCSTR> pErrorMsgs, bool bRegex = false) {
+  void CheckValidationMsgs(IDxcBlob *pBlob, llvm::ArrayRef<LPCSTR> pErrorMsgs, bool bRegex = false, UINT32 Flags = DxcValidatorFlags_Default) {
     CComPtr<IDxcValidator> pValidator;
     CComPtr<IDxcOperationResult> pResult;
 
-    UINT32 Flags = DxcValidatorFlags_Default;
     if (!IsDxilContainerLike(pBlob->GetBufferPointer(), pBlob->GetBufferSize())) {
       // Validation of raw bitcode as opposed to DxilContainer is not supported through DXIL.dll
       if (!m_ver.m_InternalValidator) {
@@ -325,12 +326,12 @@ public:
     CheckOperationResultMsgs(pResult, pErrorMsgs, false, bRegex);
   }
 
-  void CheckValidationMsgs(const char *pBlob, size_t blobSize, llvm::ArrayRef<LPCSTR> pErrorMsgs, bool bRegex = false) {
+  void CheckValidationMsgs(const char *pBlob, size_t blobSize, llvm::ArrayRef<LPCSTR> pErrorMsgs, bool bRegex = false, UINT32 Flags = DxcValidatorFlags_Default) {
     CComPtr<IDxcLibrary> pLibrary;
     CComPtr<IDxcBlobEncoding> pBlobEncoding; // Encoding doesn't actually matter, it's binary.
     VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
     VERIFY_SUCCEEDED(pLibrary->CreateBlobWithEncodingFromPinned(pBlob, blobSize, CP_UTF8, &pBlobEncoding));
-    CheckValidationMsgs(pBlobEncoding, pErrorMsgs, bRegex);
+    CheckValidationMsgs(pBlobEncoding, pErrorMsgs, bRegex, Flags);
   }
 
   bool CompileSource(IDxcBlobEncoding *pSource, LPCSTR pShaderModel,
@@ -3781,4 +3782,21 @@ TEST_F(ValidationTest, AmplificationGreaterThanMaxXYZ) {
                           "= !{i32 32, i32 1, i32 1}",
                           "= !{i32 32, i32 2, i32 4}",
                           "Declared Thread Group Count 256 (X*Y*Z) is beyond the valid maximum of 128");
-}
+}
+
+TEST_F(ValidationTest, ValidateRootSigContainer) {
+  // Validation of root signature-only container not supported until 1.5
+  if (m_ver.SkipDxilVersion(1, 5)) return;
+
+  LPCSTR pSource = "#define main \"DescriptorTable(UAV(u0))\"";
+  CComPtr<IDxcBlob> pObject;
+  if (!CompileSource(pSource, "rootsig_1_0", &pObject))
+    return;
+  CheckValidationMsgs(pObject, {}, false,
+    DxcValidatorFlags_RootSignatureOnly | DxcValidatorFlags_InPlaceEdit);
+  pObject.Release();
+  if (!CompileSource(pSource, "rootsig_1_1", &pObject))
+    return;
+  CheckValidationMsgs(pObject, {}, false,
+    DxcValidatorFlags_RootSignatureOnly | DxcValidatorFlags_InPlaceEdit);
+}