Răsfoiți Sursa

Added support for [unroll(n)] (#1783)

Adam Yang 6 ani în urmă
părinte
comite
45bbafc13e

+ 45 - 1
lib/Transforms/Scalar/DxilLoopUnroll.cpp

@@ -207,6 +207,21 @@ static bool IsMarkedFullUnroll(Loop *L) {
   return false;
 }
 
+static bool IsMarkedUnrollCount(Loop *L, unsigned *OutCount) {
+  if (MDNode *LoopID = L->getLoopID()) {
+    if (MDNode *MD = GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
+      assert(MD->getNumOperands() == 2 &&
+             "Unroll count hint metadata should have two operands.");
+      unsigned Count =
+        mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue();
+      assert(Count >= 1 && "Unroll count must be positive.");
+      *OutCount = Count;
+      return true;
+    }
+  }
+  return false;
+}
+
 static bool HasSuccessorsInLoop(BasicBlock *BB, Loop *L) {
   for (BasicBlock *Succ : successors(BB)) {
     if (L->contains(Succ)) {
@@ -513,11 +528,19 @@ static bool Mem2Reg(Function &F, DominatorTree &DT, AssumptionCache &AC) {
 }
 
 
+
 bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
 
+  bool HasExplicitLoopCount = false;
+  unsigned UnrollCount = 0;
+
   // If the loop is not marked as [unroll], don't do anything.
-  if (!IsMarkedFullUnroll(L))
+  if (IsMarkedUnrollCount(L, &UnrollCount)) {
+    HasExplicitLoopCount = true;
+  }
+  else if (!IsMarkedFullUnroll(L)) {
     return false;
+  }
 
   if (!L->isSafeToClone())
     return false;
@@ -750,6 +773,27 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
         }
       }
     }
+
+    // We've reached the N defined in [unroll(N)]
+    if (HasExplicitLoopCount && IterationI+1 >= UnrollCount) {
+      Succeeded = true;
+      BranchInst *BI = cast<BranchInst>(CurIteration.Latch->getTerminator());
+
+      BasicBlock *ExitBlock = nullptr;
+      for (unsigned i = 0; i < BI->getNumSuccessors(); i++) {
+        BasicBlock *Succ = BI->getSuccessor(i);
+        if (Succ != CurIteration.Header) {
+          ExitBlock = Succ;
+          break;
+        }
+      }
+
+      BranchInst *NewBI = BranchInst::Create(ExitBlock, BI);
+      BI->replaceAllUsesWith(NewBI);
+      BI->eraseFromParent();
+
+      break;
+    }
   }
 
   if (Succeeded) {

+ 1 - 1
tools/clang/test/CodeGenHLSL/evalMat.hlsl

@@ -7,4 +7,4 @@ float4 main(float4x4 a : A) : SV_Target
   float4 r = EvaluateAttributeCentroid(a)[0];
 
   return r;
-}
+}

+ 18 - 0
tools/clang/test/CodeGenHLSL/unroll/count_cbuff.hlsl

@@ -0,0 +1,18 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: call float @dx.op.dot3
+// CHECK: call float @dx.op.dot3
+// CHECK: call float @dx.op.dot3
+// CHECK-NOT: call float @dx.op.dot3
+
+uint g_cond;
+float main(float3 a : A, float3 b : B) : SV_Target {
+
+  float result = 0;
+  [unroll(3)]
+  for (int i = 0; i < g_cond; i++) {
+    result += dot(a*i, b);
+  }
+  return result;
+}
+

+ 29 - 0
tools/clang/test/CodeGenHLSL/unroll/count_cbuff_br.hlsl

@@ -0,0 +1,29 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Identical to count_cbuff.hlsl, except checks for number of br's
+
+// entry
+// CHECK: br
+// loop iteration
+// CHECK: call float @dx.op.dot3
+// CHECK: br
+// loop iteration
+// CHECK: call float @dx.op.dot3
+// CHECK: br
+// loop iteration, unconditional
+// CHECK: call float @dx.op.dot3
+// CHECK: br
+// return
+// CHECK-NOT: br
+
+uint g_cond;
+float main(float3 a : A, float3 b : B) : SV_Target {
+
+  float result = 0;
+  [unroll(3)]
+  for (int i = 0; i < g_cond; i++) {
+    result += dot(a*i, b);
+  }
+  return result;
+}
+

+ 15 - 0
tools/clang/test/CodeGenHLSL/unroll/count_greater_than_i.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: call float @dx.op.dot3
+// CHECK: call float @dx.op.dot3
+// CHECK-NOT: call float @dx.op.dot3
+
+float main(float3 a : A, float3 b : B) : SV_Target {
+  float result = 0;
+  [unroll(3)]
+  for (int i = 0; i < 2; i++) {
+    result += dot(a*i, b);
+  }
+  return result;
+}
+

+ 16 - 0
tools/clang/test/CodeGenHLSL/unroll/count_less_than_i.hlsl

@@ -0,0 +1,16 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: call float @dx.op.dot3
+// CHECK: call float @dx.op.dot3
+// CHECK: call float @dx.op.dot3
+// CHECK-NOT: call float @dx.op.dot3
+
+float main(float3 a : A, float3 b : B) : SV_Target {
+  float result = 0;
+  [unroll(3)]
+  for (int i = 0; i < 10; i++) {
+    result += dot(a*i, b);
+  }
+  return result;
+}
+

+ 17 - 0
tools/clang/test/CodeGenHLSL/unroll/count_negative.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: attribute 'unroll' must have a uint literal argument
+// CHECK-NOT: @main
+
+uint g_cond;
+
+float main() : SV_Target {
+  float result = 0;
+  [unroll(-1)]
+  for (int i = 0; i < g_cond; i++) {
+    result += i;
+  }
+
+  return 0;
+}
+

+ 15 - 0
tools/clang/test/CodeGenHLSL/unroll/count_zero.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK-DAG: Could not unroll loop.
+// CHECK-NOT: @main
+uint g_cond;
+float main() : SV_Target {
+
+  float result = 0;
+  [unroll(0)]
+  for (int i = 0; i < g_cond; i++) {
+    result += i;
+  }
+  return result;
+}
+