|
@@ -18,6 +18,7 @@
|
|
|
#include "dxc/DXIL/DxilUtil.h"
|
|
|
#include "dxc/DXIL/DxilFunctionProps.h"
|
|
|
#include "dxc/DXIL/DxilInstructions.h"
|
|
|
+#include "dxc/HlslIntrinsicOp.h"
|
|
|
#include "llvm/IR/GetElementPtrTypeIterator.h"
|
|
|
#include "llvm/IR/IRBuilder.h"
|
|
|
#include "llvm/IR/Instructions.h"
|
|
@@ -31,6 +32,7 @@
|
|
|
#include "llvm/Pass.h"
|
|
|
#include "llvm/Transforms/Utils/Local.h"
|
|
|
#include "llvm/Analysis/AssumptionCache.h"
|
|
|
+#include "llvm/Analysis/LoopInfo.h"
|
|
|
#include <memory>
|
|
|
#include <unordered_set>
|
|
|
|
|
@@ -52,6 +54,9 @@ public:
|
|
|
};
|
|
|
}
|
|
|
|
|
|
+char *hlsl::kDxBreakFuncName = "dx.break";
|
|
|
+char *hlsl::kDxBreakCondName = "dx.break.cond";
|
|
|
+
|
|
|
char InvalidateUndefResources::ID = 0;
|
|
|
|
|
|
ModulePass *llvm::createInvalidateUndefResourcesPass() { return new InvalidateUndefResources(); }
|
|
@@ -404,6 +409,9 @@ public:
|
|
|
}
|
|
|
RemoveStoreUndefOutput(M, hlslOP);
|
|
|
|
|
|
+ // Turn dx.break() conditional into global
|
|
|
+ LowerDxBreak(M);
|
|
|
+
|
|
|
RemoveUnusedStaticGlobal(M);
|
|
|
|
|
|
// Remove unnecessary address space casts.
|
|
@@ -622,6 +630,42 @@ private:
|
|
|
AllocFn->eraseFromParent();
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ // Convert all uses of dx.break() into per-function load/cmp of dx.break.cond global constant
|
|
|
+ void LowerDxBreak(Module &M) {
|
|
|
+ if (Function *BreakFunc = M.getFunction(kDxBreakFuncName)) {
|
|
|
+ if (BreakFunc->getNumUses()) {
|
|
|
+ llvm::Type *i32Ty = llvm::Type::getInt32Ty(M.getContext());
|
|
|
+ Type *i32ArrayTy = ArrayType::get(i32Ty, 1);
|
|
|
+ unsigned int Values[1] = { 0 };
|
|
|
+ Constant *InitialValue = ConstantDataArray::get(M.getContext(), Values);
|
|
|
+ Constant *GV = new GlobalVariable(M, i32ArrayTy, true,
|
|
|
+ GlobalValue::InternalLinkage,
|
|
|
+ InitialValue, kDxBreakCondName);
|
|
|
+
|
|
|
+ Constant *Indices[] = { ConstantInt::get(i32Ty, 0), ConstantInt::get(i32Ty, 0) };
|
|
|
+ Constant *Gep = ConstantExpr::getGetElementPtr(nullptr, GV, Indices);
|
|
|
+ SmallDenseMap<llvm::Function*, llvm::ICmpInst*, 16> DxBreakCmpMap;
|
|
|
+ // Replace all uses of dx.break with references to the constant global
|
|
|
+ for (User *U : BreakFunc->users()) {
|
|
|
+ DXASSERT(U->hasOneUse() && isa<CallInst>(U), "User of dx.break function isn't call or has multiple users");
|
|
|
+ BranchInst *BI = cast<BranchInst>(*U->user_begin());
|
|
|
+ CallInst *CI = cast<CallInst>(U);
|
|
|
+ Function *F = BI->getParent()->getParent();
|
|
|
+ ICmpInst *Cmp = DxBreakCmpMap.lookup(F);
|
|
|
+ if (!Cmp) {
|
|
|
+ BasicBlock &EntryBB = F->getEntryBlock();
|
|
|
+ LoadInst *LI = new LoadInst(Gep, nullptr, false, EntryBB.getTerminator());
|
|
|
+ Cmp = new ICmpInst(EntryBB.getTerminator(), ICmpInst::ICMP_EQ, LI, llvm::ConstantInt::get(i32Ty,0));
|
|
|
+ DxBreakCmpMap.insert(std::make_pair(F, Cmp));
|
|
|
+ }
|
|
|
+ BI->setCondition(Cmp);
|
|
|
+ CI->eraseFromParent();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ BreakFunc->eraseFromParent();
|
|
|
+ }
|
|
|
+ }
|
|
|
};
|
|
|
}
|
|
|
|
|
@@ -1028,3 +1072,118 @@ ModulePass *llvm::createDxilValidateWaveSensitivityPass() {
|
|
|
}
|
|
|
|
|
|
INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "HLSL DXIL wave sensitiveity validation", false, false)
|
|
|
+
|
|
|
+
|
|
|
+namespace {
|
|
|
+
|
|
|
+// Append all blocks containing instructions that are sensitive to WaveCI into SensitiveBB
|
|
|
+// Sensitivity entails being an eventual user of WaveCI and also belonging to a block with
|
|
|
+// an break conditional on the global breakCmp that breaks out of a loop that contains WaveCI
|
|
|
+static void CollectSensitiveBlocks(LoopInfo *LInfo, CallInst *WaveCI, Function *BreakFunc,
|
|
|
+ SmallPtrSet<BasicBlock *, 16> &SensitiveBB) {
|
|
|
+ BasicBlock *WaveBB = WaveCI->getParent();
|
|
|
+ // If this wave operation isn't in a loop, there is no need to track its sensitivity
|
|
|
+ if (!LInfo->getLoopFor(WaveBB))
|
|
|
+ return;
|
|
|
+
|
|
|
+ SmallVector<User *, 16> WorkList;
|
|
|
+ SmallPtrSet<Instruction *, 16> VisitedPhis; // To prevent infinite looping, only visit each PHI once
|
|
|
+ WorkList.append(WaveCI->user_begin(), WaveCI->user_end());
|
|
|
+
|
|
|
+ while (!WorkList.empty()) {
|
|
|
+ Instruction *I = dyn_cast<Instruction>(WorkList.pop_back_val());
|
|
|
+ if (I && LInfo->getLoopFor(I->getParent())) {
|
|
|
+ // If we've seen this PHI before, don't reprocess it
|
|
|
+ if (isa<PHINode>(I)) {
|
|
|
+ if(!VisitedPhis.insert(I).second)
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ // Determine if the instruction's block has an artificially-conditional break
|
|
|
+ // and breaks out of a loop that contains the waveCI
|
|
|
+ BasicBlock *BB = I->getParent();
|
|
|
+ BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
|
|
|
+ if (BI && BI->isConditional()) {
|
|
|
+ CallInst *Cond = dyn_cast<CallInst>(BI->getCondition());
|
|
|
+ if (Cond && Cond->getCalledFunction() == BreakFunc) {
|
|
|
+ Loop *BreakLoop = LInfo->getLoopFor(BB);
|
|
|
+ if (BreakLoop && BreakLoop->contains(WaveBB))
|
|
|
+ SensitiveBB.insert(BB);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // TODO: hit the brakes if we've left any loop that might contain WaveCI
|
|
|
+ WorkList.append(I->user_begin(), I->user_end());
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+// A pass to remove conditions from breaks that do not contain instructions that
|
|
|
+// depend on wave operations that are in the loop that the break leaves.
|
|
|
+class CleanupDxBreak : public FunctionPass {
|
|
|
+public:
|
|
|
+ static char ID; // Pass identification, replacement for typeid
|
|
|
+ explicit CleanupDxBreak() : FunctionPass(ID) {}
|
|
|
+ const char *getPassName() const override { return "HLSL Remove unnecessary dx.break conditions"; }
|
|
|
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
|
+ AU.addRequired<LoopInfoWrapperPass>();
|
|
|
+ }
|
|
|
+
|
|
|
+ LoopInfo *LInfo;
|
|
|
+
|
|
|
+ bool runOnFunction(Function &F) override {
|
|
|
+ if (F.isDeclaration())
|
|
|
+ return false;
|
|
|
+ // Only check ps and lib profile.
|
|
|
+ Module *M = F.getEntryBlock().getModule();
|
|
|
+
|
|
|
+ Function *BreakFunc = M->getFunction(kDxBreakFuncName);
|
|
|
+ if (!BreakFunc)
|
|
|
+ return false;
|
|
|
+
|
|
|
+ LInfo = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
|
|
|
+ // For each wave operation, collect the blocks sensitive to it
|
|
|
+ SmallPtrSet<BasicBlock *, 16> SensitiveBBs;
|
|
|
+ for (Function &IF : M->functions()) {
|
|
|
+ if (&IF == &F || !IF.getNumUses() || !IF.isDeclaration() ||
|
|
|
+ hlsl::GetHLOpcodeGroupByName(&IF) != HLOpcodeGroup::HLIntrinsic)
|
|
|
+ continue;
|
|
|
+
|
|
|
+ for (User *U : IF.users()) {
|
|
|
+ CallInst *CI = dyn_cast<CallInst>(U);
|
|
|
+ if (CI && IsCallWaveSensitive(CI))
|
|
|
+ CollectSensitiveBlocks(LInfo, CI, BreakFunc, SensitiveBBs);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ bool Changed = false;
|
|
|
+ // Revert artificially conditional breaks in blocks not included in SensitiveBBs
|
|
|
+ for (auto &BB : F) {
|
|
|
+ if (!SensitiveBBs.count(&BB)) {
|
|
|
+ BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator());
|
|
|
+ if (BI && BI->isConditional()) {
|
|
|
+ CallInst *Cond = dyn_cast<CallInst>(BI->getCondition());
|
|
|
+ if (Cond && Cond->getCalledFunction() == BreakFunc) {
|
|
|
+ // Make branch conditional always true and erase the conditional
|
|
|
+ Constant *C = ConstantInt::get(Type::getInt1Ty(BI->getContext()), 1);
|
|
|
+ BI->setCondition(C);
|
|
|
+ Cond->eraseFromParent();
|
|
|
+ Changed = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return Changed;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+char CleanupDxBreak::ID = 0;
|
|
|
+
|
|
|
+INITIALIZE_PASS_BEGIN(CleanupDxBreak, "hlsl-cleanup-dxbreak", "HLSL Remove unnecessary dx.break conditions", false, false)
|
|
|
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
|
|
|
+INITIALIZE_PASS_END(CleanupDxBreak, "hlsl-cleanup-dxbreak", "HLSL Remove unnecessary dx.break conditions", false, false)
|
|
|
+
|
|
|
+FunctionPass *llvm::createCleanupDxBreakPass() {
|
|
|
+ return new CleanupDxBreak();
|
|
|
+}
|