Browse Source

Fix regressions with entry props/sigs and UseMinPrecision flag

- Make sure entry/signature data is available once SetShaderModel() is
  called for shader targets (not for library).
- Add bUseMinPrecision flag to SetShaderModel() since the global switch
  is needed at that point.
- Remove SetUseMinPrecision() since it should be set in SetShaderModel()
  and should not be changed.
- Modify DxilModule loading to initialize UseMinPrecision flag correctly
  for entry signatures, and fix signature copy constructor.
- Remove DxilModule level entry property duplicates, use the
  DxilFunctionProps instead, except where useful (GS stream mask),
  and make sure these are always available/in-sync.
- Fix various code accessing properties that don't match target
- Fix various ordering issues causing bugs in properties
- Don't duplicate functions to strip them of parameters when they have
  already been stripped (linker hits this).
Tex Riddell 7 years ago
parent
commit
9a2a731b7b

+ 4 - 17
include/dxc/HLSL/DxilModule.h

@@ -59,7 +59,7 @@ public:
   llvm::LLVMContext &GetCtx() const;
   llvm::Module *GetModule() const;
   OP *GetOP() const;
-  void SetShaderModel(const ShaderModel *pSM);
+  void SetShaderModel(const ShaderModel *pSM, bool bUseMinPrecision = true);
   const ShaderModel *GetShaderModel() const;
   void GetDxilVersion(unsigned &DxilMajor, unsigned &DxilMinor) const;
   void SetValidatorVersion(unsigned ValMajor, unsigned ValMinor);
@@ -224,7 +224,8 @@ public:
   bool ModuleHasMulticomponentUAVLoads();
 
   // Compute shader.
-  unsigned m_NumThreads[3];
+  void SetNumThreads(unsigned x, unsigned y, unsigned z);
+  unsigned GetNumThreads(unsigned idx) const;
 
   // Geometry shader.
   DXIL::InputPrimitive GetInputPrimitive() const;
@@ -243,7 +244,7 @@ public:
   unsigned GetActiveStreamMask() const;
 
   // Language options
-  void SetUseMinPrecision(bool useMinPrecision);
+  // UseMinPrecision must be set at SetShaderModel time.
   bool GetUseMinPrecision() const;
   void SetDisableOptimization(bool disableOptimization);
   bool GetDisableOptimization() const;
@@ -284,27 +285,13 @@ private:
   std::vector<std::unique_ptr<DxilSampler> > m_Samplers;
 
   // Geometry shader.
-  DXIL::InputPrimitive m_InputPrimitive;
-  unsigned m_MaxVertexCount;
   DXIL::PrimitiveTopology m_StreamPrimitiveTopology;
   unsigned m_ActiveStreamMask;
-  unsigned m_NumGSInstances;
-
-  // Hull and Domain shaders.
-  unsigned m_InputControlPointCount;
-  DXIL::TessellatorDomain m_TessellatorDomain;
-
-  // Hull shader.
-  unsigned m_OutputControlPointCount;
-  DXIL::TessellatorPartitioning m_TessellatorPartitioning;
-  DXIL::TessellatorOutputPrimitive m_TessellatorOutputPrimitive;
-  float m_MaxTessellationFactor;
 
 private:
   llvm::LLVMContext &m_Ctx;
   llvm::Module *m_pModule;
   llvm::Function *m_pEntryFunc;
-  llvm::Function *m_pPatchConstantFunc;
   std::string m_EntryName;
   std::unique_ptr<DxilMDHelper> m_pMDHelper;
   std::unique_ptr<llvm::DebugInfoFinder> m_pDebugInfoFinder;

+ 1 - 1
include/dxc/HLSL/DxilSignature.h

@@ -25,7 +25,7 @@ public:
   using Kind = DXIL::SignatureKind;
 
   DxilSignature(DXIL::ShaderKind shaderKind, DXIL::SignatureKind sigKind, bool useMinPrecision);
-  DxilSignature(DXIL::SigPointKind sigPointKind);
+  DxilSignature(DXIL::SigPointKind sigPointKind, bool useMinPrecision);
   DxilSignature(const DxilSignature &src);
   virtual ~DxilSignature();
 

+ 12 - 6
lib/HLSL/DxilContainerAssembler.cpp

