瀏覽代碼

unroll(n) can now override loop bound. unroll(negative) now fails correctly (#2241)

Adam Yang 6 年之前
父節點
當前提交
dc6203ad5b

+ 24 - 10
lib/Transforms/Scalar/DxilLoopUnroll.cpp

@@ -203,14 +203,13 @@ static bool IsMarkedFullUnroll(Loop *L) {
   return false;
 }
 
-static bool IsMarkedUnrollCount(Loop *L, unsigned *OutCount) {
+static bool IsMarkedUnrollCount(Loop *L, int *OutCount) {
   if (MDNode *LoopID = L->getLoopID()) {
     if (MDNode *MD = GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
       assert(MD->getNumOperands() == 2 &&
              "Unroll count hint metadata should have two operands.");
-      unsigned Count =
-        mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue();
-      assert(Count >= 1 && "Unroll count must be positive.");
+      ConstantInt *Val = mdconst::extract<ConstantInt>(MD->getOperand(1));
+      int Count = Val->getZExtValue();
       *OutCount = Count;
       return true;
     }
@@ -683,22 +682,32 @@ static void RecursivelyRemoveLoopFromQueue(LPPassManager &LPM, Loop *L) {
 
 bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
+  DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
+  Function *F = L->getHeader()->getParent();
+
   bool HasExplicitLoopCount = false;
-  unsigned UnrollCount = 0;
+  int ExplicitUnrollCountSigned = 0;
 
   // If the loop is not marked as [unroll], don't do anything.
-  if (IsMarkedUnrollCount(L, &UnrollCount)) {
+  if (IsMarkedUnrollCount(L, &ExplicitUnrollCountSigned)) {
     HasExplicitLoopCount = true;
   }
   else if (!IsMarkedFullUnroll(L)) {
     return false;
   }
 
+  unsigned ExplicitUnrollCount = 0;
+  if (HasExplicitLoopCount) {
+    if (ExplicitUnrollCountSigned < 1) {
+      FailLoopUnroll(false, F->getContext(), LoopLoc, "Could not unroll loop. Invalid unroll count.");
+      return false;
+    }
+    ExplicitUnrollCount = (unsigned)ExplicitUnrollCountSigned;
+  }
+
   if (!L->isSafeToClone())
     return false;
 
-  DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
-  Function *F = L->getHeader()->getParent();
   bool FxcCompatMode = false;
   if (F->getParent()->HasHLModule()) {
     HLModule &HM = F->getParent()->GetHLModule();
@@ -830,6 +839,9 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   SmallVector<std::unique_ptr<LoopIteration>, 16> Iterations; // List of cloned iterations
   bool Succeeded = false;
 
+  if (HasExplicitLoopCount) {
+    this->MaxIterationAttempt = std::max(this->MaxIterationAttempt, ExplicitUnrollCount);
+  }
   for (unsigned IterationI = 0; IterationI < this->MaxIterationAttempt; IterationI++) {
 
     LoopIteration *PrevIteration = nullptr;
@@ -945,7 +957,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
     }
 
     // We've reached the N defined in [unroll(N)]
-    if (HasExplicitLoopCount && IterationI+1 >= UnrollCount) {
+    if (HasExplicitLoopCount && IterationI+1 >= ExplicitUnrollCount) {
       Succeeded = true;
       BranchInst *BI = cast<BranchInst>(CurIteration.Latch->getTerminator());
 
@@ -1049,7 +1061,9 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
   // If we were unsuccessful in unrolling the loop
   else {
-    FailLoopUnroll(FxcCompatMode /*warn only*/, F->getContext(), LoopLoc, "Could not unroll loop.");
+    FailLoopUnroll(FxcCompatMode /*warn only*/, F->getContext(), LoopLoc,
+      "Could not unroll loop. Loop bound could not be deduced at compile time. "
+      "To give an explicit unroll bound, use unroll(n).");
 
     // Remove all the cloned blocks
     for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {

+ 20 - 0
tools/clang/test/CodeGenHLSL/batch/unroll/explicit_large_count.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -Od -E main -T ps_6_0 %s | FileCheck %s
+// CHECK: @main
+
+// Confirm that the 128 limit on loop unroll can be overritten by an explicit
+// loop count
+
+[RootSignature("")]
+float main(float y : Y) : SV_Target {
+  float x = 0;
+
+  static const uint kLoopCount = 512;
+
+  [unroll(kLoopCount)]
+  for (uint i = 0; i < kLoopCount; ++i)
+  {
+    x = x * x + y;
+  }
+  return x;
+}
+

+ 18 - 0
tools/clang/test/CodeGenHLSL/batch/unroll/large_count.hlsl

@@ -0,0 +1,18 @@
+// RUN: %dxc -Od -E main -T ps_6_0 %s | FileCheck %s
+// CHECK: Could not unroll loop
+// CHECK: To give an explicit unroll bound, use unroll(n)
+// CHECK-NOT: @main
+
+[RootSignature("")]
+float main(float y : Y) : SV_Target {
+  float x = 0;
+
+  static const uint kLoopCount = 512;
+
+  [unroll]
+  for (uint i = 0; i < kLoopCount; ++i)
+  {
+    x = x * x + y;
+  }
+  return x;
+}