Explorar el Código

Resolve name collison for dxil types in BitcodeReader. (#411)

Resolve name collison for dxil types in BitcodeReader.
Xiang Li hace 8 años
padre
commit
f0c3cbd454

+ 5 - 0
include/dxc/HLSL/DxilOperations.h

@@ -15,6 +15,7 @@ namespace llvm {
 class LLVMContext;
 class Module;
 class Type;
+class StructType;
 class Function;
 class Constant;
 class Value;
@@ -86,6 +87,9 @@ public:
   static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
   static bool IsDxilOpWave(OpCode C);
   static bool IsDxilOpGradient(OpCode C);
+  static bool IsDupDxilOpType(llvm::StructType *ST);
+  static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST,
+                                                 llvm::Module &M);
 
 private:
   // Per-module properties.
@@ -125,6 +129,7 @@ private:
 
   static const char *m_OverloadTypeName[kNumTypeOverloads];
   static const char *m_NamePrefix;
+  static const char *m_TypePrefix;
   static unsigned GetTypeSlot(llvm::Type *pType);
   static const char *GetOverloadTypeName(unsigned TypeSlot);
 };

+ 11 - 0
lib/Bitcode/Reader/BitcodeReader.cpp

@@ -34,6 +34,7 @@
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/raw_ostream.h"
 #include <deque>
