Quellcode durchsuchen

Save root sig for entry for lib. (#3669)

* Save root sig for entry for lib.
Xiang Li vor 4 Jahren
Ursprung
Commit
5fa34d1fd6

+ 9 - 2
include/dxc/DXIL/DxilFunctionProps.h

@@ -21,7 +21,9 @@ class Constant;
 namespace hlsl {
 struct DxilFunctionProps {
   DxilFunctionProps() {
-    memset(this, 0, sizeof(DxilFunctionProps));
+    memset(&ShaderProps, 0, sizeof(ShaderProps));
+    shaderKind = DXIL::ShaderKind::Invalid;
+    waveSize = 0;
   }
   union {
     // Compute shader.
@@ -83,7 +85,12 @@ struct DxilFunctionProps {
   } ShaderProps;
   DXIL::ShaderKind shaderKind;
   // WaveSize is currently allowed only on compute shaders, but could be supported on other shader types in the future
-  unsigned waveSize; 
+  unsigned waveSize;
+  // Save root signature for lib profile entry.
+  std::vector<uint8_t> serializedRootSignature;
+  void SetSerializedRootSignature(const uint8_t *pData, unsigned size) {
+    serializedRootSignature.assign(pData, pData+size);
+  }
 
   // TODO: Should we have an unmangled name here for ray tracing shaders?
   bool IsPS() const     { return shaderKind == DXIL::ShaderKind::Pixel; }

+ 1 - 0
include/dxc/DXIL/DxilMetadataHelper.h

@@ -272,6 +272,7 @@ public:
   static const unsigned kDxilMSStateTag         = 9;
   static const unsigned kDxilASStateTag         = 10;
   static const unsigned kDxilWaveSizeTag        = 11;
+  static const unsigned kDxilEntryRootSigTag    = 12;
 
   // GSState.
   static const unsigned kDxilGSStateNumFields               = 5;

+ 45 - 20
lib/DXIL/DxilMetadataHelper.cpp

@@ -41,6 +41,37 @@ using std::string;
 using std::vector;
 using std::unique_ptr;
 
+namespace {
+void LoadSerializedRootSignature(MDNode *pNode,
+                                 std::vector<uint8_t> &SerializedRootSignature,
+                                 LLVMContext &Ctx) {
+  IFTBOOL(pNode->getNumOperands() == 1, DXC_E_INCORRECT_DXIL_METADATA);
+  const MDOperand &MDO = pNode->getOperand(0);
+
+  const ConstantAsMetadata *pMetaData = dyn_cast<ConstantAsMetadata>(MDO.get());
+  IFTBOOL(pMetaData != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  const ConstantDataArray *pData =
+      dyn_cast<ConstantDataArray>(pMetaData->getValue());
+  IFTBOOL(pData != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
+  IFTBOOL(pData->getElementType() == Type::getInt8Ty(Ctx),
+          DXC_E_INCORRECT_DXIL_METADATA);
+
+  SerializedRootSignature.assign(pData->getRawDataValues().begin(),
+                                 pData->getRawDataValues().end());
+}
+
+MDNode *
+EmitSerializedRootSignature(const std::vector<uint8_t> &SerializedRootSignature,
+                            LLVMContext &Ctx) {
+  if (SerializedRootSignature.empty())
+    return nullptr;
+  Constant *V = llvm::ConstantDataArray::get(
+      Ctx, llvm::ArrayRef<uint8_t>(SerializedRootSignature.data(),
+                                   SerializedRootSignature.size()));
+  return MDNode::get(Ctx, {ConstantAsMetadata::get(V)});
+}
+
+} // namespace
 
 namespace hlsl {
 
@@ -386,14 +417,12 @@ void DxilMDHelper::EmitRootSignature(
     return;
   }
 
-  Constant *V = llvm::ConstantDataArray::get(
-      m_Ctx, llvm::ArrayRef<uint8_t>(SerializedRootSignature.data(),
-                                     SerializedRootSignature.size()));
+  MDNode *Node = EmitSerializedRootSignature(SerializedRootSignature, m_Ctx);
 
   NamedMDNode *pRootSignatureNamedMD = m_pModule->getNamedMetadata(kDxilRootSignatureMDName);
   IFTBOOL(pRootSignatureNamedMD == nullptr, DXC_E_INCORRECT_DXIL_METADATA);
   pRootSignatureNamedMD = m_pModule->getOrInsertNamedMetadata(kDxilRootSignatureMDName);
-  pRootSignatureNamedMD->addOperand(MDNode::get(m_Ctx, {ConstantAsMetadata::get(V)}));
+  pRootSignatureNamedMD->addOperand(Node);
   return ;
 }
 
@@ -447,22 +476,7 @@ void DxilMDHelper::LoadRootSignature(std::vector<uint8_t> &SerializedRootSignatu
   IFTBOOL(pRootSignatureNamedMD->getNumOperands() == 1, DXC_E_INCORRECT_DXIL_METADATA);
 
   MDNode *pNode = pRootSignatureNamedMD->getOperand(0);
-  IFTBOOL(pNode->getNumOperands() == 1, DXC_E_INCORRECT_DXIL_METADATA);
-  const MDOperand &MDO = pNode->getOperand(0);
-
-  const ConstantAsMetadata *pMetaData = dyn_cast<ConstantAsMetadata>(MDO.get());
-  IFTBOOL(pMetaData != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
-  const ConstantDataArray *pData =
-      dyn_cast<ConstantDataArray>(pMetaData->getValue());
-  IFTBOOL(pData != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
-  IFTBOOL(pData->getElementType() == Type::getInt8Ty(m_Ctx),
-          DXC_E_INCORRECT_DXIL_METADATA);
-
-  SerializedRootSignature.clear();
-  unsigned size = pData->getRawDataValues().size();
-  SerializedRootSignature.resize(size);
-  memcpy(SerializedRootSignature.data(),
-         (const uint8_t *)pData->getRawDataValues().begin(), size);
+  LoadSerializedRootSignature(pNode, SerializedRootSignature, m_Ctx);
 }
 
 static const MDTuple *CastToTupleOrNull(const MDOperand &MDO) {
@@ -1483,6 +1497,13 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
         MDNode::get(m_Ctx, {Uint32ToConstMD(autoBindingSpace)}));
   }
 
+  if (!props.serializedRootSignature.empty() &&
+      DXIL::CompareVersions(m_MinValMajor, m_MinValMinor, 1, 6) > 0) {
+    MDVals.emplace_back(Uint32ToConstMD(DxilMDHelper::kDxilEntryRootSigTag));
+    MDVals.emplace_back(
+        EmitSerializedRootSignature(props.serializedRootSignature, m_Ctx));
+  }
+
   if (!MDVals.empty())
     return MDNode::get(m_Ctx, MDVals);
   else
@@ -1606,6 +1627,10 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
       MDNode *pNode = cast<MDNode>(MDO.get());
       props.waveSize = ConstMDToUint32(pNode->getOperand(0));
     } break;
+    case DxilMDHelper::kDxilEntryRootSigTag: {
+      MDNode *pNode = cast<MDNode>(MDO.get());
+      LoadSerializedRootSignature(pNode, props.serializedRootSignature, m_Ctx);
+    } break;
     default:
       DXASSERT(false, "Unknown extended shader properties tag");
       m_bExtraMetadata = true;

+ 0 - 2
lib/DXIL/DxilModule.cpp

@@ -1309,8 +1309,6 @@ void DxilModule::UpdateValidatorVersionMetadata() {
 }
 
 void DxilModule::ResetSerializedRootSignature(std::vector<uint8_t> &Value) {
-  m_SerializedRootSignature.clear();
-  m_SerializedRootSignature.reserve(Value.size());
   m_SerializedRootSignature.assign(Value.begin(), Value.end());
 }
 

+ 6 - 0
lib/HLSL/DxilLinker.cpp

@@ -788,6 +788,12 @@ DxilLinkJob::Link(std::pair<DxilFunctionLinkInfo *, DxilLib *> &entryLinkPair,
     if (newPatchConstantFunc->hasFnAttribute(llvm::Attribute::AlwaysInline))
       newPatchConstantFunc->removeFnAttr(llvm::Attribute::AlwaysInline);
   }
+
+  // Set root sig if exist.
+  if (!props.serializedRootSignature.empty()) {
+    DM.ResetSerializedRootSignature(props.serializedRootSignature);
+    props.serializedRootSignature.clear();
+  }
   // Set EntryProps
   DM.SetShaderProperties(&props);
 

+ 1 - 3
lib/HLSL/HLModule.cpp

@@ -322,9 +322,7 @@ std::vector<uint8_t> &HLModule::GetSerializedRootSignature() {
 }
 
 void HLModule::SetSerializedRootSignature(const uint8_t *pData, unsigned size) {
-  m_SerializedRootSignature.clear();
-  m_SerializedRootSignature.resize(size);
-  memcpy(m_SerializedRootSignature.data(), pData, size);
+  m_SerializedRootSignature.assign(pData, pData+size);
 }
 
 DxilTypeSystem &HLModule::GetTypeSystem() {

+ 1 - 2
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -5775,8 +5775,7 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
   if (m_pHLModule->HasDxilFunctionProps(F)) {
     DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F);
     std::unique_ptr<DxilFunctionProps> flatFuncProps = llvm::make_unique<DxilFunctionProps>();
-    flatFuncProps->shaderKind = funcProps.shaderKind;
-    flatFuncProps->ShaderProps = funcProps.ShaderProps;
+    *flatFuncProps = funcProps;
     m_pHLModule->AddDxilFunctionProps(flatF, flatFuncProps);
     if (funcProps.shaderKind == ShaderModel::Kind::Vertex) {
       auto &VS = funcProps.ShaderProps.VS;

+ 30 - 13
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -179,8 +179,8 @@ private:
                               QualType Type, QualType SrcType,
                               llvm::Type *Ty);
 
-  void EmitHLSLRootSignature(CodeGenFunction &CGF, HLSLRootSignatureAttr *RSA,
-                             llvm::Function *Fn) override;
+  void EmitHLSLRootSignature(HLSLRootSignatureAttr *RSA,
+                             Function *Fn, DxilFunctionProps &props);
 
   void CheckParameterAnnotation(SourceLocation SLoc,
                                 const DxilParameterAnnotation &paramInfo,
@@ -2325,6 +2325,12 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     }
   }
 
+  // Only parse root signature for entry function.
+  if (HLSLRootSignatureAttr *RSA = FD->getAttr<HLSLRootSignatureAttr>()) {
+    if (isExportedEntry || isEntry)
+      EmitHLSLRootSignature(RSA, F, *funcProps);
+  }
+
   // Only add functionProps when exist.
   if (isExportedEntry || isEntry)
     m_pHLModule->AddDxilFunctionProps(F, funcProps);
@@ -5744,22 +5750,33 @@ void CGMSHLSLRuntime::EmitHLSLFlatConversion(CodeGenFunction &CGF,
   }
 }
 
-void CGMSHLSLRuntime::EmitHLSLRootSignature(CodeGenFunction &CGF,
-                                            HLSLRootSignatureAttr *RSA,
-                                            Function *Fn) {
-  // Only parse root signature for entry function.
-  if (Fn != Entry.Func)
-    return;
-
+void CGMSHLSLRuntime::EmitHLSLRootSignature(HLSLRootSignatureAttr *RSA,
+                                            Function *Fn,
+                                            DxilFunctionProps &props) {
   StringRef StrRef = RSA->getSignatureName();
-  DiagnosticsEngine &Diags = CGF.getContext().getDiagnostics();
+  DiagnosticsEngine &Diags = CGM.getDiags();
   SourceLocation SLoc = RSA->getLocation();
   RootSignatureHandle RootSigHandle;
-  clang::CompileRootSignature(StrRef, Diags, SLoc, rootSigVer, DxilRootSignatureCompilationFlags::GlobalRootSignature, &RootSigHandle);
+  clang::CompileRootSignature(
+      StrRef, Diags, SLoc, rootSigVer,
+      DxilRootSignatureCompilationFlags::GlobalRootSignature, &RootSigHandle);
   if (!RootSigHandle.IsEmpty()) {
     RootSigHandle.EnsureSerializedAvailable();
-    m_pHLModule->SetSerializedRootSignature(RootSigHandle.GetSerializedBytes(),
-                                            RootSigHandle.GetSerializedSize());
+    if (!m_bIsLib) {
+      m_pHLModule->SetSerializedRootSignature(
+          RootSigHandle.GetSerializedBytes(),
+          RootSigHandle.GetSerializedSize());
+    } else {
+      if (!props.IsRay()) {
+        props.SetSerializedRootSignature(RootSigHandle.GetSerializedBytes(),
+                                         RootSigHandle.GetSerializedSize());
+      } else {
+        unsigned DiagID = Diags.getCustomDiagID(
+            DiagnosticsEngine::Error, "root signature attribute not supported "
+                                      "for raytracing entry functions");
+        Diags.Report(RSA->getLocation(), DiagID);
+      }
+    }
   }
 }
 

+ 0 - 3
tools/clang/lib/CodeGen/CGHLSLRuntime.h

@@ -120,9 +120,6 @@ public:
                                    clang::QualType SrcTy,
                                    llvm::Value *DestPtr,
                                    clang::QualType DestTy) = 0;
-  virtual void EmitHLSLRootSignature(CodeGenFunction &CGF,
-                                     clang::HLSLRootSignatureAttr *RSA,
-                                     llvm::Function *Fn) = 0;
   virtual llvm::Value *EmitHLSLLiteralCast(CodeGenFunction &CGF, llvm::Value *Src, clang::QualType SrcType,
                                                clang::QualType DstType) = 0;
 

+ 0 - 6
tools/clang/lib/CodeGen/CodeGenFunction.cpp

@@ -866,12 +866,6 @@ void CodeGenFunction::GenerateCode(GlobalDecl GD, llvm::Function *Fn,
   FunctionArgList Args;
   QualType ResTy = FD->getReturnType();
 
-  // HLSL Change Start - emit root signature associated with function
-  if (HLSLRootSignatureAttr *RSA = FD->getAttr<HLSLRootSignatureAttr>()) {
-    CGM.getHLSLRuntime().EmitHLSLRootSignature(*this, RSA, Fn);
-  }
-  // HLSL Change Ends - emit root signature associated with function
-
   CurGD = GD;
   const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD);
   if (MD && MD->isInstance()) {

+ 29 - 0
tools/clang/test/CodeGenHLSL/linker/link_with_root_sig.hlsl

@@ -0,0 +1,29 @@
+
+
+
+
+#define     RS \
+    RootSignature\
+    (\
+       "RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT)"\
+    )
+
+struct Vertex
+{
+    float4 position     : POSITION0;
+    float4 color        : COLOR0;
+};
+
+struct Interpolants
+{
+    float4 position : SV_POSITION0;
+    float4 color    : COLOR0;
+};
+
+
+[shader("vertex")]
+[RS]
+Interpolants vs_main( Vertex In )
+{
+    return In;
+}

+ 13 - 0
tools/clang/test/HLSLFileCheck/hlsl/diagnostics/errors/root_sig_on_raytracing_entry.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -T lib_6_6 %s | FileCheck %s
+
+// CHECK:error: root signature attribute not supported for raytracing entry functions
+
+
+RWStructuredBuffer<int64_t> myBuf : register(u0);
+
+[shader("raygeneration")]
+[RootSignature("UAV(u0)")]
+void RGInt64OnDescriptorHeapIndex()
+{
+    InterlockedAdd(myBuf[0], 1);
+}

+ 0 - 1
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/atomic/atomic_i64_root_resource.hlsl

@@ -8,7 +8,6 @@
 RWStructuredBuffer<int64_t> myBuf : register(u0);
 
 [shader("raygeneration")]
-[RootSignature("UAV(u0)")]
 void RGInt64OnDescriptorHeapIndex()
 {
     InterlockedAdd(myBuf[0], 1);

+ 1 - 1
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/createHandleFromHeap/dynamic_res_global_for_lib.hlsl

@@ -23,7 +23,7 @@ static float x = ID + 3;
 //  static Buffer<float> g_bufs[2] = {ResourceDescriptorHeap[ID+2], ResourceDescriptorHeap[ID+3]};
 
 [NumThreads(1, 1, 1)]
-[RootSignature("RootFlags(CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED | SAMPLER_HEAP_DIRECTLY_INDEXED), RootConstants(num32BitConstants=1, b0))")]
+[RootSignature("RootFlags(CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED | SAMPLER_HEAP_DIRECTLY_INDEXED), RootConstants(num32BitConstants=1, b0)")]
 void csmain(uint ix : SV_GroupIndex)
 {
   g_result[ix] = g_rawBuf.Load<float>(ix);// + g_bufs[0].Load(ix);

+ 1 - 1
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/createHandleFromHeap/dynamic_res_global_for_lib2.hlsl

@@ -26,7 +26,7 @@ static float x = ID + 3;
 //  static Buffer<float> g_bufs[2] = {ResourceDescriptorHeap[ID+2], ResourceDescriptorHeap[ID+3]};
 
 [NumThreads(1, 1, 1)]
-[RootSignature("RootFlags(CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED | SAMPLER_HEAP_DIRECTLY_INDEXED), RootConstants(num32BitConstants=1, b0))")]
+[RootSignature("RootFlags(CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED | SAMPLER_HEAP_DIRECTLY_INDEXED), RootConstants(num32BitConstants=1, b0)")]
 void csmain(uint ix : SV_GroupIndex)
 {
   g_result[ix] = g_rawBuf.Load<float>(ix);// + g_bufs[0].Load(ix);

+ 64 - 4
tools/clang/unittests/HLSL/LinkerTest.cpp

@@ -22,6 +22,7 @@
 #include "dxc/Test/HlslTestUtils.h"
 #include "dxc/Test/DxcTestUtils.h"
 #include "dxc/dxcapi.h"
+#include "dxc/DxilContainer/DxilContainer.h"
 
 using namespace std;
 using namespace hlsl;
@@ -67,6 +68,7 @@ public:
   TEST_METHOD(RunLinkWithTempReg);
   TEST_METHOD(RunLinkToLibWithGlobalCtor);
   TEST_METHOD(LinkSm63ToSm66);
+  TEST_METHOD(RunLinkWithRootSig);
 
 
   dxc::DxcDllSupport m_dllSupport;
@@ -77,9 +79,9 @@ public:
         m_dllSupport.CreateInstance(CLSID_DxcLinker, pResultLinker));
   }
 
-  void CompileLib(LPCWSTR filename, IDxcBlob **pResultBlob,
-                  llvm::ArrayRef<LPCWSTR> pArguments = {},
-                  LPCWSTR pShaderTarget = L"lib_6_x") {
+  void Compile(LPCWSTR filename, IDxcBlob **pResultBlob,
+               llvm::ArrayRef<LPCWSTR> pArguments = {}, LPCWSTR pEntry = L"",
+               LPCWSTR pShaderTarget = L"lib_6_x") {
     std::wstring fullPath = hlsl_test::GetPathToHlslDataFile(filename);
     CComPtr<IDxcBlobEncoding> pSource;
     CComPtr<IDxcLibrary> pLibrary;
@@ -97,13 +99,19 @@ public:
 
     VERIFY_SUCCEEDED(
         m_dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
-    VERIFY_SUCCEEDED(pCompiler->Compile(pSource, fullPath.c_str(), L"", pShaderTarget,
+    VERIFY_SUCCEEDED(pCompiler->Compile(pSource, fullPath.c_str(), pEntry, pShaderTarget,
                                         const_cast<LPCWSTR*>(pArguments.data()), pArguments.size(),
                                         nullptr, 0,
                                         pIncludeHandler, &pResult));
     CheckOperationSucceeded(pResult, pResultBlob);
   }
 
+  void CompileLib(LPCWSTR filename, IDxcBlob **pResultBlob,
+                  llvm::ArrayRef<LPCWSTR> pArguments = {},
+                  LPCWSTR pShaderTarget = L"lib_6_x") {
+    Compile(filename, pResultBlob, pArguments, L"", pShaderTarget);
+  }
+
   void AssembleLib(LPCWSTR filename, IDxcBlob **pResultBlob) {
     std::wstring fullPath = hlsl_test::GetPathToHlslDataFile(filename);
     CComPtr<IDxcLibrary> pLibrary;
@@ -831,3 +839,55 @@ TEST_F(LinkerTest, LinkSm63ToSm66) {
         "%(.*), %dx.types.ResourceProperties { i32 13, i32 4 }\\)"},
        {}, {}, true);
 }
+
+TEST_F(LinkerTest, RunLinkWithRootSig) {
+  CComPtr<IDxcBlob> pLib0;
+  CompileLib(L"..\\CodeGenHLSL\\linker\\link_with_root_sig.hlsl", &pLib0, {},
+             L"lib_6_x");
+
+  CComPtr<IDxcLinker> pLinker;
+  CreateLinker(&pLinker);
+
+  LPCWSTR libName = L"foo";
+  RegisterDxcModule(libName, pLib0, pLinker);
+
+  LPCWSTR pEntryName = L"vs_main";
+  LPCWSTR pShaderModel = L"vs_6_6";
+
+  LPCWSTR libNames[] = {libName};
+  CComPtr<IDxcOperationResult> pResult;
+
+  VERIFY_SUCCEEDED(pLinker->Link(pEntryName, pShaderModel, libNames,
+                                 1, {}, 0, &pResult));
+  CComPtr<IDxcBlob> pLinkedProgram;
+  CheckOperationSucceeded(pResult, &pLinkedProgram);
+  VERIFY_IS_TRUE(pLinkedProgram);
+
+  CComPtr<IDxcBlob> pProgram;
+  Compile(L"..\\CodeGenHLSL\\linker\\link_with_root_sig.hlsl", &pProgram, {},
+          pEntryName, pShaderModel);
+  VERIFY_IS_TRUE(pProgram);
+
+  const DxilContainerHeader *pLinkedContainer = IsDxilContainerLike(
+      pLinkedProgram->GetBufferPointer(), pLinkedProgram->GetBufferSize());
+
+  VERIFY_IS_TRUE(pLinkedContainer);
+
+  const DxilContainerHeader *pContainer = IsDxilContainerLike(
+      pProgram->GetBufferPointer(), pProgram->GetBufferSize());
+  VERIFY_IS_TRUE(pContainer);
+
+  const DxilPartHeader *pLinkedRSPart =
+      GetDxilPartByType(pLinkedContainer, DFCC_RootSignature);
+  VERIFY_IS_TRUE(pLinkedRSPart);
+  const DxilPartHeader *pRSPart =
+      GetDxilPartByType(pContainer, DFCC_RootSignature);
+  VERIFY_IS_TRUE(pRSPart);
+  VERIFY_IS_TRUE(pRSPart->PartSize == pLinkedRSPart->PartSize);
+
+  const uint8_t *pRS = (const uint8_t *)GetDxilPartData(pRSPart);
+  const uint8_t *pLinkedRS = (const uint8_t *)GetDxilPartData(pLinkedRSPart);
+  for (unsigned i = 0; i < pLinkedRSPart->PartSize; i++) {
+    VERIFY_IS_TRUE(pRS[i] == pLinkedRS[i]);
+  }
+}