Jelajahi Sumber

Add copy function for DxilSignature and DxilTypeSystem. (#359)

* Add copy function for DxilSignature and DxilTypeSystem.
Also make OP::RefreshCache() public.
Xiang Li 8 tahun lalu
induk
melakukan
fba18dd59c

+ 2 - 1
include/dxc/HLSL/DxilOperations.h

@@ -38,6 +38,8 @@ public:
   OP() = delete;
   OP(llvm::LLVMContext &Ctx, llvm::Module *pModule);
 
+  void RefreshCache();
+
   llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
   llvm::ArrayRef<llvm::Function *> GetOpFuncList(OpCode OpCode) const;
   void RemoveFunction(llvm::Function *F);
@@ -108,7 +110,6 @@ private:
   };
   OpCodeCacheItem m_OpCodeClassCache[(unsigned)OpCodeClass::NumOpClasses];
   std::unordered_map<const llvm::Function *, OpCodeClass> m_FunctionToOpClass;
-  void RefreshCache(llvm::Module *pModule);
   void UpdateCache(OpCodeClass opClass, unsigned typeSlot, llvm::Function *F);
 private:
   // Static properties.

+ 13 - 0
include/dxc/HLSL/DxilSignature.h

@@ -26,6 +26,7 @@ public:
 
   DxilSignature(DXIL::ShaderKind shaderKind, DXIL::SignatureKind sigKind);
   DxilSignature(DXIL::SigPointKind sigPointKind);
+  DxilSignature(const DxilSignature &src);
   virtual ~DxilSignature();
 
   bool IsInput() const;
@@ -53,4 +54,16 @@ private:
   std::vector<std::unique_ptr<DxilSignatureElement> > m_Elements;
 };
 
+struct DxilEntrySignature {
+  DxilEntrySignature(DXIL::ShaderKind shaderKind)
+      : InputSignature(shaderKind, DxilSignature::Kind::Input),
+        OutputSignature(shaderKind, DxilSignature::Kind::Output),
+        PatchConstantSignature(shaderKind, DxilSignature::Kind::PatchConstant) {
+  }
+  DxilEntrySignature(const DxilEntrySignature &src);
+  DxilSignature InputSignature;
+  DxilSignature OutputSignature;
+  DxilSignature PatchConstantSignature;
+};
+
 } // namespace hlsl

+ 8 - 0
include/dxc/HLSL/DxilTypeSystem.h

@@ -165,12 +165,14 @@ public:
 
   DxilStructAnnotation *AddStructAnnotation(const llvm::StructType *pStructType);
   DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType);
+  const DxilStructAnnotation *GetStructAnnotation(const llvm::StructType *pStructType) const;
   void EraseStructAnnotation(const llvm::StructType *pStructType);
 
   StructAnnotationMap &GetStructAnnotationMap();
 
   DxilFunctionAnnotation *AddFunctionAnnotation(const llvm::Function *pFunction);
   DxilFunctionAnnotation *GetFunctionAnnotation(const llvm::Function *pFunction);
+  const DxilFunctionAnnotation *GetFunctionAnnotation(const llvm::Function *pFunction) const;
   void EraseFunctionAnnotation(const llvm::Function *pFunction);
 
   FunctionAnnotationMap &GetFunctionAnnotationMap();
@@ -180,6 +182,12 @@ public:
   llvm::StructType *GetSNormF32Type(unsigned NumComps);
   llvm::StructType *GetUNormF32Type(unsigned NumComps);
 
+  // Methods to copy annotation from another DxilTypeSystem.
+  void CopyTypeAnnotation(const llvm::Type *Ty, const DxilTypeSystem &src);
+  void CopyFunctionAnnotation(const llvm::Function *pDstFunction,
+                              const llvm::Function *pSrcFunction,
+                              const DxilTypeSystem &src);
+
 private:
   llvm::Module *m_pModule;
   StructAnnotationMap m_StructAnnotations;

+ 3 - 3
lib/HLSL/DxilOperations.cpp

@@ -422,11 +422,11 @@ OP::OP(LLVMContext &Ctx, Module *pModule)
   Type *Int4Types[4] = { Type::getInt32Ty(m_Ctx), Type::getInt32Ty(m_Ctx), Type::getInt32Ty(m_Ctx), Type::getInt32Ty(m_Ctx) }; // HiHi, HiLo, LoHi, LoLo
   m_pInt4Type = GetOrCreateStructType(m_Ctx, Int4Types, "dx.types.fouri32", pModule);
   // Try to find existing intrinsic function.
-  RefreshCache(pModule);
+  RefreshCache();
 }
 
