Просмотр исходного кода

Unroll after SROA. Breaking apart resource arrays inside unroll pass (#1883)

Adam Yang 6 лет назад
Родитель
Сommit
830cbe0113

+ 14 - 9
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -219,10 +219,6 @@ static void addHLSLPasses(bool HLSLHighLevel, unsigned OptLevel, hlsl::HLSLExten
     MPM.add(createHLDeadFunctionEliminationPass());
   }
 
-  // Passes to handle [unroll]
-  MPM.add(createLoopRotatePass());
-  MPM.add(createDxilLoopUnrollPass(/*MaxIterationAttempt*/ 128));
-
   // Split struct and array of parameter.
   MPM.add(createSROA_Parameter_HLSL());
 
@@ -252,15 +248,24 @@ static void addHLSLPasses(bool HLSLHighLevel, unsigned OptLevel, hlsl::HLSLExten
     MPM.add(createDxilConvergentMarkPass());
   }
 
-  if (OptLevel > 2) {
-    MPM.add(createLoopRotatePass());
-    MPM.add(createLoopUnrollPass());
-  }
-
   MPM.add(createSimplifyInstPass());
 
   MPM.add(createCFGSimplificationPass());
 
+  // Passes to handle [unroll]
+  // Needs to happen after SROA since loop count may depend on
+  // struct members.
+  // Needs to happen before resources are lowered and before HL
+  // module is gone.
+  MPM.add(createLoopRotatePass());
+  MPM.add(createDxilLoopUnrollPass(/*MaxIterationAttempt*/ 128));
+
+  // Default unroll pass. This is purely for optimizing loops without
+  // attributes.
+  if (OptLevel > 2) {
+    MPM.add(createLoopUnrollPass());
+  }
+
   MPM.add(createDxilPromoteLocalResources());
   MPM.add(createDxilPromoteStaticResources());
   // Verify no undef resource again after promotion

+ 173 - 13
lib/Transforms/Scalar/DxilLoopUnroll.cpp

@@ -120,7 +120,6 @@ public:
   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<LoopInfoWrapperPass>();
-    AU.addPreserved<LoopInfoWrapperPass>();
     AU.addRequiredID(LoopSimplifyID);
     AU.addRequired<AssumptionCacheTracker>();
     AU.addRequired<DominatorTreeWrapperPass>();
@@ -130,18 +129,15 @@ public:
 
 char DxilLoopUnroll::ID;
 
-static void FailLoopUnroll(bool WarnOnly, Loop *L, const char *Message) {
-  DebugLoc DL = L->getStartLoc();
-  LLVMContext &Ctx = L->getHeader()->getContext();
-
+static void FailLoopUnroll(bool WarnOnly, LLVMContext &Ctx, DebugLoc DL, const char *Message) {
   if (WarnOnly) {
-    if (DL.get())
+    if (DL)
       Ctx.emitWarning(hlsl::dxilutil::FormatMessageAtLocation(DL, Message));
     else
       Ctx.emitWarning(hlsl::dxilutil::FormatMessageWithoutLocation(Message));
   }
   else {
-    if (DL.get())
+    if (DL)
       Ctx.emitError(hlsl::dxilutil::FormatMessageAtLocation(DL, Message));
     else
       Ctx.emitError(hlsl::dxilutil::FormatMessageWithoutLocation(Message));
@@ -435,7 +431,34 @@ static bool CreateLCSSA(SetVector<BasicBlock *> &Body, const SmallVectorImpl<Bas
   return Changed;
 }
 
-static void FindProblemBlocks(BasicBlock *Header, const SmallVectorImpl<BasicBlock *> &BlocksInLoop, std::unordered_set<BasicBlock *> &ProblemBlocks) {
+static Value *GetGEPPtrOrigin(GEPOperator *GEP) {
+  Value *Ptr = GEP->getPointerOperand();
+  while (Ptr) {
+    if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
+      return AI;
+    }
+    else if (GEPOperator *NewGEP = dyn_cast<GEPOperator>(Ptr)) {
+      Ptr = NewGEP->getPointerOperand();
+    }
+    else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
+      return GV;
+    }
+    else {
+      break;
+    }
+  }
+  return nullptr;
+}
+
+// Find all blocks in the loop with instructions that
+// would require an unroll to be correct.
+//
+// For example:
+// for (int i = 0; i < 10; i++) {
+//   gep i
+// }
+//
+static void FindProblemBlocks(BasicBlock *Header, const SmallVectorImpl<BasicBlock *> &BlocksInLoop, std::unordered_set<BasicBlock *> &ProblemBlocks, SetVector<AllocaInst *> &ProblemAllocas) {
   SmallVector<Instruction *, 16> WorkList;
 
   std::unordered_set<BasicBlock *> BlocksInLoopSet(BlocksInLoop.begin(), BlocksInLoop.end());
@@ -466,7 +489,16 @@ static void FindProblemBlocks(BasicBlock *Header, const SmallVectorImpl<BasicBlo
       // problem.
       //
       if (hlsl::dxilutil::IsHLSLObjectType(EltType)) {
-        ProblemBlocks.insert(GEP->getParent());
+        if (Value *Ptr = GetGEPPtrOrigin(cast<GEPOperator>(GEP))) {
+          if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
+            if (!GV->isExternalLinkage(llvm::GlobalValue::ExternalLinkage))
+              ProblemBlocks.insert(GEP->getParent());
+          }
+          else if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
+            ProblemAllocas.insert(AI);
+            ProblemBlocks.insert(GEP->getParent());
+          }
+        }
         continue; // Stop Propagating
       }
     }
@@ -484,6 +516,114 @@ static void FindProblemBlocks(BasicBlock *Header, const SmallVectorImpl<BasicBlo
   }
 }
 
+// Helper function for getting GEP's const index value
+inline static int64_t GetGEPIndex(GEPOperator *GEP, unsigned idx) {
+  return cast<ConstantInt>(GEP->getOperand(idx + 1))->getSExtValue();
+} 
+
+// Replace allocas with all constant indices with scalar allocas, then promote
+// them to values where possible (mem2reg).
+//
+// Before loop unroll, we did not have constant indices for arrays and SROA was
+// unable to break them into scalars. Now that unroll has potentially given
+// them constant values, we need to turn them into scalars.
+//
+// if "AllowOOBIndex" is true, it turns any out of bound index into 0.
+// Otherwise it emits an error and fails compilation.
+//
+template<typename IteratorT>
+static bool BreakUpArrayAllocas(bool AllowOOBIndex, IteratorT ItBegin, IteratorT ItEnd, DominatorTree *DT, AssumptionCache *AC) { 
+  bool Success = true;
+
+  SmallVector<AllocaInst *, 8> WorkList(ItBegin, ItEnd);
+
+  SmallVector<GEPOperator *, 16> GEPs;
+  while (WorkList.size()) {
+    AllocaInst *AI = WorkList.pop_back_val();
+
+    Type *AllocaType = AI->getAllocatedType();
+
+    // Only deal with array allocas.
+    if (!AllocaType->isArrayTy())
+      continue;
+
+    unsigned ArraySize = AI->getAllocatedType()->getArrayNumElements();
+    Type *ElementType = AllocaType->getArrayElementType();
+    if (!ArraySize)
+      continue;
+
+    GEPs.clear(); // Re-use array
+    for (User *U : AI->users()) {
+      if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
+        if (!GEP->hasAllConstantIndices() || GEP->getNumIndices() < 2 ||
+          GetGEPIndex(GEP, 0) != 0)
+        {
+          GEPs.clear();
+          break;
+        }
+        else {
+          GEPs.push_back(GEP);
+        }
+      }
+      else {
+        GEPs.clear();
+        break;
+      }
+    }
+
+    if (!GEPs.size())
+      continue;
+
+    SmallVector<AllocaInst *, 8> ScalarAllocas;
+    ScalarAllocas.resize(ArraySize);
+
+    IRBuilder<> B(AI);
+    for (GEPOperator *GEP : GEPs) {
+      int64_t idx = GetGEPIndex(GEP, 1);
+      GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP);
+
+      if (idx < 0 || idx >= ArraySize) {
+        if (AllowOOBIndex)
+          idx = 0;
+        else {
+          Success = false;
+          if (GEPInst)
+            hlsl::dxilutil::EmitErrorOnInstruction(GEPInst, "Array access out of bound.");
+          continue;
+        }
+      } 
+      AllocaInst *ScalarAlloca = ScalarAllocas[idx];
+      if (!ScalarAlloca) {
+        ScalarAlloca = B.CreateAlloca(ElementType);
+        ScalarAllocas[idx] = ScalarAlloca;
+        if (ElementType->isArrayTy()) {
+          WorkList.push_back(ScalarAlloca);
+        }
+      }
+      Value *NewPointer = nullptr;
+      if (ElementType->isArrayTy()) {
+        SmallVector<Value *, 2> Indices;
+        Indices.push_back(B.getInt32(0));
+        for (unsigned i = 2; i < GEP->getNumIndices(); i++) {
+          Indices.push_back(GEP->getOperand(i + 1));
+        }
+        NewPointer = B.CreateGEP(ScalarAlloca, Indices);
+      } else {
+        NewPointer = ScalarAlloca;
+      }
+
+      GEP->replaceAllUsesWith(NewPointer);
+    } 
+
+    if (!ElementType->isArrayTy()) {
+      std::remove(ScalarAllocas.begin(), ScalarAllocas.end(), nullptr);
+      PromoteMemToReg(ScalarAllocas, *DT, nullptr, AC);
+    }
+  }
+
+  return Success;
+}
+
 static bool ContainsFloatingPointType(Type *Ty) {
   if (Ty->isFloatingPointTy()) {
     return true;
@@ -545,11 +685,12 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   if (!L->isSafeToClone())
     return false;
 
+  DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
   Function *F = L->getHeader()->getParent();
-  bool OnlyWarnOnFail = false;
+  bool FxcCompatMode = false;
   if (F->getParent()->HasHLModule()) {
     HLModule &HM = F->getParent()->GetHLModule();
-    OnlyWarnOnFail = HM.GetHLOptions().bFXCCompatMode;
+    FxcCompatMode = HM.GetHLOptions().bFXCCompatMode;
   }
 
   // Analysis passes
@@ -589,8 +730,9 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   BlocksInLoop.append(ExitBlocks.begin(), ExitBlocks.end());
 
   // Heuristically find blocks that likely need to be unrolled
+  SetVector<AllocaInst *> ProblemAllocas;
   std::unordered_set<BasicBlock *> ProblemBlocks;
-  FindProblemBlocks(L->getHeader(), BlocksInLoop, ProblemBlocks);
+  FindProblemBlocks(L->getHeader(), BlocksInLoop, ProblemBlocks, ProblemAllocas);
 
   // Keep track of the PHI nodes at the header.
   SmallVector<PHINode *, 16> PHIs;
@@ -810,6 +952,12 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   }
 
   if (Succeeded) {
+    // We are going to be cleaning them up later. Maker sure
+    // they're in entry block so deleting loop blocks don't 
+    // kill them too.
+    for (AllocaInst *AI : ProblemAllocas)
+      DXASSERT(AI->getParent() == &F->getEntryBlock(), "Alloca is not in entry block.");
+
     LoopIteration &FirstIteration = *Iterations.front().get();
     // Make the predecessor branch to the first new header.
     {
@@ -822,6 +970,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
     }
 
     if (OuterL) {
+
       // Core body blocks need to be added to outer loop
       for (size_t i = 0; i < Iterations.size(); i++) {
         LoopIteration &Iteration = *Iterations[i].get();
@@ -862,18 +1011,29 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
     for (BasicBlock *BB : ToBeCloned)
       BB->eraseFromParent();
 
+    // Blocks need to be removed from DomTree. There's no easy way
+    // to remove them in the right order, so just make DomTree
+    // recalculate.
+    DT->recalculate(*F);
+
     if (OuterL) {
       // This process may have created multiple back edges for the
       // parent loop. Simplify to keep it well-formed.
       simplifyLoop(OuterL, DT, LI, this, nullptr, nullptr, AC);
     }
 
+    // Now that we potentially turned some GEP indices into constants,
+    // try to clean up their allocas.
+    if (!BreakUpArrayAllocas(FxcCompatMode /* allow oob index */, ProblemAllocas.begin(), ProblemAllocas.end(), DT, AC)) {
+      FailLoopUnroll(false, F->getContext(), LoopLoc, "Could not unroll loop due to out of bound array access.");
+    }
+
     return true;
   }
 
   // If we were unsuccessful in unrolling the loop
   else {
-    FailLoopUnroll(OnlyWarnOnFail, L, "Could not unroll loop.");
+    FailLoopUnroll(FxcCompatMode /*warn only*/, F->getContext(), LoopLoc, "Could not unroll loop.");
 
     // Remove all the cloned blocks
     for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {

+ 28 - 0
tools/clang/test/CodeGenHLSL/unroll/2d_array.hlsl

@@ -0,0 +1,28 @@
+// RUN: %dxc -Od -E main -T ps_6_0 %s | FileCheck %s
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK-NOT: call i32 @dx.op.bufferUpdateCounter
+
+AppendStructuredBuffer<float4> buf0;
+AppendStructuredBuffer<float4> buf1;
+AppendStructuredBuffer<float4> buf2;
+AppendStructuredBuffer<float4> buf3;
+uint g_cond;
+
+float main() : SV_Target {
+
+  AppendStructuredBuffer<float4> buffers[2][2] = { buf0, buf1, buf2, buf3, };
+
+  [unroll]
+  for (uint j = 0; j < 4; j++) {
+    if (g_cond == j) {
+      buffers[j/2][j%2].Append(1);
+      return 10;
+    }
+  }
+
+  return 0;
+}
+

+ 24 - 0
tools/clang/test/CodeGenHLSL/unroll/extern.hlsl

@@ -0,0 +1,24 @@
+// RUN: %dxc -T lib_6_3 %s | FileCheck %s
+
+// Global array with external linkage does not need constant indexing.
+// Check that the block is not included in the unroll and only happens
+// once
+
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK-NOT: call i32 @dx.op.bufferUpdateCounter
+
+extern AppendStructuredBuffer<float> buffs[4];
+
+export float f(int arg : A) {
+  
+  float result = 0;
+
+  [unroll]
+  for (int i = 0; i < 4; i++) {
+    if (i == arg) {
+      buffs[i].Append(arg);
+      return 1;
+    }
+  }
+  return 0;
+}

+ 30 - 0
tools/clang/test/CodeGenHLSL/unroll/oob.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -Od -E main -T ps_6_0 %s | FileCheck %s
+// CHECK-DAG: Could not unroll loop due to out of bound array access.
+// CHECK-DAG: Array access out of bound.
+// CHECK-DAG: Could not unroll loop due to out of bound array access.
+// CHECK-NOT: @main
+
+AppendStructuredBuffer<float> buf0;
+AppendStructuredBuffer<float> buf1;
+AppendStructuredBuffer<float> buf2;
+AppendStructuredBuffer<float> buf3;
+
+uint g_cond;
+
+float main() : SV_Target {
+  AppendStructuredBuffer<float> buffs[4] = {
+    buf0, buf1, buf2, buf3,
+  };
+  
+  float result = 0;
+  [unroll]
+  for (int j = -1; j < 4+1; j++) {
+    if (j == g_cond) {
+      buffs[j].Append(g_cond);
+      break;
+    }
+    result += 1;
+  }
+  return result;
+}
+

+ 33 - 0
tools/clang/test/CodeGenHLSL/unroll/oob_2016.hlsl

@@ -0,0 +1,33 @@
+// RUN: %dxc -Od -E main -T ps_6_0 -HV 2016 %s | FileCheck %s
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK-NOT: call i32 @dx.op.bufferUpdateCounter
+
+AppendStructuredBuffer<float> buf0;
+AppendStructuredBuffer<float> buf1;
+AppendStructuredBuffer<float> buf2;
+AppendStructuredBuffer<float> buf3;
+
+uint g_cond;
+
+float main() : SV_Target {
+  AppendStructuredBuffer<float> buffs[4] = {
+    buf0, buf1, buf2, buf3,
+  };
+  
+  float result = 0;
+  [unroll]
+  for (int j = -1; j < 4+1; j++) {
+    if (j == g_cond) {
+      buffs[j].Append(g_cond);
+      break;
+    }
+    result += 1;
+  }
+  return result;
+}
+

+ 40 - 0
tools/clang/test/CodeGenHLSL/unroll/struct_member.hlsl

@@ -0,0 +1,40 @@
+// RUN: %dxc -Od -E main -T ps_6_0 %s | FileCheck %s
+// CHECK: @main
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK: call i32 @dx.op.bufferUpdateCounter
+// CHECK-NOT: call i32 @dx.op.bufferUpdateCounter
+
+AppendStructuredBuffer<float4> buf0;
+AppendStructuredBuffer<float4> buf1;
+AppendStructuredBuffer<float4> buf2;
+AppendStructuredBuffer<float4> buf3;
+uint g_cond;
+
+struct Params {
+  int foo;
+};
+
+float f(Params p) {
+
+  AppendStructuredBuffer<float4> buffers[2][2] = { buf0, buf1, buf2, buf3, };
+
+  [unroll]
+  for (uint j = 0; j < p.foo; j++) {
+    if (g_cond == j) {
+      buffers[j/2][j%2].Append(1);
+      return 10;
+    }
+  }
+
+  return 0;
+}
+
+float main() : SV_Target {
+  Params p;
+  p.foo = 4;
+
+  return f(p);
+}
+