+#include "dxc/HLSL/DxilOperations.h"   // HLSL Change
 using namespace llvm;
 
 namespace {
@@ -1483,6 +1484,9 @@ std::error_code BitcodeReader::parseTypeTableBody() {
         TypeList[NumRecords] = nullptr;
       } else  // Otherwise, create a new struct.
         Res = createIdentifiedStructType(Context, TypeName);
+      // HLSL Change Begin - avoid name collision for dxil types.
+      bool bNameCollision = Res->getName().size() > TypeName.size();
+      // HLSL Change End.
       TypeName.clear();
 
       SmallVector<Type*, 8> EltTys;
@@ -1495,6 +1499,13 @@ std::error_code BitcodeReader::parseTypeTableBody() {
       if (EltTys.size() != Record.size()-1)
         return error("Invalid record");
       Res->setBody(EltTys, Record[0]);
+      // HLSL Change Begin - avoid name collision for dxil types.
+      if (bNameCollision) {
+        if (hlsl::OP::IsDupDxilOpType(Res)) {
+          Res = hlsl::OP::GetOriginalDxilOpType(Res, *TheModule);
+        }
+      }
+      // HLSL Change End.
       ResultTy = Res;
       break;
     }

+ 25 - 0
lib/HLSL/DxilOperations.cpp

@@ -252,6 +252,7 @@ const char *OP::m_OverloadTypeName[kNumTypeOverloads] = {
 };
 
 const char *OP::m_NamePrefix = "dx.op.";
+const char *OP::m_TypePrefix = "dx.types.";
 
 // Keep sync with DXIL::AtomicBinOpCode
 static const char *AtomicBinOpCodeName[] = {
@@ -340,6 +341,30 @@ bool OP::IsDxilOpFunc(const llvm::Function *F) {
   return IsDxilOpFuncName(F->getName());
 }
 
+bool OP::IsDupDxilOpType(llvm::StructType *ST) {
+  if (!ST->hasName())
+    return false;
+  StringRef Name = ST->getName();
+  if (!Name.startswith(m_TypePrefix))
+    return false;
+  size_t DotPos = Name.rfind('.');
+  if (DotPos == 0 || DotPos == StringRef::npos || Name.back() == '.' ||
+      !isdigit(static_cast<unsigned char>(Name[DotPos + 1])))
+    return false;
+  return true;
+}
+
+StructType *OP::GetOriginalDxilOpType(llvm::StructType *ST, llvm::Module &M) {
+  DXASSERT(IsDupDxilOpType(ST), "else should not call GetOriginalDxilOpType");
+  StringRef Name = ST->getName();
+  size_t DotPos = Name.rfind('.');
+  StructType *OriginalST = M.getTypeByName(Name.substr(0, DotPos));
+  DXASSERT(OriginalST, "else name collison without original type");
+  DXASSERT(ST->isLayoutIdentical(OriginalST),
+           "else invalid layout for dxil types");
+  return OriginalST;
+}
+
 bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I) {
   const CallInst *CI = dyn_cast<CallInst>(I);
   if (CI == nullptr) return false;

+ 5 - 1
tools/clang/test/CodeGenHLSL/lib_cs_entry.hlsl

@@ -17,6 +17,10 @@
 // Make sure function props is correct for [numthreads(8,8,1)].
 // CHECK: @entry, i32 5, i32 8, i32 8, i32 1
 
+cbuffer A {
+  float a;
+}
+
 void StoreOutputMat(float2x2  m, uint gidx);
 float2x2 LoadInputMat(uint x, uint y);
 float2x2 RotateMat(float2x2 m, uint x, uint y);
@@ -26,7 +30,7 @@ void entry( uint2 tid : SV_DispatchThreadID, uint2 gid : SV_GroupID, uint2 gtid
 {
     float2x2 f2x2 = LoadInputMat(gid.x, gid.y);
 
-    f2x2 = RotateMat(f2x2, tid.x, tid.y);
+    f2x2 = RotateMat(f2x2, tid.x, tid.y) + a;
 
     StoreOutputMat(f2x2, gidx);
 }

+ 5 - 1
tools/clang/test/CodeGenHLSL/lib_resource2.hlsl

@@ -21,11 +21,15 @@ float2x2 LoadInputMat(uint x, uint y) {
   return mats.Load(x).f2x2 + mats2.Load(y);
 }
 
+cbuffer B {
+  float b;
+}
+
 groupshared column_major float2x2 dataC[8*8];
 
 float2x2 RotateMat(float2x2 m, uint x, uint y) {
     dataC[x%(8*8)] = m;
     GroupMemoryBarrierWithGroupSync();
     float2x2 f2x2 = dataC[8*8-1-y%(8*8)];
-    return f2x2;
+    return f2x2 + b;
 }

+ 14 - 6
tools/clang/tools/dxlib-sample/dxlib_sample.cpp

@@ -19,6 +19,7 @@
 #include "dxc/dxctools.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "llvm/Support/Path.h"
+#include <cwchar>
 
 using namespace hlsl;
 using namespace llvm;
@@ -135,6 +136,11 @@ HRESULT IncludeToLibPreprocessor::Preprocess(IDxcBlob *pSource,
 
   HRESULT status;
   if (!SUCCEEDED(pRewriteResult->GetStatus(&status)) || !SUCCEEDED(status)) {
+    CComPtr<IDxcBlobEncoding> pErr;
+    IFT(pRewriteResult->GetErrorBuffer(&pErr));
+    std::string errString =
+        std::string((char *)pErr->GetBufferPointer(), pErr->GetBufferSize());
+    IFTMSG(E_FAIL, errString);
     return E_FAIL;
   };
   // Append existing header.
@@ -226,20 +232,22 @@ private:
   std::shared_mutex m_mutex;
 };
 
+static hash_code CombineWStr(hash_code hash, LPCWSTR Arg) {
+  unsigned length = std::wcslen(Arg)*2;
+  return hash_combine(hash, StringRef((char*)(Arg), length));
+}
+
 hash_code LibCacheManager::GetHash(IDxcBlob *pSource, CompileInput &compiler) {
   hash_code libHash = hash_value(
       StringRef((char *)pSource->GetBufferPointer(), pSource->GetBufferSize()));
   // Combine compile input.
   for (auto &Arg : compiler.arguments) {
-    CW2A pUtf8Arg(Arg, CP_UTF8);
-    libHash = hash_combine(libHash, pUtf8Arg.m_psz);
+    libHash = CombineWStr(libHash, Arg);
   }
   for (auto &Define : compiler.defines) {
-    CW2A pUtf8Name(Define.Name, CP_UTF8);
-    libHash = hash_combine(libHash, pUtf8Name.m_psz);
+    libHash = CombineWStr(libHash, Define.Name);
     if (Define.Value) {
-      CW2A pUtf8Value(Define.Value, CP_UTF8);
-      libHash = hash_combine(libHash, pUtf8Value.m_psz);
+      libHash = CombineWStr(libHash, Define.Value);
     }
   }
   return libHash;