瀏覽代碼

remove-dead-blocks handles switch. Split resource array before finalize. (#3184)

Adam Yang 4 年之前
父節點
當前提交
e6b313b6ab

+ 3 - 0
include/dxc/DXIL/DxilUtil.h

@@ -35,6 +35,8 @@ class PassRegistry;
 class DebugInfoFinder;
 class DebugLoc;
 class DIGlobalVariable;
+class ConstantInt;
+class SwitchInst;
 
 ModulePass *createDxilLoadMetadataPass();
 void initializeDxilLoadMetadataPass(llvm::PassRegistry&);
@@ -113,6 +115,7 @@ namespace dxilutil {
                                       unsigned startOpIdx,
                                       unsigned numOperands);
   bool SimplifyTrivialPHIs(llvm::BasicBlock *BB);
+  llvm::BasicBlock *GetSwitchSuccessorForCond(llvm::SwitchInst *Switch, llvm::ConstantInt *Cond);
   void MigrateDebugValue(llvm::Value *Old, llvm::Value *New);
   void TryScatterDebugValueToVectorElements(llvm::Value *Val);
   std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,

+ 10 - 0
lib/DXIL/DxilUtil.cpp

@@ -430,6 +430,16 @@ bool SimplifyTrivialPHIs(BasicBlock *BB) {
   return Changed;
 }
 
+llvm::BasicBlock *GetSwitchSuccessorForCond(llvm::SwitchInst *Switch,llvm::ConstantInt *Cond) {
+  for (auto it = Switch->case_begin(), end = Switch->case_end(); it != end; it++) {
+    if (it.getCaseValue() == Cond) {
+      return it.getCaseSuccessor();
+      break;
+    }
+  }
+  return Switch->getDefaultDest();
+}
+
 static DbgValueInst *FindDbgValueInst(Value *Val) {
   if (auto *ValAsMD = LocalAsMetadata::getIfExists(Val)) {
     if (auto *ValMDAsVal = MetadataAsValue::getIfExists(Val->getContext(), ValAsMD)) {

+ 116 - 5
lib/HLSL/DxilCondenseResources.cpp

@@ -428,10 +428,123 @@ ModulePass *llvm::createDxilCondenseResourcesPass() {
 
 INITIALIZE_PASS(DxilCondenseResources, "hlsl-dxil-condense", "DXIL Condense Resources", false, false)
 
-static
-bool LegalizeResourcesPHIs(Module &M, DxilValueCache *DVC) {
+static bool GetConstantLegalGepForSplitAlloca(GetElementPtrInst *gep, DxilValueCache *DVC, int64_t *ret) {
+  if (gep->getNumIndices() != 2) {
+    return false;
+  }
+
+  if (ConstantInt *Index0 = dyn_cast<ConstantInt>(gep->getOperand(1))) {
+    if (Index0->getLimitedValue() != 0) {
+      return false;
+    }
+  }
+  else {
+    return false;
+  }
+
+  if (ConstantInt *C = DVC->GetConstInt(gep->getOperand(2))) {
+    int64_t index = C->getSExtValue();
+    *ret = index;
+    return true;
+  }
+
+  return false;
+}
+
+static bool LegalizeResourceArrays(Module &M, DxilValueCache *DVC) {
+  SmallVector<AllocaInst *,16> Allocas;
+
+  bool Changed = false;
+
+  // Find all allocas
+  for (Function &F : M) {
+    if (F.empty())
+      continue;
+
+    BasicBlock &BB = F.getEntryBlock();
+    for (Instruction &I : BB) {
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
+        Type *ty = AI->getAllocatedType();
+        // Only handle single dimentional array. Since this pass runs after MultiDimArrayToOneDimArray,
+        // it should handle all arrays.
+        if (ty->isArrayTy() && hlsl::dxilutil::IsHLSLResourceType(ty->getArrayElementType()))
+          Allocas.push_back(AI);
+      }
+    }
+  }
+
+  SmallVector<AllocaInst *,16> ScalarAllocas;
+  std::unordered_map<GetElementPtrInst *, int64_t> ConstIndices;
+
+  for (AllocaInst *AI : Allocas) {
+    Type *ty = AI->getAllocatedType();
+    Type *resType = ty->getArrayElementType();
+
+    ScalarAllocas.clear();
+    ConstIndices.clear();
+
+    bool SplitAlloca = true;
+
+    for (User *U : AI->users()) {
+      if (GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(U)) {
+        int64_t index = 0;
+        if (!GetConstantLegalGepForSplitAlloca(gep, DVC, &index)) {
+          SplitAlloca = false;
+          break;
+        }
+
+        // Out of bounds. Out of bounds GEP's will trigger and error later.
+        if (index < 0 || index >= (int64_t)ty->getArrayNumElements()) {
+          SplitAlloca = false;
+          Changed = true;
+          dxilutil::EmitErrorOnInstruction(gep, "Accessing resource array with out-out-bounds index.");
+        }
+        ConstIndices[gep] = index;
+      }
+      else {
+        SplitAlloca = false;
+        break;
+      }
+    }
+
+    if (SplitAlloca) {
+
+      IRBuilder<> B(AI);
+      ScalarAllocas.resize(ty->getArrayNumElements());
+
+      for (auto it = AI->user_begin(),end = AI->user_end(); it != end;) {
+        GetElementPtrInst *gep = cast<GetElementPtrInst>(*(it++));
+        assert(ConstIndices.count(gep));
+        int64_t idx = ConstIndices[gep];
+
+        AllocaInst *ScalarAI = ScalarAllocas[idx];
+        if (!ScalarAI) {
+          ScalarAI = B.CreateAlloca(resType);
+          ScalarAllocas[idx] = ScalarAI;
+        }
+
+        gep->replaceAllUsesWith(ScalarAI);
+        gep->eraseFromParent();
+      }
+
+      AI->eraseFromParent();
+
+      Changed = true;
+    }
+  }
+
+  return Changed;
+}
+
+static bool LegalizeResources(Module &M, DxilValueCache *DVC) {
+
+  bool Changed = false;
+
+  Changed |= LegalizeResourceArrays(M, DVC);
+
   // Simple pass to collect resource PHI's
   SmallVector<PHINode *, 8> PHIs;
+
   for (Function &F : M) {
     for (BasicBlock &BB : F) {
       for (Instruction &I : BB) {
@@ -443,12 +556,10 @@ bool LegalizeResourcesPHIs(Module &M, DxilValueCache *DVC) {
         else {
           break;
         }
-
       }
     }
   }
 
-  bool Changed = false;
 
   SmallVector<Instruction *, 8> DCEWorklist;
 
@@ -572,7 +683,7 @@ public:
 
     {
       DxilValueCache *DVC = &getAnalysis<DxilValueCache>();
-      bool bLocalChanged = LegalizeResourcesPHIs(M, DVC);
+      bool bLocalChanged = LegalizeResources(M, DVC);
       if (bLocalChanged) {
         // Remove unused resources.
         DM.RemoveResourcesWithUnusedSymbols();

+ 30 - 3
lib/Transforms/Scalar/DxilRemoveDeadBlocks.cpp

@@ -19,6 +19,7 @@
 #include "llvm/Analysis/DxilValueCache.h"
 
 #include "dxc/DXIL/DxilMetadataHelper.h"
+#include "dxc/DXIL/DxilUtil.h"
 
 #include <unordered_set>
 
@@ -34,7 +35,6 @@ static void RemoveIncomingValueFrom(BasicBlock *SuccBB, BasicBlock *BB) {
   }
 }
 
-
 static bool EraseDeadBlocks(Function &F, DxilValueCache *DVC) {
   std::unordered_set<BasicBlock *> Seen;
   std::vector<BasicBlock *> WorkList;
@@ -88,8 +88,33 @@ static bool EraseDeadBlocks(Function &F, DxilValueCache *DVC) {
       }
     }
     else if (SwitchInst *Switch = dyn_cast<SwitchInst>(BB->getTerminator())) {
-      for (unsigned i = 0; i < Switch->getNumSuccessors(); i++) {
-        Add(Switch->getSuccessor(i));
+      Value *Cond = Switch->getCondition();
+      BasicBlock *Succ = nullptr;
+      if (ConstantInt *ConstCond = DVC->GetConstInt(Cond)) {
+        Succ = hlsl::dxilutil::GetSwitchSuccessorForCond(Switch, ConstCond);
+      }
+
+      if (Succ) {
+        Add(Succ);
+
+        BranchInst *NewBr = BranchInst::Create(Succ, BB);
+        hlsl::DxilMDHelper::CopyMetadata(*NewBr, *Switch);
+
+        for (unsigned i = 0; i < Switch->getNumSuccessors(); i++) {
+          BasicBlock *NotSucc = Switch->getSuccessor(i);
+          if (NotSucc != Succ) {
+            RemoveIncomingValueFrom(NotSucc, BB);
+          }
+        }
+
+        Switch->eraseFromParent();
+        Switch = nullptr;
+        Changed = true;
+      }
+      else {
+        for (unsigned i = 0; i < Switch->getNumSuccessors(); i++) {
+          Add(Switch->getSuccessor(i));
+        }
       }
     }
   }
@@ -145,6 +170,8 @@ static bool EraseDeadBlocks(Function &F, DxilValueCache *DVC) {
     BB->eraseFromParent();
   }
 
+  DVC->ResetUnknowns();
+
   return true;
 }
 

+ 38 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_remove_dead_pass/switch_array.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc /T ps_6_0 /Od %s | FileCheck %s
+
+// CHECK: @main
+
+Texture2D t0 : register(t0);
+Texture2D t1 : register(t1);
+Texture2D t2 : register(t2);
+Texture2D t3 : register(t3);
+Texture2D t4 : register(t4);
+Texture2D t5 : register(t5);
+Texture2D t6 : register(t6);
+
+
+static Texture2D textures[] = {
+  t0, t1, t2, t3, t4, t5, t6,
+};
+
+Texture2D foo(uint i) {
+  switch (i) {
+    case 0:
+      return textures[0];
+    case 1:
+      return textures[1];
+    case 2:
+      return textures[2];
+    case 3:
+      return textures[3];
+    case 4:
+      return textures[4];
+  }
+  return textures[0];
+}
+
+float main(uint3 off : OFF) : SV_Target {
+  return foo(3).Load(off);
+}
+
+

+ 38 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_remove_dead_pass/switch_index.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc /T ps_6_0 /Od %s | FileCheck %s
+
+// CHECK: @main
+
+Texture2D t0 : register(t0);
+Texture2D t1 : register(t1);
+Texture2D t2 : register(t2);
+Texture2D t3 : register(t3);
+Texture2D t4 : register(t4);
+Texture2D t5 : register(t5);
+Texture2D t6 : register(t6);
+
+float foo(uint i) {
+  switch (i) {
+    case 0:
+      return 1;
+    case 1:
+      return 2;
+    case 2:
+      return 3;
+    case 3:
+      return 4;
+    case 4:
+      return 5;
+  }
+  return 0;
+}
+
+float main(uint3 off : OFF) : SV_Target {
+
+  Texture2D textures[] = {
+    t0, t1, t2, t3, t4, t5, t6,
+  };
+
+  return textures[foo(3)].Load(off);
+}
+
+

+ 38 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_remove_dead_pass/switch_index_default.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc /T ps_6_0 /Od %s | FileCheck %s
+
+// CHECK: @main
+
+Texture2D t0 : register(t0);
+Texture2D t1 : register(t1);
+Texture2D t2 : register(t2);
+Texture2D t3 : register(t3);
+Texture2D t4 : register(t4);
+Texture2D t5 : register(t5);
+Texture2D t6 : register(t6);
+
+float foo(uint i) {
+  switch (i) {
+    case 0:
+      return 1;
+    case 1:
+      return 2;
+    case 2:
+      return 3;
+    case 3:
+      return 4;
+    case 4:
+      return 5;
+  }
+  return 0;
+}
+
+float main(uint3 off : OFF) : SV_Target {
+
+  Texture2D textures[] = {
+    t0, t1, t2, t3, t4, t5, t6,
+  };
+
+  return textures[foo(10)].Load(off);
+}
+
+

+ 38 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_remove_dead_pass/switch_oob.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc /T ps_6_0 /Od %s | FileCheck %s
+
+// CHECK: Accessing resource array with out-out-bounds index.
+
+Texture2D t0 : register(t0);
+Texture2D t1 : register(t1);
+Texture2D t2 : register(t2);
+Texture2D t3 : register(t3);
+Texture2D t4 : register(t4);
+Texture2D t5 : register(t5);
+Texture2D t6 : register(t6);
+
+float foo(uint i) {
+  switch (i) {
+    case 0:
+      return 1;
+    case 1:
+      return 2;
+    case 2:
+      return 3;
+    case 3:
+      return 10;
+    case 4:
+      return 5;
+  }
+  return 0;
+}
+
+float4 main(uint3 off : OFF) : SV_Target {
+
+  Texture2D textures[] = {
+    t0, t1, t2, t3, t4, t5, t6,
+  };
+
+  return textures[foo(3)].Load(off);
+}
+
+

+ 37 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_remove_dead_pass/switch_res.hlsl

@@ -0,0 +1,37 @@
+// RUN: %dxc /T ps_6_0 /Od %s | FileCheck %s
+
+// CHECK: @main
+
+Texture2D t0 : register(t0);
+Texture2D t1 : register(t1);
+Texture2D t2 : register(t2);
+Texture2D t3 : register(t3);
+Texture2D t4 : register(t4);
+Texture2D t5 : register(t5);
+Texture2D t6 : register(t6);
+
+Texture2D foo(uint i) {
+  switch (i) {
+    case 0:
+      return t0;
+    case 1:
+      return t1;
+    case 2:
+      return t2;
+    case 3:
+      return t3;
+    case 4:
+      return t4;
+    case 5:
+      return t5;
+    case 6:
+      return t6;
+  }
+  return t0;
+}
+
+float main(uint3 off : OFF) : SV_Target {
+  return foo(5).Load(off);
+}
+
+