|
@@ -1074,42 +1074,60 @@ INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
-// Append all blocks containing instructions that are sensitive to WaveCI into SensitiveBB
|
|
|
-// Sensitivity entails being an eventual user of WaveCI and also belonging to a block with
|
|
|
-// an break conditional on the global breakCmp that breaks out of a loop that contains WaveCI
|
|
|
-static void CollectSensitiveBlocks(LoopInfo *LInfo, CallInst *WaveCI, Function *BreakFunc,
|
|
|
- SmallPtrSet<BasicBlock *, 16> &SensitiveBB) {
|
|
|
- BasicBlock *WaveBB = WaveCI->getParent();
|
|
|
- // If this wave operation isn't in a loop, there is no need to track its sensitivity
|
|
|
- if (!LInfo->getLoopFor(WaveBB))
|
|
|
+// Cull blocks from BreakBBs that containing instructions that are sensitive to the wave-sensitive Inst
|
|
|
+// Sensitivity entails being an eventual user of the Inst and also belonging to a block with
|
|
|
+// a break conditional on dx.break that breaks out of a loop that contains WaveCI
|
|
|
+// LInfo is needed to determine loop contents. VisitedPhis is needed to prevent infinit looping.
|
|
|
+static void CullSensitiveBlocks(LoopInfo *LInfo, BasicBlock *WaveBB, BasicBlock *LastBB, Instruction *Inst,
|
|
|
+ SmallPtrSet<Instruction *, 16> VisitedPhis,
|
|
|
+ SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
|
|
|
+ BasicBlock *BB = Inst->getParent();
|
|
|
+ Loop *BreakLoop = LInfo->getLoopFor(BB);
|
|
|
+ // If this instruction isn't in a loop, there is no need to track its sensitivity further
|
|
|
+ if (!BreakLoop)
|
|
|
return;
|
|
|
|
|
|
- SmallVector<User *, 16> WorkList;
|
|
|
- SmallPtrSet<Instruction *, 16> VisitedPhis; // To prevent infinite looping, only visit each PHI once
|
|
|
- WorkList.append(WaveCI->user_begin(), WaveCI->user_end());
|
|
|
+ // To prevent infinite looping, only visit each PHI once
|
|
|
+ // If we've seen this PHI before, don't reprocess it
|
|
|
+ if (isa<PHINode>(Inst) && !VisitedPhis.insert(Inst).second)
|
|
|
+ return;
|
|
|
|
|
|
- while (!WorkList.empty()) {
|
|
|
- Instruction *I = dyn_cast<Instruction>(WorkList.pop_back_val());
|
|
|
- if (I && LInfo->getLoopFor(I->getParent())) {
|
|
|
- // If we've seen this PHI before, don't reprocess it
|
|
|
- if (isa<PHINode>(I)) {
|
|
|
- if(!VisitedPhis.insert(I).second)
|
|
|
- continue;
|
|
|
- }
|
|
|
- // Determine if the instruction's block has an artificially-conditional break
|
|
|
- // and breaks out of a loop that contains the waveCI
|
|
|
- BasicBlock *BB = I->getParent();
|
|
|
- BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
|
|
|
- if (BI && BI->isConditional()) {
|
|
|
- CallInst *Cond = dyn_cast<CallInst>(BI->getCondition());
|
|
|
- if (Cond && Cond->getCalledFunction() == BreakFunc) {
|
|
|
- Loop *BreakLoop = LInfo->getLoopFor(BB);
|
|
|
- if (BreakLoop && BreakLoop->contains(WaveBB))
|
|
|
- SensitiveBB.insert(BB);
|
|
|
- }
|
|
|
- }
|
|
|
- // TODO: hit the brakes if we've left any loop that might contain WaveCI
|
|
|
- WorkList.append(I->user_begin(), I->user_end());
|
|
|
+ // If this BB wasn't already just processed, handle it now
|
|
|
+ if (LastBB != BB) {
|
|
|
+ // Determine if the instruction's block has an artificially-conditional break
|
|
|
+ // and breaks out of a loop that contains the waveCI
|
|
|
+ BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
|
|
|
+ if (BI && BI->isConditional() && BreakLoop->contains(WaveBB))
|
|
|
+ BreakBBs.erase(BB);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Recurse on the users
|
|
|
+ for (User *U : Inst->users()) {
|
|
|
+ Instruction *I = cast<Instruction>(U);
|
|
|
+ CullSensitiveBlocks(LInfo, WaveBB, BB, I, VisitedPhis, BreakBBs);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Collect blocks that end in a dx.break dependent branch by tracing the descendants of BreakFunc
|
|
|
+// that are found in ThisFunc and store the block and call instruction in BreakBBs
|
|
|
+static void CollectBreakBlocks(Function *BreakFunc, Function *ThisFunc,
|
|
|
+ SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
|
|
|
+ for (User *U : BreakFunc->users()) {
|
|
|
+ SmallVector<User *, 16> WorkList;
|
|
|
+ Instruction *CI = cast<Instruction>(U);
|
|
|
+ // If this user doesn't pertain to the current function, skip it.
|
|
|
+ if (CI->getParent()->getParent() != ThisFunc)
|
|
|
+ continue;
|
|
|
+ WorkList.append(CI->user_begin(), CI->user_end());
|
|
|
+ while (!WorkList.empty()) {
|
|
|
+ Instruction *I = dyn_cast<Instruction>(WorkList.pop_back_val());
|
|
|
+ // When we find a Branch that depends on dx.break, save it and stop
|
|
|
+ // This should almost always be the first user of the Call Inst
|
|
|
+ // If not, iterate on the users
|
|
|
+ if (BranchInst *BI = dyn_cast<BranchInst>(I))
|
|
|
+ BreakBBs[BI->getParent()] = CI;
|
|
|
+ else
|
|
|
+ WorkList.append(I->user_begin(), I->user_end());
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1131,7 +1149,6 @@ public:
|
|
|
bool runOnFunction(Function &F) override {
|
|
|
if (F.isDeclaration())
|
|
|
return false;
|
|
|
- // Only check ps and lib profile.
|
|
|
Module *M = F.getEntryBlock().getModule();
|
|
|
|
|
|
Function *BreakFunc = M->getFunction(DXIL::kDxBreakFuncName);
|
|
@@ -1139,35 +1156,32 @@ public:
|
|
|
return false;
|
|
|
|
|
|
LInfo = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
|
|
|
- // For each wave operation, collect the blocks sensitive to it
|
|
|
- SmallPtrSet<BasicBlock *, 16> SensitiveBBs;
|
|
|
+ // Collect the blocks that depend on dx.break and the instructions that call dx.break()
|
|
|
+ SmallDenseMap<BasicBlock *, Instruction *, 16> BreakBBs;
|
|
|
+ CollectBreakBlocks(BreakFunc, &F, BreakBBs);
|
|
|
+
|
|
|
+ // For each wave operation, remove all the dx.break blocks that are sensitive to it
|
|
|
for (Function &IF : M->functions()) {
|
|
|
HLOpcodeGroup opgroup = hlsl::GetHLOpcodeGroup(&IF);
|
|
|
- if (&IF != &F && IF.getNumUses() && IF.isDeclaration() && IsHLWaveSensitive(&IF) &&
|
|
|
+ // Only consider wave-sensitive intrinsics or extintrinsics
|
|
|
+ if (IF.isDeclaration() && IsHLWaveSensitive(&IF) &&
|
|
|
(opgroup == HLOpcodeGroup::HLIntrinsic || opgroup == HLOpcodeGroup::HLExtIntrinsic)) {
|
|
|
+ // For each user of the function, trace all its users to remove the blocks
|
|
|
for (User *U : IF.users()) {
|
|
|
CallInst *CI = cast<CallInst>(U);
|
|
|
- CollectSensitiveBlocks(LInfo, CI, BreakFunc, SensitiveBBs);
|
|
|
+ SmallPtrSet<Instruction *, 16> VisitedPhis;
|
|
|
+ CullSensitiveBlocks(LInfo, CI->getParent(), nullptr, CI, VisitedPhis, BreakBBs);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
bool Changed = false;
|
|
|
- // Revert artificially conditional breaks in blocks not included in SensitiveBBs
|
|
|
- for (auto &BB : F) {
|
|
|
- if (!SensitiveBBs.count(&BB)) {
|
|
|
- BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator());
|
|
|
- if (BI && BI->isConditional()) {
|
|
|
- CallInst *Cond = dyn_cast<CallInst>(BI->getCondition());
|
|
|
- if (Cond && Cond->getCalledFunction() == BreakFunc) {
|
|
|
- // Make branch conditional always true and erase the conditional
|
|
|
- Constant *C = ConstantInt::get(Type::getInt1Ty(BI->getContext()), 1);
|
|
|
- BI->setCondition(C);
|
|
|
- Cond->eraseFromParent();
|
|
|
- Changed = true;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+ // Revert artificially conditional breaks in non-wave-sensitive blocks that remain in BreakBBs
|
|
|
+ Constant *C = ConstantInt::get(Type::getInt1Ty(M->getContext()), 1);
|
|
|
+ for (auto &BB : BreakBBs) {
|
|
|
+ // Replace the call instruction with a constant boolen
|
|
|
+ BB.second->replaceAllUsesWith(C);
|
|
|
+ Changed = true;
|
|
|
}
|
|
|
return Changed;
|
|
|
}
|