Browse Source

Handle more complicated dx.break circumstances (#2825)

This includes a number of different changes to allow the dx.break
artificial branch conditional to function properly in different
situations that might arise as a result of optimization passes.

First of all, the dx.break usage is eliminated entirely for any function
found to not have need of it during code gen finishing. The dependencies
are not all clear at this stage, so it only determines if the function
makes use of any wave operations.

During the original dx.break elimination pass, the code allows for any
kind of instruction. This requires a bit more information about the user
of dx.break since the immediate user might not be a branch. So first the
dx.break users are traversed the same way that the wave users are. Then
the wave users are used to cull the list of dxbreak users. Then the
remaining dx.break users, which are not wave sensitive have their branch
conditionals removed by traversing the users until the branch is found.

An incidental optimization in the finding of wave sensitive functions is
used which requires depth first traversal. Since certain checks are
performed on a per-block basis and the user chain can be expected to
stay in the same block at least for a bit, by using depth first search,
we can skip when the new block matches the one just processed.

There is a bit of consolidation of logic here. Instead of repeatedly
checking for the pattern that indicates a dx.break modified branch, it
is done once to collect all such blocks and thereafter, its presence
there is taken to mean that it qualifies.

A test is added to verify that this solves the original motivation of
this which was a circumstance where merging of conditionals resulted in
the dx.break call instruction was not immediately used by the branch.

Add additional early exit for waveless compiles

If the shader has no wave ops at all, there's no reason to create the
function at all.
Greg Roth 5 years ago
parent
commit
2ec5dd9c62

+ 67 - 53
lib/HLSL/DxilPreparePasses.cpp

@@ -1074,42 +1074,60 @@ INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "
 
 namespace {
 
-// Append all blocks containing instructions that are sensitive to WaveCI into SensitiveBB
-// Sensitivity entails being an eventual user of WaveCI and also belonging to a block with
-// an break conditional on the global breakCmp that breaks out of a loop that contains WaveCI
-static void CollectSensitiveBlocks(LoopInfo *LInfo, CallInst *WaveCI, Function *BreakFunc,
-                                   SmallPtrSet<BasicBlock *, 16> &SensitiveBB) {
-  BasicBlock *WaveBB = WaveCI->getParent();
-  // If this wave operation isn't in a loop, there is no need to track its sensitivity
-  if (!LInfo->getLoopFor(WaveBB))
+// Cull blocks from BreakBBs that containing instructions that are sensitive to the wave-sensitive Inst
+// Sensitivity entails being an eventual user of the Inst and also belonging to a block with
+// a break conditional on dx.break that breaks out of a loop that contains WaveCI
+// LInfo is needed to determine loop contents. VisitedPhis is needed to prevent infinit looping.
+static void CullSensitiveBlocks(LoopInfo *LInfo, BasicBlock *WaveBB, BasicBlock *LastBB, Instruction *Inst,
+                                SmallPtrSet<Instruction *, 16> VisitedPhis,
+                                SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
+  BasicBlock *BB = Inst->getParent();
+  Loop *BreakLoop = LInfo->getLoopFor(BB);
+  // If this instruction isn't in a loop, there is no need to track its sensitivity further
+  if (!BreakLoop)
     return;
 
-  SmallVector<User *, 16> WorkList;
-  SmallPtrSet<Instruction *, 16> VisitedPhis; // To prevent infinite looping, only visit each PHI once
-  WorkList.append(WaveCI->user_begin(), WaveCI->user_end());
+  // To prevent infinite looping, only visit each PHI once
+  // If we've seen this PHI before, don't reprocess it
+  if (isa<PHINode>(Inst) && !VisitedPhis.insert(Inst).second)
+    return;
 
-  while (!WorkList.empty()) {
-    Instruction *I = dyn_cast<Instruction>(WorkList.pop_back_val());
-    if (I && LInfo->getLoopFor(I->getParent())) {
-      // If we've seen this PHI before, don't reprocess it
-      if (isa<PHINode>(I)) {
-        if(!VisitedPhis.insert(I).second)
-          continue;
-      }
-      // Determine if the instruction's block has an artificially-conditional break
-      // and breaks out of a loop that contains the waveCI
-      BasicBlock *BB = I->getParent();
-      BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
-      if (BI && BI->isConditional()) {
-        CallInst *Cond = dyn_cast<CallInst>(BI->getCondition());
-        if (Cond && Cond->getCalledFunction() == BreakFunc) {
-          Loop *BreakLoop = LInfo->getLoopFor(BB);
-          if (BreakLoop && BreakLoop->contains(WaveBB))
-            SensitiveBB.insert(BB);
-        }
-      }
-      // TODO: hit the brakes if we've left any loop that might contain WaveCI
-      WorkList.append(I->user_begin(), I->user_end());
+  // If this BB wasn't already just processed, handle it now
+  if (LastBB != BB) {
+    // Determine if the instruction's block has an artificially-conditional break
+    // and breaks out of a loop that contains the waveCI
+    BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
+    if (BI && BI->isConditional() && BreakLoop->contains(WaveBB))
+      BreakBBs.erase(BB);
+  }
+
+  // Recurse on the users
+  for (User *U : Inst->users()) {
+    Instruction *I = cast<Instruction>(U);
+    CullSensitiveBlocks(LInfo, WaveBB, BB, I, VisitedPhis, BreakBBs);
+  }
+}
+
+// Collect blocks that end in a dx.break dependent branch by tracing the descendants of BreakFunc
+// that are found in ThisFunc and store the block and call instruction in BreakBBs
+static void CollectBreakBlocks(Function *BreakFunc, Function *ThisFunc,
+                               SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
+  for (User *U : BreakFunc->users()) {
+    SmallVector<User *, 16> WorkList;
+    Instruction *CI = cast<Instruction>(U);
+    // If this user doesn't pertain to the current function, skip it.
+    if (CI->getParent()->getParent() != ThisFunc)
+      continue;
+    WorkList.append(CI->user_begin(), CI->user_end());
+    while (!WorkList.empty()) {
+      Instruction *I = dyn_cast<Instruction>(WorkList.pop_back_val());
+      // When we find a Branch that depends on dx.break, save it and stop
+      // This should almost always be the first user of the Call Inst
+      // If not, iterate on the users
+      if (BranchInst *BI = dyn_cast<BranchInst>(I))
+        BreakBBs[BI->getParent()] = CI;
+      else
+        WorkList.append(I->user_begin(), I->user_end());
     }
   }
 }
@@ -1131,7 +1149,6 @@ public:
   bool runOnFunction(Function &F) override {
     if (F.isDeclaration())
       return false;
-    // Only check ps and lib profile.
     Module *M = F.getEntryBlock().getModule();
 
     Function *BreakFunc = M->getFunction(DXIL::kDxBreakFuncName);
@@ -1139,35 +1156,32 @@ public:
       return false;
 
     LInfo = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
-    // For each wave operation, collect the blocks sensitive to it
-    SmallPtrSet<BasicBlock *, 16> SensitiveBBs;
+    // Collect the blocks that depend on dx.break and the instructions that call dx.break()
+    SmallDenseMap<BasicBlock *, Instruction *, 16> BreakBBs;
+    CollectBreakBlocks(BreakFunc, &F, BreakBBs);
+
+    // For each wave operation, remove all the dx.break blocks that are sensitive to it
     for (Function &IF : M->functions()) {
       HLOpcodeGroup opgroup = hlsl::GetHLOpcodeGroup(&IF);
-      if (&IF != &F && IF.getNumUses() && IF.isDeclaration() && IsHLWaveSensitive(&IF) &&
+      // Only consider wave-sensitive intrinsics or extintrinsics
+      if (IF.isDeclaration() && IsHLWaveSensitive(&IF) &&
           (opgroup == HLOpcodeGroup::HLIntrinsic || opgroup == HLOpcodeGroup::HLExtIntrinsic)) {
+        // For each user of the function, trace all its users to remove the blocks
         for (User *U : IF.users()) {
           CallInst *CI = cast<CallInst>(U);
-          CollectSensitiveBlocks(LInfo, CI, BreakFunc, SensitiveBBs);
+          SmallPtrSet<Instruction *, 16> VisitedPhis;
+          CullSensitiveBlocks(LInfo, CI->getParent(), nullptr, CI, VisitedPhis, BreakBBs);
         }
       }
     }
 
     bool Changed = false;
-    // Revert artificially conditional breaks in blocks not included in SensitiveBBs
-    for (auto &BB : F) {
-      if (!SensitiveBBs.count(&BB)) {
-        BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator());
-        if (BI && BI->isConditional()) {
-          CallInst *Cond = dyn_cast<CallInst>(BI->getCondition());
-          if (Cond && Cond->getCalledFunction() == BreakFunc) {
-            // Make branch conditional always true and erase the conditional
-            Constant *C = ConstantInt::get(Type::getInt1Ty(BI->getContext()), 1);
-            BI->setCondition(C);
-            Cond->eraseFromParent();
-            Changed = true;
-          }
-        }
-      }
+    // Revert artificially conditional breaks in non-wave-sensitive blocks that remain in BreakBBs
+    Constant *C = ConstantInt::get(Type::getInt1Ty(M->getContext()), 1);
+    for (auto &BB : BreakBBs) {
+      // Replace the call instruction with a constant boolen
+      BB.second->replaceAllUsesWith(C);
+      Changed = true;
     }
     return Changed;
   }

+ 28 - 3
tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

@@ -2571,18 +2571,43 @@ void FinishIntrinsics(
   AddOpcodeParamForIntrinsics(HLM, intrinsicMap, valToResPropertiesMap);
 }
 
-void AddDxBreak(Module &M, SmallVector<llvm::BranchInst*, 16> DxBreaks) {
+// Add the dx.break temporary intrinsic and create Call Instructions
+// to it for each branch that requires the artificial conditional.
+void AddDxBreak(Module &M, const SmallVector<llvm::BranchInst*, 16> &DxBreaks) {
   if (DxBreaks.empty())
     return;
 
+  // Collect functions that make use of any wave operations
+  // Only they will need the dx.break condition added
+  SmallPtrSet<Function *, 16> WaveUsers;
+  for (Function &F : M.functions()) {
+    HLOpcodeGroup opgroup = hlsl::GetHLOpcodeGroup(&F);
+    if (F.isDeclaration() && IsHLWaveSensitive(&F) &&
+        (opgroup == HLOpcodeGroup::HLIntrinsic || opgroup == HLOpcodeGroup::HLExtIntrinsic)) {
+      for (User *U : F.users()) {
+        CallInst *CI = cast<CallInst>(U);
+        WaveUsers.insert(CI->getParent()->getParent());
+      }
+    }
+  }
+
+  // If there are no wave users, not even the function declaration is needed
+  if (WaveUsers.empty())
+    return;
+
   // Create the dx.break function
   FunctionType *FT = llvm::FunctionType::get(llvm::Type::getInt1Ty(M.getContext()), false);
   Function *func = cast<llvm::Function>(M.getOrInsertFunction(DXIL::kDxBreakFuncName, FT));
   func->addFnAttr(Attribute::AttrKind::NoUnwind);
 
+  // For all break branches recorded previously, if the function they are in makes
+  // any use of a wave op, it may need to be artificially conditional. Make it so now.
+  // The CleanupDxBreak pass will remove those that aren't needed when more is known.
   for(llvm::BranchInst *BI : DxBreaks) {
-    CallInst *Call = CallInst::Create(FT, func, ArrayRef<Value *>(), "", BI);
-    BI->setCondition(Call);
+    if (WaveUsers.count(BI->getParent()->getParent())) {
+      CallInst *Call = CallInst::Create(FT, func, ArrayRef<Value *>(), "", BI);
+      BI->setCondition(Call);
+    }
   }
 }
 

+ 1 - 1
tools/clang/lib/CodeGen/CGHLSLMSHelper.h

@@ -95,7 +95,7 @@ void FinishIntrinsics(
     llvm::DenseMap<llvm::Value *, hlsl::DxilResourceProperties>
         &valToResPropertiesMap);
 
-void AddDxBreak(llvm::Module &M, llvm::SmallVector<llvm::BranchInst*, 16> DxBreaks);
+void AddDxBreak(llvm::Module &M, const llvm::SmallVector<llvm::BranchInst*, 16> &DxBreaks);
 
 void ReplaceConstStaticGlobals(
     std::unordered_map<llvm::GlobalVariable *, std::vector<llvm::Constant *>>

+ 1 - 2
tools/clang/lib/CodeGen/CGHLSLRuntime.h

@@ -12,11 +12,10 @@
 #pragma once
 
 #include <functional>
-#include <llvm/ADT/DenseMap.h> // HLSL Change
+#include <llvm/ADT/SmallVector.h> // HLSL Change
 
 namespace llvm {
 class Function;
-template <typename T, unsigned N> class SmallVector;
 class Value;
 class Constant;
 class TerminatorInst;

+ 49 - 0
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/reduction/WaveAndBreakMerge.hlsl

@@ -0,0 +1,49 @@
+// RUN: %dxc -T lib_6_3 %s | FileCheck %s
+// When a conditional break block follows a conditional while loop entry,
+// There can be some merging of conditionals, particularly when the dx.break
+// adds a conditional of its own. This ensures they are handled appropriately.
+
+// CHECK: @dx.break.cond = internal constant
+
+// CHECK: define i32
+// CHECK-SAME: CondMergeWave
+// CHECK: load i32
+// CHECK-SAME: @dx.break.cond
+// CHECK: icmp eq i32
+
+// These verify the break block keeps the merged conditional
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+// CHECK: and i1
+// CHECK: br i1
+// CHECK: ret i32
+export
+int CondMergeWave(uint Bits, float4 Bobs)
+{
+  while (Bits) {
+    if (Bobs.a < 0.001) {
+      Bits = WaveReadLaneFirst(Bits);
+      break;
+    }
+    Bits >>= 1;
+  }
+  return Bits;
+}
+
+// CHECK: define i32
+// CHECK-SAME: CondMerge
+
+// Shouldn't be any use of dx.break nor any need to and anything
+// CHECK-NOT: dx.break.cond
+// CHECK-NOT: and i1
+// CHECK: ret i32
+export
+int CondMerge(uint Bits, float4 Bobs)
+{
+  while (Bits) {
+    if (Bobs.a < 0.001)
+      break;
+    Bits >>= 1;
+  }
+  return Bits;
+}
+