@@ -288,18 +288,21 @@ public:
 };
 
 DxilPartWriter *hlsl::NewProgramSignatureWriter(const DxilModule &M, DXIL::SignatureKind Kind) {
+  DXIL::TessellatorDomain domain = DXIL::TessellatorDomain::Undefined;
+  if (M.GetShaderModel()->IsHS() || M.GetShaderModel()->IsDS())
+    domain = M.GetTessellatorDomain();
   switch (Kind) {
   case DXIL::SignatureKind::Input:
     return new DxilProgramSignatureWriter(
-        M.GetInputSignature(), M.GetTessellatorDomain(), true,
+        M.GetInputSignature(), domain, true,
         M.GetUseMinPrecision());
   case DXIL::SignatureKind::Output:
     return new DxilProgramSignatureWriter(
-        M.GetOutputSignature(), M.GetTessellatorDomain(), false,
+        M.GetOutputSignature(), domain, false,
         M.GetUseMinPrecision());
   case DXIL::SignatureKind::PatchConstant:
     return new DxilProgramSignatureWriter(
-        M.GetPatchConstantSignature(), M.GetTessellatorDomain(),
+        M.GetPatchConstantSignature(), domain,
         /*IsInput*/ M.GetShaderModel()->IsDS(),
         /*UseMinPrecision*/M.GetUseMinPrecision());
   }
@@ -1286,12 +1289,15 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
   std::unique_ptr<DxilProgramSignatureWriter> pOutputSigWriter = nullptr;
   std::unique_ptr<DxilProgramSignatureWriter> pPatchConstantSigWriter = nullptr;
   if (!pModule->GetShaderModel()->IsLib()) {
+    DXIL::TessellatorDomain domain = DXIL::TessellatorDomain::Undefined;
+    if (pModule->GetShaderModel()->IsHS() || pModule->GetShaderModel()->IsDS())
+      domain = pModule->GetTessellatorDomain();
     pInputSigWriter = llvm::make_unique<DxilProgramSignatureWriter>(
-        pModule->GetInputSignature(), pModule->GetTessellatorDomain(),
+        pModule->GetInputSignature(), domain,
         /*IsInput*/ true,
         /*UseMinPrecision*/ pModule->GetUseMinPrecision());
     pOutputSigWriter = llvm::make_unique<DxilProgramSignatureWriter>(
-        pModule->GetOutputSignature(), pModule->GetTessellatorDomain(),
+        pModule->GetOutputSignature(), domain,
         /*IsInput*/ false,
         /*UseMinPrecision*/ pModule->GetUseMinPrecision());
     // Write the input and output signature parts.
@@ -1305,7 +1311,7 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
                    });
 
     pPatchConstantSigWriter = llvm::make_unique<DxilProgramSignatureWriter>(
-        pModule->GetPatchConstantSignature(), pModule->GetTessellatorDomain(),
+        pModule->GetPatchConstantSignature(), domain,
         /*IsInput*/ pModule->GetShaderModel()->IsDS(),
         /*UseMinPrecision*/ pModule->GetUseMinPrecision());
     if (pModule->GetPatchConstantSignature().GetElements().size()) {

+ 15 - 5
lib/HLSL/DxilContainerReflection.cpp

@@ -2035,6 +2035,8 @@ UINT DxilShaderReflection::GetConversionInstructionCount() { return 0; }
 UINT DxilShaderReflection::GetBitwiseInstructionCount() { return 0; }
 
 D3D_PRIMITIVE DxilShaderReflection::GetGSInputPrimitive() {
+  if (!m_pDxilModule->GetShaderModel()->IsGS())
+    return D3D_PRIMITIVE::D3D10_PRIMITIVE_UNDEFINED;
   return (D3D_PRIMITIVE)m_pDxilModule->GetInputPrimitive();
 }
 
@@ -2053,11 +2055,19 @@ HRESULT DxilShaderReflection::GetMinFeatureLevel(enum D3D_FEATURE_LEVEL* pLevel)
 
 _Use_decl_annotations_
 UINT DxilShaderReflection::GetThreadGroupSize(UINT *pSizeX, UINT *pSizeY, UINT *pSizeZ) {
-  UINT *pNumThreads = m_pDxilModule->m_NumThreads;
-  AssignToOutOpt(pNumThreads[0], pSizeX);
-  AssignToOutOpt(pNumThreads[1], pSizeY);
-  AssignToOutOpt(pNumThreads[2], pSizeZ);
-  return pNumThreads[0] * pNumThreads[1] * pNumThreads[2];
+  if (!m_pDxilModule->GetShaderModel()->IsCS()) {
+    AssignToOutOpt((UINT)0, pSizeX);
+    AssignToOutOpt((UINT)0, pSizeY);
+    AssignToOutOpt((UINT)0, pSizeZ);
+    return 0;
+  }
+  unsigned x = m_pDxilModule->GetNumThreads(0);
+  unsigned y = m_pDxilModule->GetNumThreads(1);
+  unsigned z = m_pDxilModule->GetNumThreads(2);
+  AssignToOutOpt(x, pSizeX);
+  AssignToOutOpt(y, pSizeY);
+  AssignToOutOpt(z, pSizeZ);
+  return x * y * z;
 }
 
 UINT64 DxilShaderReflection::GetRequiresFlags() {

+ 10 - 10
lib/HLSL/DxilGenerationPass.cpp

@@ -141,12 +141,14 @@ void InitDxilModuleFromHLModule(HLModule &H, DxilModule &M, bool HasDebugInfo) {
   unsigned ValMajor, ValMinor;
   H.GetValidatorVersion(ValMajor, ValMinor);
   M.SetValidatorVersion(ValMajor, ValMinor);
-  M.SetShaderModel(H.GetShaderModel());
+  M.SetShaderModel(H.GetShaderModel(), H.GetHLOptions().bUseMinPrecision);
 
   // Entry function.
-  Function *EntryFn = H.GetEntryFunction();
-  M.SetEntryFunction(EntryFn);
-  M.SetEntryFunctionName(H.GetEntryFunctionName());
+  if (!M.GetShaderModel()->IsLib()) {
+    Function *EntryFn = H.GetEntryFunction();
+    M.SetEntryFunction(EntryFn);
+    M.SetEntryFunctionName(H.GetEntryFunctionName());
+  }
 
   std::vector<GlobalVariable* > &LLVMUsed = M.GetLLVMUsed();
 
@@ -194,8 +196,6 @@ void InitDxilModuleFromHLModule(HLModule &H, DxilModule &M, bool HasDebugInfo) {
   //bool m_bEnableMSAD;
   //M.m_ShaderFlags.SetAllResourcesBound(H.GetHLOptions().bAllResourcesBound);
 
-  M.SetUseMinPrecision(H.GetHLOptions().bUseMinPrecision);
-
   // DXIL type system.
   M.ResetTypeSystem(H.ReleaseTypeSystem());
   // Dxil OP.
@@ -309,12 +309,12 @@ public:
     // High-level metadata should now be turned into low-level metadata.
     const bool SkipInit = true;
     hlsl::DxilModule &DxilMod = M.GetOrCreateDxilModule(SkipInit);
-    if (!SM->IsLib()) {
-      DxilMod.SetShaderProperties(&EntryPropsMap.begin()->second->props);
-    }
+    auto pProps = &EntryPropsMap.begin()->second->props;
     InitDxilModuleFromHLModule(*m_pHLModule, DxilMod, m_HasDbgInfo);
-
     DxilMod.ResetEntryPropsMap(std::move(EntryPropsMap));
+    if (!SM->IsLib()) {
+      DxilMod.SetShaderProperties(pProps);
+    }
 
     HLModule::ClearHLMetadata(M);
     M.ResetHLModule();

+ 2 - 2
lib/HLSL/DxilLinker.cpp

@@ -720,7 +720,7 @@ DxilLinkJob::Link(std::pair<DxilFunctionLinkInfo *, DxilLib *> &entryLinkPair,
   // Create DxilModule.
   const bool bSkipInit = true;
   DxilModule &DM = pM->GetOrCreateDxilModule(bSkipInit);
-  DM.SetShaderModel(pSM);
+  DM.SetShaderModel(pSM, entryDM.GetUseMinPrecision());
 
   // Set Validator version.
   DM.SetValidatorVersion(m_valMajor, m_valMinor);
@@ -812,7 +812,7 @@ DxilLinkJob::LinkToLib(const ShaderModel *pSM) {
   // Create DxilModule.
   const bool bSkipInit = true;
   DxilModule &DM = pM->GetOrCreateDxilModule(bSkipInit);
-  DM.SetShaderModel(pSM);
+  DM.SetShaderModel(pSM, tmpDM.GetUseMinPrecision());
 
   // Set Validator version.
   DM.SetValidatorVersion(m_valMajor, m_valMinor);

+ 274 - 114
lib/HLSL/DxilModule.cpp

@@ -87,23 +87,13 @@ DxilModule::DxilModule(Module *pModule)
 , m_pDebugInfoFinder(nullptr)
 , m_pEntryFunc(nullptr)
 , m_EntryName("")
-, m_pPatchConstantFunc(nullptr)
 , m_pSM(nullptr)
 , m_DxilMajor(DXIL::kDxilMajor)
 , m_DxilMinor(DXIL::kDxilMinor)
 , m_ValMajor(1)
 , m_ValMinor(0)
-, m_InputPrimitive(DXIL::InputPrimitive::Undefined)
-, m_MaxVertexCount(0)
 , m_StreamPrimitiveTopology(DXIL::PrimitiveTopology::Undefined)
 , m_ActiveStreamMask(0)
-, m_NumGSInstances(1)
-, m_InputControlPointCount(0)
-, m_TessellatorDomain(DXIL::TessellatorDomain::Undefined)
-, m_OutputControlPointCount(0)
-, m_TessellatorPartitioning(DXIL::TessellatorPartitioning::Undefined)
-, m_TessellatorOutputPrimitive(DXIL::TessellatorOutputPrimitive::Undefined)
-, m_MaxTessellationFactor(0.f)
 , m_RootSignature(nullptr)
 , m_bUseMinPrecision(true) // use min precision by default
 , m_bDisableOptimizations(false)
@@ -111,8 +101,6 @@ DxilModule::DxilModule(Module *pModule)
 , m_AutoBindingSpace(UINT_MAX) {
   DXASSERT_NOMSG(m_pModule != nullptr);
 
-  m_NumThreads[0] = m_NumThreads[1] = m_NumThreads[2] = 0;
-
 #if defined(_DEBUG) || defined(DBG)
   // Pin LLVM dump methods.
   void (__thiscall Module::*pfnModuleDump)() const = &Module::dump;
@@ -130,13 +118,21 @@ LLVMContext &DxilModule::GetCtx() const { return m_Ctx; }
 Module *DxilModule::GetModule() const { return m_pModule; }
 OP *DxilModule::GetOP() const { return m_pOP.get(); }
 
-void DxilModule::SetShaderModel(const ShaderModel *pSM) {
+void DxilModule::SetShaderModel(const ShaderModel *pSM, bool bUseMinPrecision) {
   DXASSERT(m_pSM == nullptr || (pSM != nullptr && *m_pSM == *pSM), "shader model must not change for the module");
   DXASSERT(pSM != nullptr && pSM->IsValidForDxil(), "shader model must be valid");
   DXASSERT(pSM->IsValidForModule(), "shader model must be valid for top-level module use");
   m_pSM = pSM;
   m_pSM->GetDxilVersion(m_DxilMajor, m_DxilMinor);
   m_pMDHelper->SetShaderModel(m_pSM);
+  m_bUseMinPrecision = bUseMinPrecision;
+  if (!m_pSM->IsLib()) {
+    // Always have valid entry props for non-lib case from this point on.
+    DxilFunctionProps props;
+    props.shaderKind = m_pSM->GetKind();
+    m_DxilEntryPropsMap[nullptr] =
+      llvm::make_unique<DxilEntryProps>(props, m_bUseMinPrecision);
+  }
   m_RootSignature.reset(new RootSignatureHandle());
 }
 
@@ -198,7 +194,18 @@ const Function *DxilModule::GetEntryFunction() const {
 }
 
 void DxilModule::SetEntryFunction(Function *pEntryFunc) {
+  if (m_pSM->IsLib()) {
+    DXASSERT(pEntryFunc == nullptr,
+             "Otherwise, trying to set an entry function on library");
+    m_pEntryFunc = nullptr;
+    return;
+  }
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
   m_pEntryFunc = pEntryFunc;
+  // Move entry props to new function in order to preserve them.
+  std::unique_ptr<DxilEntryProps> Props = std::move(m_DxilEntryPropsMap.begin()->second);
+  m_DxilEntryPropsMap.clear();
+  m_DxilEntryPropsMap[m_pEntryFunc] = std::move(Props);
 }
 
 const string &DxilModule::GetEntryFunctionName() const {
@@ -210,15 +217,37 @@ void DxilModule::SetEntryFunctionName(const string &name) {
 }
 
 llvm::Function *DxilModule::GetPatchConstantFunction() {
-  return m_pPatchConstantFunc;
+  if (!m_pSM->IsHS())
+    return nullptr;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  return props.ShaderProps.HS.patchConstantFunc;
 }
 
 const llvm::Function *DxilModule::GetPatchConstantFunction() const {
-  return m_pPatchConstantFunc;
+  if (!m_pSM->IsHS())
+    return nullptr;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  return props.ShaderProps.HS.patchConstantFunc;
 }
 
-void DxilModule::SetPatchConstantFunction(llvm::Function *pFunc) {
-  m_pPatchConstantFunc = pFunc;
+void DxilModule::SetPatchConstantFunction(llvm::Function *patchConstantFunc) {
+  if (!m_pSM->IsHS())
+    return;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  auto &HS = props.ShaderProps.HS;
+  if (HS.patchConstantFunc != patchConstantFunc) {
+    if (HS.patchConstantFunc)
+      m_PatchConstantFunctions.erase(HS.patchConstantFunc);
+    HS.patchConstantFunc = patchConstantFunc;
+    if (patchConstantFunc)
+      m_PatchConstantFunctions.insert(patchConstantFunc);
+  }
 }
 
 unsigned DxilModule::GetGlobalFlags() const {
@@ -326,24 +355,64 @@ void DxilModule::CollectShaderFlagsForModule() {
   CollectShaderFlagsForModule(m_ShaderFlags);
 }
 
+void DxilModule::SetNumThreads(unsigned x, unsigned y, unsigned z) {
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsCS(),
+           "only works for CS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsCS(), "Must be CS profile");
+  unsigned *numThreads = props.ShaderProps.CS.numThreads;
+  numThreads[0] = x;
+  numThreads[1] = y;
+  numThreads[2] = z;
+}
+unsigned DxilModule::GetNumThreads(unsigned idx) const {
+  DXASSERT(idx < 3, "Thread dimension index must be 0-2");
+  if (!m_pSM->IsCS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  __analysis_assume(idx < 3);
+  const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsCS(), "Must be CS profile");
+  return props.ShaderProps.CS.numThreads[idx];
+}
+
 DXIL::InputPrimitive DxilModule::GetInputPrimitive() const {
-  return m_InputPrimitive;
+  if (!m_pSM->IsGS())
+    return DXIL::InputPrimitive::Undefined;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsGS(), "Must be GS profile");
+  return props.ShaderProps.GS.inputPrimitive;
 }
 
 void DxilModule::SetInputPrimitive(DXIL::InputPrimitive IP) {
-  DXASSERT_NOMSG(m_InputPrimitive == DXIL::InputPrimitive::Undefined);
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsGS(),
+           "only works for GS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsGS(), "Must be GS profile");
+  auto &GS = props.ShaderProps.GS;
   DXASSERT_NOMSG(DXIL::InputPrimitive::Undefined < IP && IP < DXIL::InputPrimitive::LastEntry);
-  m_InputPrimitive = IP;
+  GS.inputPrimitive = IP;
 }
 
 unsigned DxilModule::GetMaxVertexCount() const {
-  DXASSERT_NOMSG(m_MaxVertexCount != 0);
-  return m_MaxVertexCount;
+  if (!m_pSM->IsGS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsGS(), "Must be GS profile");
+  auto &GS = props.ShaderProps.GS;
+  DXASSERT_NOMSG(GS.maxVertexCount != 0);
+  return GS.maxVertexCount;
 }
 
 void DxilModule::SetMaxVertexCount(unsigned Count) {
-  DXASSERT_NOMSG(m_MaxVertexCount == 0);
-  m_MaxVertexCount = Count;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsGS(),
+           "only works for GS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsGS(), "Must be GS profile");
+  auto &GS = props.ShaderProps.GS;
+  GS.maxVertexCount = Count;
 }
 
 DXIL::PrimitiveTopology DxilModule::GetStreamPrimitiveTopology() const {
@@ -352,6 +421,7 @@ DXIL::PrimitiveTopology DxilModule::GetStreamPrimitiveTopology() const {
 
 void DxilModule::SetStreamPrimitiveTopology(DXIL::PrimitiveTopology Topology) {
   m_StreamPrimitiveTopology = Topology;
+  SetActiveStreamMask(m_ActiveStreamMask);  // Update props
 }
 
 bool DxilModule::HasMultipleOutputStreams() const {
@@ -384,11 +454,20 @@ unsigned DxilModule::GetOutputStream() const {
 }
 
 unsigned DxilModule::GetGSInstanceCount() const {
-  return m_NumGSInstances;
+  if (!m_pSM->IsGS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsGS(), "Must be GS profile");
+  return props.ShaderProps.GS.instanceCount;
 }
 
 void DxilModule::SetGSInstanceCount(unsigned Count) {
-  m_NumGSInstances = Count;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsGS(),
+           "only works for GS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsGS(), "Must be GS profile");
+  props.ShaderProps.GS.instanceCount = Count;
 }
 
 bool DxilModule::IsStreamActive(unsigned Stream) const {
@@ -401,20 +480,27 @@ void DxilModule::SetStreamActive(unsigned Stream, bool bActive) {
   } else {
     m_ActiveStreamMask &= ~(1<<Stream);
   }
+  SetActiveStreamMask(m_ActiveStreamMask);
 }
 
 void DxilModule::SetActiveStreamMask(unsigned Mask) {
   m_ActiveStreamMask = Mask;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsGS(),
+           "only works for GS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsGS(), "Must be GS profile");
+  for (unsigned i = 0; i < 4; i++) {
+    if (IsStreamActive(i))
+      props.ShaderProps.GS.streamPrimitiveTopologies[i] = m_StreamPrimitiveTopology;
+    else
+      props.ShaderProps.GS.streamPrimitiveTopologies[i] = DXIL::PrimitiveTopology::Undefined;
+  }
 }
 
 unsigned DxilModule::GetActiveStreamMask() const {
   return m_ActiveStreamMask;
 }
 
-void DxilModule::SetUseMinPrecision(bool UseMinPrecision) {
-  m_bUseMinPrecision = UseMinPrecision;
-}
-
 bool DxilModule::GetUseMinPrecision() const {
   return m_bUseMinPrecision;
 }
@@ -436,51 +522,117 @@ bool DxilModule::GetAllResourcesBound() const {
 }
 
 unsigned DxilModule::GetInputControlPointCount() const {
-  return m_InputControlPointCount;
+  if (!(m_pSM->IsHS() || m_pSM->IsDS()))
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS() || props.IsDS(), "Must be HS or DS profile");
+  if (props.IsHS())
+    return props.ShaderProps.HS.inputControlPoints;
+  else
+    return props.ShaderProps.DS.inputControlPoints;
 }
 
 void DxilModule::SetInputControlPointCount(unsigned NumICPs) {
-  m_InputControlPointCount = NumICPs;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1
+           && (m_pSM->IsHS() || m_pSM->IsDS()),
+           "only works for non-lib profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS() || props.IsDS(), "Must be HS or DS profile");
+  if (props.IsHS())
+    props.ShaderProps.HS.inputControlPoints = NumICPs;
+  else
+    props.ShaderProps.DS.inputControlPoints = NumICPs;
 }
 
 DXIL::TessellatorDomain DxilModule::GetTessellatorDomain() const {
-  return m_TessellatorDomain;
+  if (!(m_pSM->IsHS() || m_pSM->IsDS()))
+    return DXIL::TessellatorDomain::Undefined;
+  DXASSERT_NOMSG(m_DxilEntryPropsMap.size() == 1);
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  if (props.IsHS())
+    return props.ShaderProps.HS.domain;
+  else
+    return props.ShaderProps.DS.domain;
 }
 
 void DxilModule::SetTessellatorDomain(DXIL::TessellatorDomain TessDomain) {
-  m_TessellatorDomain = TessDomain;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1
+           && (m_pSM->IsHS() || m_pSM->IsDS()),
+           "only works for HS or DS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS() || props.IsDS(), "Must be HS or DS profile");
+  if (props.IsHS())
+    props.ShaderProps.HS.domain = TessDomain;
+  else
+    props.ShaderProps.DS.domain = TessDomain;
 }
 
 unsigned DxilModule::GetOutputControlPointCount() const {
-  return m_OutputControlPointCount;
+  if (!m_pSM->IsHS())
+    return 0;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  return props.ShaderProps.HS.outputControlPoints;
 }
 
 void DxilModule::SetOutputControlPointCount(unsigned NumOCPs) {
-  m_OutputControlPointCount = NumOCPs;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsHS(),
+           "only works for HS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  props.ShaderProps.HS.outputControlPoints = NumOCPs;
 }
 
 DXIL::TessellatorPartitioning DxilModule::GetTessellatorPartitioning() const {
-  return m_TessellatorPartitioning;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsHS(),
+           "only works for HS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  return props.ShaderProps.HS.partition;
 }
 
 void DxilModule::SetTessellatorPartitioning(DXIL::TessellatorPartitioning TessPartitioning) {
-  m_TessellatorPartitioning = TessPartitioning;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsHS(),
+           "only works for HS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  props.ShaderProps.HS.partition = TessPartitioning;
 }
 
 DXIL::TessellatorOutputPrimitive DxilModule::GetTessellatorOutputPrimitive() const {
-  return m_TessellatorOutputPrimitive;
+  if (!m_pSM->IsHS())
+    return DXIL::TessellatorOutputPrimitive::Undefined;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  return props.ShaderProps.HS.outputPrimitive;
 }
 
 void DxilModule::SetTessellatorOutputPrimitive(DXIL::TessellatorOutputPrimitive TessOutputPrimitive) {
-  m_TessellatorOutputPrimitive = TessOutputPrimitive;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsHS(),
+           "only works for HS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  props.ShaderProps.HS.outputPrimitive = TessOutputPrimitive;
 }
 
 float DxilModule::GetMaxTessellationFactor() const {
-  return m_MaxTessellationFactor;
+  if (!m_pSM->IsHS())
+    return 0.0F;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  return props.ShaderProps.HS.maxTessFactor;
 }
 
 void DxilModule::SetMaxTessellationFactor(float MaxTessellationFactor) {
-  m_MaxTessellationFactor = MaxTessellationFactor;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsHS(),
+           "only works for HS profile");
+  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+  DXASSERT(props.IsHS(), "Must be HS profile");
+  props.ShaderProps.HS.maxTessFactor = MaxTessellationFactor;
 }
 
 void DxilModule::SetAutoBindingSpace(uint32_t Space) {
@@ -493,51 +645,39 @@ uint32_t DxilModule::GetAutoBindingSpace() const {
 void DxilModule::SetShaderProperties(DxilFunctionProps *props) {
   if (!props)
     return;
+  DxilFunctionProps &ourProps = GetDxilFunctionProps(GetEntryFunction());
+  if (props != &ourProps) {
+    ourProps.shaderKind = props->shaderKind;
+    ourProps.ShaderProps = props->ShaderProps;
+  }
   switch (props->shaderKind) {
   case DXIL::ShaderKind::Pixel: {
     auto &PS = props->ShaderProps.PS;
     m_ShaderFlags.SetForceEarlyDepthStencil(PS.EarlyDepthStencil);
   } break;
-  case DXIL::ShaderKind::Compute: {
-    auto &CS = props->ShaderProps.CS;
-    for (size_t i = 0; i < _countof(m_NumThreads); ++i)
-      m_NumThreads[i] = CS.numThreads[i];
-  } break;
-  case DXIL::ShaderKind::Domain: {
-    auto &DS = props->ShaderProps.DS;
-    SetTessellatorDomain(DS.domain);
-    SetInputControlPointCount(DS.inputControlPoints);
-  } break;
-  case DXIL::ShaderKind::Hull: {
-    auto &HS = props->ShaderProps.HS;
-    SetPatchConstantFunction(HS.patchConstantFunc);
-    SetTessellatorDomain(HS.domain);
-    SetTessellatorPartitioning(HS.partition);
-    SetTessellatorOutputPrimitive(HS.outputPrimitive);
-    SetInputControlPointCount(HS.inputControlPoints);
-    SetOutputControlPointCount(HS.outputControlPoints);
-    SetMaxTessellationFactor(HS.maxTessFactor);
-  } break;
+  case DXIL::ShaderKind::Compute:
+  case DXIL::ShaderKind::Domain:
+  case DXIL::ShaderKind::Hull:
   case DXIL::ShaderKind::Vertex:
     break;
   default: {
     DXASSERT(props->shaderKind == DXIL::ShaderKind::Geometry,
              "else invalid shader kind");
     auto &GS = props->ShaderProps.GS;
-    SetInputPrimitive(GS.inputPrimitive);
-    SetMaxVertexCount(GS.maxVertexCount);
+    m_ActiveStreamMask = 0;
     for (size_t i = 0; i < _countof(GS.streamPrimitiveTopologies); ++i) {
       if (GS.streamPrimitiveTopologies[i] !=
           DXIL::PrimitiveTopology::Undefined) {
-        SetStreamActive(i, true);
-        DXASSERT_NOMSG(GetStreamPrimitiveTopology() ==
+        m_ActiveStreamMask |= (1 << i);
+        DXASSERT_NOMSG(m_StreamPrimitiveTopology ==
                            DXIL::PrimitiveTopology::Undefined ||
-                       GetStreamPrimitiveTopology() ==
+                       m_StreamPrimitiveTopology ==
                            GS.streamPrimitiveTopologies[i]);
-        SetStreamPrimitiveTopology(GS.streamPrimitiveTopologies[i]);
+        m_StreamPrimitiveTopology = GS.streamPrimitiveTopologies[i];
       }
     }
-    SetGSInstanceCount(GS.instanceCount);
+    // Refresh props:
+    SetActiveStreamMask(m_ActiveStreamMask);
   } break;
   }
 }
@@ -774,37 +914,37 @@ void DxilModule::RemoveUnusedResourceSymbols() {
 
 DxilSignature &DxilModule::GetInputSignature() {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
-           "only works for none lib profile");
+           "only works for non-lib profile");
   return m_DxilEntryPropsMap.begin()->second->sig.InputSignature;
 }
 
 const DxilSignature &DxilModule::GetInputSignature() const {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
-           "only works for none lib profile");
+           "only works for non-lib profile");
   return m_DxilEntryPropsMap.begin()->second->sig.InputSignature;
 }
 
 DxilSignature &DxilModule::GetOutputSignature() {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
-           "only works for none lib profile");
+           "only works for non-lib profile");
   return m_DxilEntryPropsMap.begin()->second->sig.OutputSignature;
 }
 
 const DxilSignature &DxilModule::GetOutputSignature() const {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
-           "only works for none lib profile");
+           "only works for non-lib profile");
   return m_DxilEntryPropsMap.begin()->second->sig.OutputSignature;
 }
 
 DxilSignature &DxilModule::GetPatchConstantSignature() {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
-           "only works for none lib profile");
+           "only works for non-lib profile");
   return m_DxilEntryPropsMap.begin()->second->sig.PatchConstantSignature;
 }
 
 const DxilSignature &DxilModule::GetPatchConstantSignature() const {
   DXASSERT(m_DxilEntryPropsMap.size() == 1 && !m_pSM->IsLib(),
-           "only works for none lib profile");
+           "only works for non-lib profile");
   return m_DxilEntryPropsMap.begin()->second->sig.PatchConstantSignature;
 }
 
@@ -862,11 +1002,14 @@ void DxilModule::SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, l
            "Hull shader must already have function props!");
   DxilFunctionProps &props = propIter->second->props;
   DXASSERT(props.IsHS(), "else hullShaderFunc is not a Hull Shader");
-  if (props.ShaderProps.HS.patchConstantFunc)
-    m_PatchConstantFunctions.erase(props.ShaderProps.HS.patchConstantFunc);
-  props.ShaderProps.HS.patchConstantFunc = patchConstantFunc;
-  if (patchConstantFunc)
-    m_PatchConstantFunctions.insert(patchConstantFunc);
+  auto &HS = props.ShaderProps.HS;
+  if (HS.patchConstantFunc != patchConstantFunc) {
+    if (HS.patchConstantFunc)
+      m_PatchConstantFunctions.erase(HS.patchConstantFunc);
+    HS.patchConstantFunc = patchConstantFunc;
+    if (patchConstantFunc)
+      m_PatchConstantFunctions.insert(patchConstantFunc);
+  }
 }
 bool DxilModule::IsGraphicsShader(const llvm::Function *F) const {
   return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsGraphics();
@@ -1084,20 +1227,54 @@ bool DxilModule::IsKnownNamedMetaData(llvm::NamedMDNode &Node) {
 void DxilModule::LoadDxilMetadata() {
   m_pMDHelper->LoadDxilVersion(m_DxilMajor, m_DxilMinor);
   m_pMDHelper->LoadValidatorVersion(m_ValMajor, m_ValMinor);
-  const ShaderModel *loadedModule;
-  m_pMDHelper->LoadDxilShaderModel(loadedModule);
-  SetShaderModel(loadedModule);
+  const ShaderModel *loadedSM;
+  m_pMDHelper->LoadDxilShaderModel(loadedSM);
+
+  // This must be set before LoadDxilEntryProperties
+  m_pMDHelper->SetShaderModel(loadedSM);
+
+  // Setting module shader model requires UseMinPrecision flag,
+  // which requires loading m_ShaderFlags,
+  // which requires global entry properties,
+  // so load entry properties first, then set the shader model
 
   const llvm::NamedMDNode *pEntries = m_pMDHelper->GetDxilEntryPoints();
-  if (!loadedModule->IsLib()) {
+  if (!loadedSM->IsLib()) {
     IFTBOOL(pEntries->getNumOperands() == 1, DXC_E_INCORRECT_DXIL_METADATA);
   }
   Function *pEntryFunc;
   string EntryName;
-  const llvm::MDOperand *pSignatures, *pResources, *pProperties;
-  m_pMDHelper->GetDxilEntryPoint(pEntries->getOperand(0), pEntryFunc, EntryName, pSignatures, pResources, pProperties);
+  const llvm::MDOperand *pEntrySignatures, *pEntryResources, *pEntryProperties;
+  m_pMDHelper->GetDxilEntryPoint(pEntries->getOperand(0),
+                                 pEntryFunc, EntryName,
+                                 pEntrySignatures, pEntryResources,
+                                 pEntryProperties);
+
+  uint64_t rawShaderFlags = 0;
+  DxilFunctionProps entryFuncProps;
+  entryFuncProps.shaderKind = loadedSM->GetKind();
+  if (loadedSM->IsLib()) {
+    // Get rawShaderFlags and m_AutoBindingSpace; entryFuncProps unused.
+    m_pMDHelper->LoadDxilEntryProperties(*pEntryProperties, rawShaderFlags,
+                                         entryFuncProps, m_AutoBindingSpace);
+  }
+  else {
+    m_pMDHelper->LoadDxilEntryProperties(*pEntryProperties, rawShaderFlags,
+                                         entryFuncProps, m_AutoBindingSpace);
+  }
+
+  m_bUseMinPrecision = true;
+  if (rawShaderFlags) {
+    m_ShaderFlags.SetShaderFlagsRaw(rawShaderFlags);
+    m_bUseMinPrecision = !m_ShaderFlags.GetUseNativeLowPrecision();
+    m_bDisableOptimizations = m_ShaderFlags.GetDisableOptimizations();
+    m_bAllResourcesBound = m_ShaderFlags.GetAllResourcesBound();
+  }
 
-  if (loadedModule->IsLib()) {
+  // Now that we have the UseMinPrecision flag, set shader model:
+  SetShaderModel(loadedSM, m_bUseMinPrecision);
+
+  if (loadedSM->IsLib()) {
     for (unsigned i = 1; i < pEntries->getNumOperands(); i++) {
       Function *pFunc;
       string Name;
@@ -1116,7 +1293,7 @@ void DxilModule::LoadDxilMetadata() {
       }
 
       std::unique_ptr<DxilEntryProps> pEntryProps =
-          llvm::make_unique<DxilEntryProps>(props, GetUseMinPrecision());
+          llvm::make_unique<DxilEntryProps>(props, m_bUseMinPrecision);
       DXASSERT(pSignatures->get() == nullptr || !props.IsRay(),
                "Raytracing has no signature");
       m_pMDHelper->LoadDxilSignatures(*pSignatures, pEntryProps->sig);
@@ -1124,37 +1301,20 @@ void DxilModule::LoadDxilMetadata() {
       m_DxilEntryPropsMap[pFunc] = std::move(pEntryProps);
     }
   } else {
-    DxilFunctionProps props;
-    props.shaderKind = loadedModule->GetKind();
-
     std::unique_ptr<DxilEntryProps> pEntryProps =
-        llvm::make_unique<DxilEntryProps>(props, GetUseMinPrecision());
-    m_pMDHelper->LoadDxilSignatures(*pSignatures, pEntryProps->sig);
+        llvm::make_unique<DxilEntryProps>(entryFuncProps, m_bUseMinPrecision);
+    DxilFunctionProps *pFuncProps = &pEntryProps->props;
+    m_pMDHelper->LoadDxilSignatures(*pEntrySignatures, pEntryProps->sig);
 
+    m_DxilEntryPropsMap.clear();
     m_DxilEntryPropsMap[pEntryFunc] = std::move(pEntryProps);
-  }
-  SetEntryFunction(pEntryFunc);
-  SetEntryFunctionName(EntryName);
 
-  uint64_t rawShaderFlags = 0;
-  if (m_pSM->IsLib()) {
-    DxilFunctionProps props;
-    m_pMDHelper->LoadDxilEntryProperties(*pProperties, rawShaderFlags, props,
-                                          m_AutoBindingSpace);
-  } else {
-    DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
-    m_pMDHelper->LoadDxilEntryProperties(
-        *pProperties, rawShaderFlags,
-        props, m_AutoBindingSpace);
-    SetShaderProperties(&props);
+    SetEntryFunction(pEntryFunc);
+    SetEntryFunctionName(EntryName);
+    SetShaderProperties(pFuncProps);
   }
-  if (rawShaderFlags) {
-    m_ShaderFlags.SetShaderFlagsRaw(rawShaderFlags);
-    m_bUseMinPrecision = !m_ShaderFlags.GetUseNativeLowPrecision();
-    m_bDisableOptimizations = m_ShaderFlags.GetDisableOptimizations();
-    m_bAllResourcesBound = m_ShaderFlags.GetAllResourcesBound();
-  }
-  LoadDxilResources(*pResources);
+
+  LoadDxilResources(*pEntryResources);
 
   m_pMDHelper->LoadDxilTypeSystem(*m_pTypeSystem.get());
 

+ 29 - 12
lib/HLSL/DxilPreparePasses.cpp

@@ -133,8 +133,31 @@ INITIALIZE_PASS(DxilDeadFunctionElimination, "dxil-dfe", "Remove all unused func
 
 namespace {
 
-Function *StripFunctionParameter(Function *F, DxilModule &DM,
+static void TransferEntryFunctionAttributes(Function *F, Function *NewFunc) {
+  // Keep necessary function attributes
+  AttributeSet attributeSet = F->getAttributes();
+  StringRef attrKind, attrValue;
+  if (attributeSet.hasAttribute(AttributeSet::FunctionIndex, DXIL::kFP32DenormKindString)) {
+    Attribute attribute = attributeSet.getAttribute(AttributeSet::FunctionIndex, DXIL::kFP32DenormKindString);
+    DXASSERT(attribute.isStringAttribute(), "otherwise we have wrong fp-denorm-mode attribute.");
+    attrKind = attribute.getKindAsString();
+    attrValue = attribute.getValueAsString();
+  }
+  if (F == NewFunc) {
+    NewFunc->removeAttributes(AttributeSet::FunctionIndex, attributeSet);
+  }
+  if (!attrKind.empty() && !attrValue.empty())
+    NewFunc->addFnAttr(attrKind, attrValue);
+}
+
+static Function *StripFunctionParameter(Function *F, DxilModule &DM,
     DenseMap<const Function *, DISubprogram *> &FunctionDIs) {
+  if (F->arg_empty() && F->getReturnType()->isVoidTy()) {
+    // This will strip non-entry function attributes
+    TransferEntryFunctionAttributes(F, F);
+    return nullptr;
+  }
+
   Module &M = *DM.GetModule();
   Type *VoidTy = Type::getVoidTy(M.getContext());
   FunctionType *FT = FunctionType::get(VoidTy, false);
@@ -152,13 +175,7 @@ Function *StripFunctionParameter(Function *F, DxilModule &DM,
   // Splice the body of the old function right into the new function.
   NewFunc->getBasicBlockList().splice(NewFunc->begin(), F->getBasicBlockList());
 
-  // Keep necessary function attributes
-  AttributeSet attributeSet = F->getAttributes();
-  if (attributeSet.hasAttribute(AttributeSet::FunctionIndex, DXIL::kFP32DenormKindString)) {
-    Attribute attribute = attributeSet.getAttribute(AttributeSet::FunctionIndex, DXIL::kFP32DenormKindString);
-    DXASSERT(attribute.isStringAttribute(), "otherwise we have wrong fp-denorm-mode attribute.");
-    NewFunc->addFnAttr(attribute.getKindAsString(), attribute.getValueAsString());
-  }
+  TransferEntryFunctionAttributes(F, NewFunc);
 
   // Patch the pointer to LLVM function in debug info descriptor.
   auto DI = FunctionDIs.find(F);
@@ -361,7 +378,6 @@ private:
             StripFunctionParameter(PatchConstantFunc, DM, FunctionDIs);
         if (PatchConstantFunc) {
           DM.SetPatchConstantFunction(PatchConstantFunc);
-          DM.SetPatchConstantFunctionForHS(DM.GetEntryFunction(), PatchConstantFunc);
         }
       }
 
@@ -388,16 +404,17 @@ private:
       for (Function *entry : entries) {
         DxilFunctionProps &props = DM.GetDxilFunctionProps(entry);
         if (props.IsHS()) {
+          // Strip patch constant function first.
           Function* patchConstFunc = props.ShaderProps.HS.patchConstantFunc;
           auto it = patchConstantUpdates.find(patchConstFunc);
           if (it == patchConstantUpdates.end()) {
             patchConstFunc = patchConstantUpdates[patchConstFunc] =
-              StripFunctionParameter(patchConstFunc, DM, FunctionDIs);
+                StripFunctionParameter(patchConstFunc, DM, FunctionDIs);
           } else {
             patchConstFunc = it->second;
           }
-          // Strip patch constant function first.
-          DM.SetPatchConstantFunctionForHS(entry, patchConstFunc);
+          if (patchConstFunc)
+            DM.SetPatchConstantFunctionForHS(entry, patchConstFunc);
         }
         StripFunctionParameter(entry, DM, FunctionDIs);
       }

+ 8 - 4
lib/HLSL/DxilSignature.cpp

@@ -24,17 +24,21 @@ namespace hlsl {
 // Singnature methods.
 //
 DxilSignature::DxilSignature(DXIL::ShaderKind shaderKind,
-                             DXIL::SignatureKind sigKind, bool useMinPrecision)
+                             DXIL::SignatureKind sigKind,
+                             bool useMinPrecision)
     : m_sigPointKind(SigPoint::GetKind(shaderKind, sigKind,
                                        /*isPatchConstantFunction*/ false,
                                        /*isSpecialInput*/ false)),
       m_UseMinPrecision(useMinPrecision) {}
 
-DxilSignature::DxilSignature(DXIL::SigPointKind sigPointKind)
-: m_sigPointKind(sigPointKind) {}
+DxilSignature::DxilSignature(DXIL::SigPointKind sigPointKind,
+                             bool useMinPrecision)
+    : m_sigPointKind(sigPointKind),
+      m_UseMinPrecision(useMinPrecision) {}
 
 DxilSignature::DxilSignature(const DxilSignature &src)
-    : m_sigPointKind(src.m_sigPointKind) {
+    : m_sigPointKind(src.m_sigPointKind),
+      m_UseMinPrecision(src.m_UseMinPrecision) {
   const bool bSetID = false;
   for (auto &Elt : src.GetElements()) {
     std::unique_ptr<DxilSignatureElement> newElt = CreateElement();

+ 11 - 11
lib/HLSL/DxilValidation.cpp

@@ -4197,9 +4197,9 @@ static void ValidateShaderState(ValidationContext &ValCtx) {
   DXIL::ShaderKind ShaderType = M.GetShaderModel()->GetKind();
 
   if (ShaderType == DXIL::ShaderKind::Compute) {
-    unsigned x = M.m_NumThreads[0];
-    unsigned y = M.m_NumThreads[1];
-    unsigned z = M.m_NumThreads[2];
+    unsigned x = M.GetNumThreads(0);
+    unsigned y = M.GetNumThreads(1);
+    unsigned z = M.GetNumThreads(2);
 
     unsigned threadsInGroup = x * y * z;
 
@@ -4350,6 +4350,14 @@ static void ValidateShaderState(ValidationContext &ValCtx) {
     }
 
     CheckPatchConstantSemantic(ValCtx);
+
+    unsigned outputControlPointCount = M.GetOutputControlPointCount();
+    if (outputControlPointCount > DXIL::kMaxIAPatchControlPointCount) {
+      ValCtx.EmitFormatError(
+        ValidationRule::SmOutputControlPointCountRange,
+        { std::to_string(DXIL::kMaxIAPatchControlPointCount),
+        std::to_string(outputControlPointCount) });
+    }
   } else if (ShaderType == DXIL::ShaderKind::Geometry) {
     unsigned maxVertexCount = M.GetMaxVertexCount();
     if (maxVertexCount > DXIL::kMaxGSOutputVertexCount) {
@@ -4383,14 +4391,6 @@ static void ValidateShaderState(ValidationContext &ValCtx) {
       ValCtx.EmitError(ValidationRule::SmGSValidInputPrimitive);
     }
   }
-
-  unsigned outputControlPointCount = M.GetOutputControlPointCount();
-  if (outputControlPointCount > DXIL::kMaxIAPatchControlPointCount) {
-    ValCtx.EmitFormatError(
-        ValidationRule::SmOutputControlPointCountRange,
-        {std::to_string(DXIL::kMaxIAPatchControlPointCount),
-         std::to_string(outputControlPointCount)});
-  }
 }
 
 static bool