ソースを参照

Merge types in linker, and recognize class.matrix.* type for library. (#1033)

Tex Riddell 8 年 前
コミット
c225dead63

+ 2 - 0
include/dxc/HLSL/DxilOperations.h

@@ -94,6 +94,7 @@ public:
   static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
   static bool IsDxilOpWave(OpCode C);
   static bool IsDxilOpGradient(OpCode C);
+  static bool IsDxilOpTypeName(llvm::StringRef name);
   static bool IsDxilOpType(llvm::StructType *ST);
   static bool IsDupDxilOpType(llvm::StructType *ST);
   static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST,
@@ -140,6 +141,7 @@ private:
   static const char *m_OverloadTypeName[kNumTypeOverloads];
   static const char *m_NamePrefix;
   static const char *m_TypePrefix;
+  static const char *m_MatrixTypePrefix;
   static unsigned GetTypeSlot(llvm::Type *pType);
   static const char *GetOverloadTypeName(unsigned TypeSlot);
 };

+ 5 - 3
lib/Bitcode/Reader/BitcodeReader.cpp

@@ -1486,8 +1486,8 @@ std::error_code BitcodeReader::parseTypeTableBody() {
         Res = createIdentifiedStructType(Context, TypeName);
       // HLSL Change Begin - avoid name collision for dxil types.
       bool bNameCollision = Res->getName().size() > TypeName.size();
+      //TypeName.clear();
       // HLSL Change End.
-      TypeName.clear();
 
       SmallVector<Type*, 8> EltTys;
       for (unsigned i = 1, e = Record.size(); i != e; ++i) {
@@ -1501,10 +1501,12 @@ std::error_code BitcodeReader::parseTypeTableBody() {
       Res->setBody(EltTys, Record[0]);
       // HLSL Change Begin - avoid name collision for dxil types.
       if (bNameCollision) {
-        if (hlsl::OP::IsDupDxilOpType(Res)) {
-          Res = hlsl::OP::GetOriginalDxilOpType(Res, *TheModule);
+        StructType *otherType = TheModule->getTypeByName(TypeName);
+        if (otherType->isLayoutIdentical(Res)) {
+          Res = otherType;
         }
       }
+      TypeName.clear();
       // HLSL Change End.
       ResultTy = Res;
       break;

+ 7 - 2
lib/HLSL/DxilOperations.cpp

@@ -273,6 +273,7 @@ const char *OP::m_OverloadTypeName[kNumTypeOverloads] = {
 
 const char *OP::m_NamePrefix = "dx.op.";
 const char *OP::m_TypePrefix = "dx.types.";
+const char *OP::m_MatrixTypePrefix = "class.matrix."; // Allowed in library
 
 // Keep sync with DXIL::AtomicBinOpCode
 static const char *AtomicBinOpCodeName[] = {
@@ -363,18 +364,22 @@ bool OP::IsDxilOpFunc(const llvm::Function *F) {
   return IsDxilOpFuncName(F->getName());
 }
 
+bool OP::IsDxilOpTypeName(StringRef name) {
+  return name.startswith(m_TypePrefix) || name.startswith(m_MatrixTypePrefix);
+}
+
 bool OP::IsDxilOpType(llvm::StructType *ST) {
   if (!ST->hasName())
     return false;
   StringRef Name = ST->getName();
-  return Name.startswith(m_TypePrefix);
+  return IsDxilOpTypeName(Name);
 }
 
 bool OP::IsDupDxilOpType(llvm::StructType *ST) {
   if (!ST->hasName())
     return false;
   StringRef Name = ST->getName();
-  if (!Name.startswith(m_TypePrefix))
+  if (!IsDxilOpTypeName(Name))
     return false;
   size_t DotPos = Name.rfind('.');
   if (DotPos == 0 || DotPos == StringRef::npos || Name.back() == '.' ||