Ver código fonte

Merged PR 25: Fix res IDs after pruning in DxilLowerCreateHandleForLib

Fix res IDs after pruning in DxilLowerCreateHandleForLib
Tex Riddell 7 anos atrás
pai
commit
62ab8fc3d2

+ 20 - 10
lib/HLSL/DxilCondenseResources.cpp

@@ -463,22 +463,29 @@ public:
     m_DM->ClearLLVMUsed();
     m_bIsLib = DM.GetShaderModel()->IsLib();
 
+    bool bChanged = false;
+    unsigned numResources = DM.GetCBuffers().size() + DM.GetUAVs().size() +
+                            DM.GetSRVs().size() + DM.GetSamplers().size();
+
+    if (!numResources)
+      return false;
+
     // Switch tbuffers to SRVs, as they have been treated as cbuffers up to this
     // point.
     if (DM.GetCBuffers().size())
-      PatchTBuffers(DM);
+      bChanged = PatchTBuffers(DM) || bChanged;
 
     // Remove unused resource.
     DM.RemoveUnusedResourceSymbols();
 
-    bool hasResource = DM.GetCBuffers().size() || DM.GetUAVs().size() ||
-                       DM.GetSRVs().size() || DM.GetSamplers().size();
+    unsigned newResources = DM.GetCBuffers().size() + DM.GetUAVs().size() +
+                            DM.GetSRVs().size() + DM.GetSamplers().size();
+    bChanged = bChanged || (numResources != newResources);
 
-    if (!hasResource || m_bIsLib)
-      return false;
+    if (0 == newResources || m_bIsLib)
+      return bChanged;
 
-    BuildRewriteMap(m_rewrites, DM);
-    ApplyRewriteMapOnResTable(m_rewrites, DM);
+    bChanged = true;
 
     // Load up debug information, to cross-reference values and the instructions
     // used to load them.
@@ -498,7 +505,7 @@ public:
     dxilutil::RemoveUnusedFunctions(M, DM.GetEntryFunction(),
                                     DM.GetPatchConstantFunction(), m_bIsLib);
 
-    return true;
+    return bChanged;
   }
 
 private:
@@ -508,7 +515,7 @@ private:
   void AddCreateHandleForPhiNodeAndSelect(OP *hlslOP);
   void UpdateStructTypeForLegacyLayout();
   // Switch CBuffer for SRV for TBuffers.
-  void PatchTBuffers(DxilModule &DM);
+  bool PatchTBuffers(DxilModule &DM);
   void PatchTBufferUse(Value *V, DxilModule &DM);
 };
 
@@ -1032,7 +1039,8 @@ void DxilLowerCreateHandleForLib::PatchTBufferUse(Value *V, DxilModule &DM) {
   }
 }
 
-void DxilLowerCreateHandleForLib::PatchTBuffers(DxilModule &DM) {
+bool DxilLowerCreateHandleForLib::PatchTBuffers(DxilModule &DM) {
+  bool bChanged = false;
   // move tbuffer resources to SRVs
   unsigned offset = DM.GetSRVs().size();
   Module &M = *DM.GetModule();
@@ -1054,8 +1062,10 @@ void DxilLowerCreateHandleForLib::PatchTBuffers(DxilModule &DM) {
           /*InsertBefore*/ nullptr, GV->getThreadLocalMode(),
           GV->getType()->getAddressSpace(), GV->isExternallyInitialized());
       CB->SetGlobalSymbol(NewGV);
+      bChanged = true;
     }
   }
+  return bChanged;
 }
 
 // Select on handle.

+ 6 - 0
lib/HLSL/DxilModule.cpp

@@ -998,6 +998,7 @@ void DxilModule::RemoveUnusedResources() {
 namespace {
 template <typename TResource>
 static void RemoveResourceSymbols(std::vector<std::unique_ptr<TResource>> &vec) {
+  unsigned resID = 0;
   for (std::vector<std::unique_ptr<TResource>>::iterator p = vec.begin(); p != vec.end();) {
     std::vector<std::unique_ptr<TResource>>::iterator c = p++;
     GlobalVariable *GV = cast<GlobalVariable>((*c)->GetGlobalSymbol());
@@ -1005,7 +1006,12 @@ static void RemoveResourceSymbols(std::vector<std::unique_ptr<TResource>> &vec)
     if (GV->user_empty()) {
       p = vec.erase(c);
       GV->eraseFromParent();
+      continue;
+    }
+    if ((*c)->GetID() != resID) {
+      (*c)->SetID(resID);
     }
+    resID++;
   }
 }
 }

+ 22 - 0
tools/clang/test/CodeGenHLSL/quick-test/lib_remove_res.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -T lib_6_3 %s | FileCheck %s
+
+// Ensure UnusedBuffer is removed:
+// CHECK-NOT: @"\01?UnusedBuffer@@3UByteAddressBuffer@@A"
+
+// Ensure resource ID is 0 for ReadBuffer1 after UnusedBuffer global is removed.
+// CHECK: !{i32 0, %struct.ByteAddressBuffer* @"\01?ReadBuffer1@@3UByteAddressBuffer@@A", !"ReadBuffer1",
+
+RWByteAddressBuffer outputBuffer : register(u0);
+ByteAddressBuffer UnusedBuffer : register(t0);
+ByteAddressBuffer ReadBuffer1 : register(t1);
+
+void test()
+{
+  ByteAddressBuffer buffer = UnusedBuffer;
+
+  if (true)
+     buffer = ReadBuffer1;
+
+  uint v = buffer.Load(0);
+  outputBuffer.Store(0, v);
+}