-void OP::RefreshCache(llvm::Module *pModule) {
-  for (Function &F : pModule->functions()) {
+void OP::RefreshCache() {
+  for (Function &F : m_pModule->functions()) {
     if (OP::IsDxilOpFunc(&F) && !F.user_empty()) {
       CallInst *CI = cast<CallInst>(*F.user_begin());
       OpCode OpCode = OP::GetDxilOpFuncCallInst(CI);

+ 21 - 0
lib/HLSL/DxilSignature.cpp

@@ -28,6 +28,19 @@ DxilSignature::DxilSignature(DXIL::ShaderKind shaderKind, DXIL::SignatureKind si
 DxilSignature::DxilSignature(DXIL::SigPointKind sigPointKind)
 : m_sigPointKind(sigPointKind) {}
 
+DxilSignature::DxilSignature(const DxilSignature &src)
+    : m_sigPointKind(src.m_sigPointKind) {
+  const bool bSetID = false;
+  for (auto &Elt : src.GetElements()) {
+    std::unique_ptr<DxilSignatureElement> newElt = CreateElement();
+    newElt->Initialize(Elt->GetName(), Elt->GetCompType(),
+                       Elt->GetInterpolationMode()->GetKind(), Elt->GetRows(),
+                       Elt->GetCols(), Elt->GetStartRow(), Elt->GetStartCol(),
+                       Elt->GetID(), Elt->GetSemanticIndexVec());
+    AppendElement(std::move(newElt), bSetID);
+  }
+}
+
 DxilSignature::~DxilSignature() {
 }
 
@@ -199,6 +212,14 @@ unsigned DxilSignature::PackElements(DXIL::PackingStrategy packing) {
   return rowsUsed;
 }
 
+//------------------------------------------------------------------------------
+//
+// EntrySingnature methods.
+//
+DxilEntrySignature::DxilEntrySignature(const DxilEntrySignature &src)
+    : InputSignature(src.InputSignature), OutputSignature(src.OutputSignature),
+      PatchConstantSignature(src.PatchConstantSignature) {}
+
 } // namespace hlsl
 
 #include <algorithm>

+ 72 - 0
lib/HLSL/DxilTypeSystem.cpp

@@ -182,6 +182,16 @@ DxilStructAnnotation *DxilTypeSystem::GetStructAnnotation(const StructType *pStr
   }
 }
 
+const DxilStructAnnotation *
+DxilTypeSystem::GetStructAnnotation(const StructType *pStructType) const {
+  auto it = m_StructAnnotations.find(pStructType);
+  if (it != m_StructAnnotations.end()) {
+    return it->second.get();
+  } else {
+    return nullptr;
+  }
+}
+
 void DxilTypeSystem::EraseStructAnnotation(const StructType *pStructType) {
   DXASSERT_NOMSG(m_StructAnnotations.count(pStructType));
   m_StructAnnotations.remove_if([pStructType](
@@ -211,6 +221,16 @@ DxilFunctionAnnotation *DxilTypeSystem::GetFunctionAnnotation(const Function *pF
   }
 }
 
+const DxilFunctionAnnotation *
+DxilTypeSystem::GetFunctionAnnotation(const Function *pFunction) const {
+  auto it = m_FunctionAnnotations.find(pFunction);
+  if (it != m_FunctionAnnotations.end()) {
+    return it->second.get();
+  } else {
+    return nullptr;
+  }
+}
+
 void DxilTypeSystem::EraseFunctionAnnotation(const Function *pFunction) {
   DXASSERT_NOMSG(m_FunctionAnnotations.count(pFunction));
   m_FunctionAnnotations.remove_if([pFunction](
@@ -253,6 +273,58 @@ StructType *DxilTypeSystem::GetNormFloatType(CompType CT, unsigned NumComps) {
   return pStructType;
 }
 
+void DxilTypeSystem::CopyTypeAnnotation(const llvm::Type *Ty,
+                                        const DxilTypeSystem &src) {
+  if (isa<PointerType>(Ty))
+    Ty = Ty->getPointerElementType();
+
+  while (isa<ArrayType>(Ty))
+    Ty = Ty->getArrayElementType();
+
+  // Only struct type has annotation.
+  if (!isa<StructType>(Ty))
+    return;
+
+  const StructType *ST = cast<StructType>(Ty);
+  // Already exist.
+  if (GetStructAnnotation(ST))
+    return;
+
+  if (const DxilStructAnnotation *annot = src.GetStructAnnotation(ST)) {
+    DxilStructAnnotation *dstAnnot = AddStructAnnotation(ST);
+    // Copy the annotation.
+    *dstAnnot = *annot;
+    // Copy field type annotations.
+    for (Type *Ty : ST->elements()) {
+      CopyTypeAnnotation(Ty, src);
+    }
+  }
+}
+
+void DxilTypeSystem::CopyFunctionAnnotation(const llvm::Function *pDstFunction,
+                                            const llvm::Function *pSrcFunction,
+                                            const DxilTypeSystem &src) {
+  const DxilFunctionAnnotation *annot = src.GetFunctionAnnotation(pSrcFunction);
+  // Don't have annotation.
+  if (!annot)
+    return;
+  // Already exist.
+  if (GetFunctionAnnotation(pDstFunction))
+    return;
+
+  DxilFunctionAnnotation *dstAnnot = AddFunctionAnnotation(pDstFunction);
+
+  // Copy the annotation.
+  *dstAnnot = *annot;
+
+  // Clone ret type annotation.
+  CopyTypeAnnotation(pDstFunction->getReturnType(), src);
+  // Clone param type annotations.
+  for (const Argument &arg : pDstFunction->args()) {
+    CopyTypeAnnotation(arg.getType(), src);
+  }
+}
+
 DXIL::SigPointKind SigPointFromInputQual(DxilParamInputQual Q, DXIL::ShaderKind SK, bool isPC) {
   DXASSERT(Q != DxilParamInputQual::Inout, "Inout not expected for SigPointFromInputQual");
   switch (SK) {