Browse Source

Taught [unroll] that barriers are safe to unroll (#3123)

Adam Yang 5 years ago
parent
commit
05115ab926

+ 30 - 1
lib/Transforms/Scalar/DxilLoopUnroll.cpp

@@ -81,6 +81,7 @@
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 
 
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/DXIL/DxilUtil.h"
+#include "dxc/DXIL/DxilOperations.h"
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/HLModule.h"
 #include "llvm/Analysis/DxilValueCache.h"
 #include "llvm/Analysis/DxilValueCache.h"
 
 
@@ -130,6 +131,7 @@ public:
   }
   }
   const char *getPassName() const override { return "Dxil Loop Unroll"; }
   const char *getPassName() const override { return "Dxil Loop Unroll"; }
   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+  bool IsLoopSafeToClone(Loop *L);
   void getAnalysisUsage(AnalysisUsage &AU) const override {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<LoopInfoWrapperPass>();
     AU.addRequired<LoopInfoWrapperPass>();
     AU.addRequired<AssumptionCacheTracker>();
     AU.addRequired<AssumptionCacheTracker>();
@@ -636,6 +638,33 @@ static void RecursivelyRemoveLoopFromQueue(LPPassManager &LPM, Loop *L) {
   LPM.deleteLoopFromQueue(L);
   LPM.deleteLoopFromQueue(L);
 }
 }
 
 
+// Mostly copied from Loop::isSafeToClone, but making exception
+// for dx.op.barrier.
+//
+bool DxilLoopUnroll::IsLoopSafeToClone(Loop *L) {
+  // Return false if any loop blocks contain indirectbrs, or there are any calls
+  // to noduplicate functions.
+  for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); I != E; ++I) {
+    if (isa<IndirectBrInst>((*I)->getTerminator()))
+      return false;
+
+    if (const InvokeInst *II = dyn_cast<InvokeInst>((*I)->getTerminator()))
+      if (II->cannotDuplicate())
+        return false;
+
+    for (BasicBlock::iterator BI = (*I)->begin(), BE = (*I)->end(); BI != BE; ++BI) {
+      if (const CallInst *CI = dyn_cast<CallInst>(BI)) {
+        if (CI->cannotDuplicate() &&
+          !hlsl::OP::IsDxilOpFuncCallInst(CI, hlsl::OP::OpCode::Barrier))
+        {
+          return false;
+        }
+      }
+    }
+  }
+  return true;
+}
+
 bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
 
   DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
   DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
@@ -663,7 +692,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
     ExplicitUnrollCount = (unsigned)ExplicitUnrollCountSigned;
     ExplicitUnrollCount = (unsigned)ExplicitUnrollCountSigned;
   }
   }
 
 
-  if (!L->isSafeToClone())
+  if (!IsLoopSafeToClone(L))
     return false;
     return false;
 
 
   unsigned TripCount = 0;
   unsigned TripCount = 0;

+ 63 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/attributes/unroll/barriers.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// Regression test for loops that contain barriers
+
+// CHECK: @main
+
+// CHECK: barrier
+// CHECK: bufferUpdateCounter
+// CHECK: bufferStore
+
+// CHECK: barrier
+// CHECK: bufferUpdateCounter
+// CHECK: bufferStore
+
+// CHECK: barrier
+// CHECK: bufferUpdateCounter
+// CHECK: bufferStore
+
+// CHECK: barrier
+// CHECK: bufferUpdateCounter
+// CHECK: bufferStore
+
+// CHECK: barrier
+// CHECK: bufferUpdateCounter
+// CHECK: bufferStore
+
+// CHECK: barrier
+// CHECK: bufferUpdateCounter
+// CHECK: bufferStore
+
+AppendStructuredBuffer<float4> buf0;
+AppendStructuredBuffer<float4> buf1;
+AppendStructuredBuffer<float4> buf2;
+AppendStructuredBuffer<float4> buf3;
+uint g_cond;
+
+
+[numthreads( 128, 1, 1 )]
+
+void main() {
+
+  AppendStructuredBuffer<float4> buffers[] = { buf0, buf1, buf2, buf3, };
+
+  float ret = 0;
+  [unroll]
+  for (uint i = 0; i < 2; i++) {
+
+    GroupMemoryBarrierWithGroupSync();
+    buffers[i].Append(i);
+
+    [unroll]
+    for (uint j = 0; j < 2; j++) {
+      ret++;
+
+      GroupMemoryBarrierWithGroupSync();
+      buffers[i].Append(j);
+
+      GroupMemoryBarrierWithGroupSync();
+      buffers[i].Append(j);
+    }
+  }
+}
+