Переглянути джерело

Don't reprocess instr in CullSensitiveBlocks (#3029)

Previously, only PHI instructions were collected and prevented from
repeating. This was the minimum required to prevent infinite compiles,
but in complicated shaders with a lot of wave dependencies that feed in
through multiple operands, compilation can take a very long time.

By including all encountered instructions, no repetition of traversal is
performed. The cost is lookups into a much larger set, but that seems
the vastly more performant option.

By collecting call instructions and grouping them by innermost
containing loop, the collection of visited instructions can be shared
since the result of the test for any break loop containing the wave loop
will have the same result so long as the wave loop and instruction
belonging to the break loop are the same.

The instigating shader doesn't actually make use of this optimization,
so I added a test that does.
Greg Roth 5 роки тому
батько
коміт
9a93d5dcfb

+ 25 - 11
lib/HLSL/DxilPreparePasses.cpp

@@ -1085,9 +1085,9 @@ namespace {
 // Cull blocks from BreakBBs that containing instructions that are sensitive to the wave-sensitive Inst
 // Sensitivity entails being an eventual user of the Inst and also belonging to a block with
 // a break conditional on dx.break that breaks out of a loop that contains WaveCI
-// LInfo is needed to determine loop contents. VisitedPhis is needed to prevent infinit looping.
-static void CullSensitiveBlocks(LoopInfo *LInfo, BasicBlock *WaveBB, BasicBlock *LastBB, Instruction *Inst,
-                                SmallPtrSet<Instruction *, 16> &VisitedPhis,
+// LInfo is needed to determine loop contents. Visited is needed to prevent infinite looping.
+static void CullSensitiveBlocks(LoopInfo *LInfo, Loop *WaveLoop, BasicBlock *LastBB, Instruction *Inst,
+                                std::unordered_set<Instruction *> &Visited,
                                 SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
   BasicBlock *BB = Inst->getParent();
   Loop *BreakLoop = LInfo->getLoopFor(BB);
@@ -1095,9 +1095,8 @@ static void CullSensitiveBlocks(LoopInfo *LInfo, BasicBlock *WaveBB, BasicBlock
   if (!BreakLoop || BreakBBs.empty())
     return;
 
-  // To prevent infinite looping, only visit each PHI once
-  // If we've seen this PHI before, don't reprocess it
-  if (isa<PHINode>(Inst) && !VisitedPhis.insert(Inst).second)
+  // To prevent infinite looping, only visit each instruction once
+  if (!Visited.insert(Inst).second)
     return;
 
   // If this BB wasn't already just processed, handle it now
@@ -1105,14 +1104,14 @@ static void CullSensitiveBlocks(LoopInfo *LInfo, BasicBlock *WaveBB, BasicBlock
     // Determine if the instruction's block has an artificially-conditional break
     // and breaks out of a loop that contains the waveCI
     BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
-    if (BI && BI->isConditional() && BreakLoop->contains(WaveBB))
+    if (BI && BI->isConditional() && BreakLoop->contains(WaveLoop))
       BreakBBs.erase(BB);
   }
 
   // Recurse on the users
   for (User *U : Inst->users()) {
     Instruction *I = cast<Instruction>(U);
-    CullSensitiveBlocks(LInfo, WaveBB, BB, I, VisitedPhis, BreakBBs);
+    CullSensitiveBlocks(LInfo, WaveLoop, BB, I, Visited, BreakBBs);
   }
 }
 
@@ -1171,7 +1170,11 @@ public:
     if (BreakBBs.empty())
       return false;
 
-    // For each wave operation, remove all the dx.break blocks that are sensitive to it
+
+
+    // Collect all wave calls in this function and group by loop
+    SmallDenseMap<Loop *, SmallVector<CallInst *, 8>, 16> WaveCalls;
+
     for (Function &IF : M->functions()) {
       HLOpcodeGroup opgroup = hlsl::GetHLOpcodeGroup(&IF);
       // Only consider wave-sensitive intrinsics or extintrinsics
@@ -1181,13 +1184,24 @@ public:
         for (User *U : IF.users()) {
           CallInst *CI = cast<CallInst>(U);
           if (CI->getParent()->getParent() == &F) {
-            SmallPtrSet<Instruction *, 16> VisitedPhis;
-            CullSensitiveBlocks(LInfo, CI->getParent(), nullptr, CI, VisitedPhis, BreakBBs);
+            Loop *WaveLoop = LInfo->getLoopFor(CI->getParent());
+            WaveCalls[WaveLoop].emplace_back(CI);
           }
         }
       }
     }
 
+    // For each wave operation, remove all the dx.break blocks that are sensitive to it
+    for (DenseMap<Loop*, SmallVector<CallInst *, 8>>::iterator I =
+           WaveCalls.begin(), E = WaveCalls.end();
+         I != E; ++I) {
+      Loop *loop = I->first;
+      std::unordered_set<Instruction *> Visited;
+      for (CallInst *CI : I->second) {
+        CullSensitiveBlocks(LInfo, loop, nullptr, CI, Visited, BreakBBs);
+      }
+    }
+
     bool Changed = false;
     // Revert artificially conditional breaks in non-wave-sensitive blocks that remain in BreakBBs
     Constant *C = ConstantInt::get(Type::getInt1Ty(M->getContext()), 1);

+ 131 - 8
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/reduction/WaveAndBreakPS.hlsl

@@ -16,6 +16,7 @@ int WaveInBreakBlock(int a : A, int b : B);
 int WaveInEntry(int a : A, int b : B);
 int WaveInSubLoop(int a : A, int b : B);
 int WaveInOtherLoop(int a : A, int b : B, int c : C);
+int MultiWaveInMultiLoops(int a : A, int b : B, int c : C, uint d : D);
 
 // CHECK: @dx.break.cond = internal constant
 
@@ -37,7 +38,7 @@ int WaveInOtherLoop(int a : A, int b : B, int c : C);
 // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
 // CHECK-SAME: %mainBuf
 
-int main(int a : A, int b : B, int c : C) : SV_Target
+int main(int a : A, int b : B, int c : C, int d : D) : SV_Target
 {
   int res = 0;
   int i = 0;
@@ -57,7 +58,7 @@ int main(int a : A, int b : B, int c : C) : SV_Target
         }
     }
   return res + WaveInPostLoop(a, b) + WaveInBreakBlock(a, b) + WaveInEntry(a, b) +
-    WaveInSubLoop(a,b) + WaveInOtherLoop(a,b,c);
+    WaveInSubLoop(a,b) + WaveInOtherLoop(a,b,c) + MultiWaveInMultiLoops(a,b,c,d);
 }
 
 // Wave moved to after the break block. Expected to keep the block in loop
