|
@@ -60,6 +60,7 @@
|
|
|
#include "llvm/Analysis/LoopPass.h"
|
|
|
#include "llvm/Analysis/InstructionSimplify.h"
|
|
|
#include "llvm/Analysis/AssumptionCache.h"
|
|
|
+#include "llvm/Analysis/ScalarEvolution.h"
|
|
|
#include "llvm/Transforms/Scalar.h"
|
|
|
#include "llvm/Transforms/Utils/Cloning.h"
|
|
|
#include "llvm/Transforms/Utils/Local.h"
|
|
@@ -74,6 +75,7 @@
|
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
#include "llvm/Support/Debug.h"
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
|
+#include "llvm/IR/LegacyPassManager.h"
|
|
|
|
|
|
#include "dxc/DXIL/DxilUtil.h"
|
|
|
#include "dxc/HLSL/HLModule.h"
|
|
@@ -110,7 +112,7 @@ public:
|
|
|
std::unordered_set<Function *> CleanedUpAlloca;
|
|
|
unsigned MaxIterationAttempt = 0;
|
|
|
|
|
|
- DxilLoopUnroll(unsigned MaxIterationAttempt = 128) :
|
|
|
+ DxilLoopUnroll(unsigned MaxIterationAttempt = 1024) :
|
|
|
LoopPass(ID),
|
|
|
MaxIterationAttempt(MaxIterationAttempt)
|
|
|
{
|
|
@@ -120,16 +122,17 @@ public:
|
|
|
bool runOnLoop(Loop *L, LPPassManager &LPM) override;
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
|
AU.addRequired<LoopInfoWrapperPass>();
|
|
|
- AU.addRequiredID(LoopSimplifyID);
|
|
|
AU.addRequired<AssumptionCacheTracker>();
|
|
|
AU.addRequired<DominatorTreeWrapperPass>();
|
|
|
AU.addPreserved<DominatorTreeWrapperPass>();
|
|
|
+ AU.addRequired<ScalarEvolution>();
|
|
|
+ AU.addRequiredID(LoopSimplifyID);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
char DxilLoopUnroll::ID;
|
|
|
|
|
|
-static void FailLoopUnroll(bool WarnOnly, LLVMContext &Ctx, DebugLoc DL, const char *Message) {
|
|
|
+static void FailLoopUnroll(bool WarnOnly, LLVMContext &Ctx, DebugLoc DL, const Twine &Message) {
|
|
|
if (WarnOnly) {
|
|
|
if (DL)
|
|
|
Ctx.emitWarning(hlsl::dxilutil::FormatMessageAtLocation(DL, Message));
|
|
@@ -684,6 +687,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
|
|
|
|
|
|
DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
|
|
|
Function *F = L->getHeader()->getParent();
|
|
|
+ ScalarEvolution *SE = &getAnalysis<ScalarEvolution>();
|
|
|
|
|
|
bool HasExplicitLoopCount = false;
|
|
|
int ExplicitUnrollCountSigned = 0;
|
|
@@ -714,6 +718,18 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
|
|
|
FxcCompatMode = HM.GetHLOptions().bFXCCompatMode;
|
|
|
}
|
|
|
|
|
|
+ unsigned TripCount = 0;
|
|
|
+ unsigned TripMultiple = 0;
|
|
|
+ bool HasTripCount = false;
|
|
|
+ BasicBlock *ExitingBlock = L->getLoopLatch();
|
|
|
+ if (!ExitingBlock || !L->isLoopExiting(ExitingBlock))
|
|
|
+ ExitingBlock = L->getExitingBlock();
|
|
|
+ if (ExitingBlock) {
|
|
|
+ TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
|
|
|
+ TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock);
|
|
|
+ HasTripCount = TripMultiple != 1 || TripCount == 1;
|
|
|
+ }
|
|
|
+
|
|
|
// Analysis passes
|
|
|
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
|
|
AssumptionCache *AC =
|
|
@@ -736,12 +752,6 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
- // Promote alloca's
|
|
|
- if (!CleanedUpAlloca.count(F)) {
|
|
|
- CleanedUpAlloca.insert(F);
|
|
|
- Mem2Reg(*F, *DT, *AC);
|
|
|
- }
|
|
|
-
|
|
|
SmallVector<BasicBlock *, 16> ExitBlocks;
|
|
|
L->getExitBlocks(ExitBlocks);
|
|
|
std::unordered_set<BasicBlock *> ExitBlockSet(ExitBlocks.begin(), ExitBlocks.end());
|
|
@@ -839,9 +849,15 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
|
|
|
SmallVector<std::unique_ptr<LoopIteration>, 16> Iterations; // List of cloned iterations
|
|
|
bool Succeeded = false;
|
|
|
|
|
|
- if (HasExplicitLoopCount) {
|
|
|
- this->MaxIterationAttempt = std::max(this->MaxIterationAttempt, ExplicitUnrollCount);
|
|
|
+ // If we were able to figure out the definitive trip count,
|
|
|
+ // just unroll that many times.
|
|
|
+ if (HasTripCount) {
|
|
|
+ this->MaxIterationAttempt = TripCount;
|
|
|
}
|
|
|
+ else if (HasExplicitLoopCount) {
|
|
|
+ this->MaxIterationAttempt = ExplicitUnrollCount;
|
|
|
+ }
|
|
|
+
|
|
|
for (unsigned IterationI = 0; IterationI < this->MaxIterationAttempt; IterationI++) {
|
|
|
|
|
|
LoopIteration *PrevIteration = nullptr;
|
|
@@ -957,7 +973,9 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
|
|
|
}
|
|
|
|
|
|
// We've reached the N defined in [unroll(N)]
|
|
|
- if (HasExplicitLoopCount && IterationI+1 >= ExplicitUnrollCount) {
|
|
|
+ if ((HasExplicitLoopCount && IterationI+1 >= ExplicitUnrollCount) ||
|
|
|
+ (HasTripCount && IterationI+1 >= TripCount))
|
|
|
+ {
|
|
|
Succeeded = true;
|
|
|
BranchInst *BI = cast<BranchInst>(CurIteration.Latch->getTerminator());
|
|
|
|
|
@@ -1024,6 +1042,8 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ SE->forgetLoop(L);
|
|
|
+
|
|
|
// Remove the original blocks that we've cloned from all loops.
|
|
|
for (BasicBlock *BB : ToBeCloned)
|
|
|
LI->removeBlock(BB);
|
|
@@ -1061,9 +1081,16 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
|
|
|
|
|
|
// If we were unsuccessful in unrolling the loop
|
|
|
else {
|
|
|
- FailLoopUnroll(FxcCompatMode /*warn only*/, F->getContext(), LoopLoc,
|
|
|
- "Could not unroll loop. Loop bound could not be deduced at compile time. "
|
|
|
- "To give an explicit unroll bound, use unroll(n).");
|
|
|
+ const char *Msg =
|
|
|
+ "Could not unroll loop. Loop bound could not be deduced at compile time. "
|
|
|
+ "Use [unroll(n)] to give an explicit count.";
|
|
|
+ if (FxcCompatMode) {
|
|
|
+ FailLoopUnroll(true /*warn only*/, F->getContext(), LoopLoc, Msg);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ FailLoopUnroll(false /*warn only*/, F->getContext(), LoopLoc,
|
|
|
+ Twine(Msg) + Twine(" Use '-HV 2016' to treat this as warning."));
|
|
|
+ }
|
|
|
|
|
|
// Remove all the cloned blocks
|
|
|
for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {
|
|
@@ -1088,8 +1115,88 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
|
|
|
|
|
|
}
|
|
|
|
|
|
+// Special Mem2Reg pass
|
|
|
+//
|
|
|
+// In order to figure out loop bounds to unroll, we must first run mem2reg pass
|
|
|
+// on the function, but we don't want to run mem2reg on functions that don't
|
|
|
+// have to be unrolled when /Od is given. This pass considers all these
|
|
|
+// conditions and runs mem2reg on functions only when needed.
|
|
|
+//
|
|
|
+class DxilConditionalMem2Reg : public FunctionPass {
|
|
|
+public:
|
|
|
+ static char ID;
|
|
|
+
|
|
|
+ // Function overrides that resolve options when used for DxOpt
|
|
|
+ void applyOptions(PassOptions O) {
|
|
|
+ GetPassOptionBool(O, "NoOpt", &NoOpt, false);
|
|
|
+ }
|
|
|
+ void dumpConfig(raw_ostream &OS) {
|
|
|
+ FunctionPass::dumpConfig(OS);
|
|
|
+ OS << ",NoOpt=" << NoOpt;
|
|
|
+ }
|
|
|
+
|
|
|
+ bool NoOpt = false;
|
|
|
+ explicit DxilConditionalMem2Reg(bool NoOpt=false) : FunctionPass(ID), NoOpt(NoOpt)
|
|
|
+ {
|
|
|
+ initializeDxilConditionalMem2RegPass(*PassRegistry::getPassRegistry());
|
|
|
+ }
|
|
|
+
|
|
|
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
|
+ AU.addRequired<LoopInfoWrapperPass>();
|
|
|
+ AU.addRequired<DominatorTreeWrapperPass>();
|
|
|
+ AU.addRequired<AssumptionCacheTracker>();
|
|
|
+ AU.addRequiredID(LoopSimplifyID);
|
|
|
+ AU.setPreservesCFG();
|
|
|
+ }
|
|
|
+
|
|
|
+ // Recursively find loops that are marked with [unroll]
|
|
|
+ static bool HasLoopsMarkedUnrollRecursive(Loop *L) {
|
|
|
+ int Count = 0;
|
|
|
+ if (IsMarkedFullUnroll(L) || IsMarkedUnrollCount(L, &Count)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ for (Loop *ChildLoop : *L) {
|
|
|
+ if (HasLoopsMarkedUnrollRecursive(ChildLoop))
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ bool runOnFunction(Function &F) {
|
|
|
+ LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
|
|
|
+ DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
|
|
+ AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
|
|
|
+
|
|
|
+ bool NeedPromote = false;
|
|
|
+ bool Changed = false;
|
|
|
+
|
|
|
+ if (NoOpt) {
|
|
|
+ // If any of the functions are marked as full unroll.
|
|
|
+ for (Loop *L : *LI) {
|
|
|
+ if (HasLoopsMarkedUnrollRecursive(L)) {
|
|
|
+ NeedPromote = true;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ NeedPromote = true;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (NeedPromote)
|
|
|
+ Changed |= Mem2Reg(F, *DT, *AC);
|
|
|
+
|
|
|
+ return Changed;
|
|
|
+ }
|
|
|
+};
|
|
|
+char DxilConditionalMem2Reg::ID;
|
|
|
+
|
|
|
+Pass *llvm::createDxilConditionalMem2RegPass(bool NoOpt) {
|
|
|
+ return new DxilConditionalMem2Reg(NoOpt);
|
|
|
+}
|
|
|
Pass *llvm::createDxilLoopUnrollPass(unsigned MaxIterationAttempt) {
|
|
|
return new DxilLoopUnroll(MaxIterationAttempt);
|
|
|
}
|
|
|
|
|
|
+INITIALIZE_PASS(DxilConditionalMem2Reg, "dxil-cond-mem2reg", "Dxil Conditional Mem2Reg", false, false)
|
|
|
INITIALIZE_PASS(DxilLoopUnroll, "dxil-loop-unroll", "Dxil Unroll loops", false, false)
|