@@ -79,7 +80,6 @@ int main(int a : A, int b : B, int c : C) : SV_Target
 // CHECK: br i1
 
 // CHECK: call i32 @dx.op.waveReadLaneFirst
-export
 int WaveInPostLoop(int a : A, int b : B)
 {
   int res = 0;
@@ -117,7 +117,6 @@ int WaveInPostLoop(int a : A, int b : B)
 // CHECK-SAME: %breakBuf
 // CHECK: br i1
 
-export
 int WaveInBreakBlock(int a : A, int b : B)
 {
   int res = 0;
@@ -147,7 +146,6 @@ int WaveInBreakBlock(int a : A, int b : B)
 // CHECK: call %dx.types.Handle @dx.op.createHandle
 // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
 // CHECK-SAME: %entryBuf
-export
 int WaveInEntry(int a : A, int b : B)
 {
   int res = 0;
@@ -190,7 +188,6 @@ int WaveInEntry(int a : A, int b : B)
 // CHECK-SAME: %subBuf
 // CHECK: add
 // CHECK: br i1
-export
 int WaveInSubLoop(int a : A, int b : B)
 {
   int res = 0;
@@ -235,13 +232,11 @@ int WaveInSubLoop(int a : A, int b : B)
 // CHECK: call %dx.types.Handle @dx.op.createHandle
 // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
 // CHECK-SAME: %otherBuf
-// CHECK-NOT: br i1
 
 // These verify the third break block doesn't
 // CHECK: call %dx.types.Handle @dx.op.createHandle
 // CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
 // CHECK-SAME: %otherBuf
-// CHECK: add
 int WaveInOtherLoop(int a : A, int b : B, int c : C)
 {
   int res = 0;
@@ -275,6 +270,134 @@ int WaveInOtherLoop(int a : A, int b : B, int c : C)
   return res;
 }
 
+// Complicated case where multiple waves are in multiple loops overlapping and not
+
+// Position all the wave ops
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+// CHECK: call i32 @dx.op.waveActiveOp
+// CHECK: call i32 @dx.op.waveActiveOp
+// CHECK: call i32 @dx.op.waveActiveBit
+
+// These verify the first four break blocks keep the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+
+// Repeat for second loop
+
+// Position all the wave ops
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+// CHECK: call i32 @dx.op.waveActiveOp
+// CHECK: call i32 @dx.op.waveActiveOp
+// CHECK: call i32 @dx.op.waveActiveBit
+
+// These verify the last four break blocks keep the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+
+
+int MultiWaveInMultiLoops(int a : A, int b : B, int c : C, uint d : D)
+{
+  int res = 0;
+  int u = 0;
+  int v = 0;
+  int w = 0;
+  int x = 0;
+
+  for (;;) {
+    u += WaveReadLaneFirst(a);
+    v += WaveActiveSum(b);
+    for (int i = 0; i < c; i++) {
+      w += WaveActiveProduct(c);
+      x += WaveActiveBitAnd(d);
+    }
+    if (a == u + b) {
+      res += mainBuf[u + c][b];
+      break;
+    }
+    if (a == v + b) {
+      res += mainBuf[v + c][b];
+      break;
+    }
+    if (a == w + b) {
+      res += mainBuf[w + c][b];
+      break;
+    }
+    if (a == x + b) {
+      res += mainBuf[x + c][b];
+      break;
+    }
+  }
+
+  u = 0;
+  v = 0;
+  w = 0;
+  x = 0;
+
+  for (;;) {
+    u += WaveReadLaneFirst(a);
+    v += WaveActiveSum(b);
+    for (int i = 0; i < c; i++) {
+      w += WaveActiveProduct(c);
+      x += WaveActiveBitAnd(d);
+    }
+    if (a == b + u) {
+      res += mainBuf[c + u][b];
+      break;
+    }
+    if (a == b + v) {
+      res += mainBuf[c + v][b];
+      break;
+    }
+    if (a == b + w) {
+      res += mainBuf[c + w][b];
+      break;
+    }
+    if (a == b + x) {
+      res += mainBuf[c + x][b];
+      break;
+    }
+  }
+
+
+  return res;
+}
+
 // Final operations
 // CHECK-NOT: br i1
 // CHECK: call void @dx.op.storeOutput