فهرست منبع

Fixed [unroll] being interpreted as [unroll(1)]. (#1669)

Tristan Labelle 6 سال پیش
والد
کامیت
d47cbb6579

+ 8 - 8
tools/clang/lib/CodeGen/CGLoopInfo.cpp

@@ -22,7 +22,7 @@ static MDNode *createMetadata(LLVMContext &Ctx, const LoopAttributes &Attrs) {
 
   if (!Attrs.IsParallel && Attrs.VectorizerWidth == 0 &&
       Attrs.VectorizerUnroll == 0 &&
-      Attrs.HlslLoop == false && // HLSL Change
+      Attrs.HlslUnrollPolicy == LoopAttributes::HlslAllowUnroll && // HLSL Change
       Attrs.HlslUnrollCount == 0 && // HLSL Change
       Attrs.VectorizerEnable == LoopAttributes::VecUnspecified)
     return nullptr;
@@ -59,15 +59,15 @@ static MDNode *createMetadata(LLVMContext &Ctx, const LoopAttributes &Attrs) {
   }
 
   // HLSL Change Begins.
-  if (Attrs.HlslLoop) {
+  if (Attrs.HlslUnrollPolicy == LoopAttributes::HlslDisableUnroll) {
     // Disable unroll.
     SmallVector<Metadata *, 1> DisableOperands;
     DisableOperands.push_back(MDString::get(Ctx, "llvm.loop.unroll.disable"));
     MDNode *DisableNode = MDNode::get(Ctx, DisableOperands);
     Args.push_back(DisableNode);
   }
-  else if (Attrs.HlslUnrollCount) {
-    if (Attrs.HlslUnrollCount == 1) {
+  else if (Attrs.HlslUnrollPolicy == LoopAttributes::HlslForceUnroll) {
+    if (Attrs.HlslUnrollCount == 0) {
       // Full unroll.
       SmallVector<Metadata *, 1> FullOperands;
       FullOperands.push_back(MDString::get(Ctx, "llvm.loop.unroll.full"));
@@ -91,14 +91,14 @@ static MDNode *createMetadata(LLVMContext &Ctx, const LoopAttributes &Attrs) {
 LoopAttributes::LoopAttributes(bool IsParallel)
     : IsParallel(IsParallel), VectorizerEnable(LoopAttributes::VecUnspecified),
       VectorizerWidth(0), VectorizerUnroll(0),
-      HlslLoop(false), HlslUnrollCount(0) {} // HLSL Change
+      HlslUnrollPolicy(LoopAttributes::HlslAllowUnroll), HlslUnrollCount(0) {} // HLSL Change
 
 void LoopAttributes::clear() {
   IsParallel = false;
   VectorizerWidth = 0;
   VectorizerUnroll = 0;
   VectorizerEnable = LoopAttributes::VecUnspecified;
-  HlslLoop = false; // HLSL Change
+  HlslUnrollPolicy = LoopAttributes::HlslAllowUnroll; // HLSL Change
   HlslUnrollCount = 0; // HLSL Change
 }
 
@@ -113,11 +113,11 @@ void LoopInfoStack::push(BasicBlock *Header,
     const LoopHintAttr *LH = dyn_cast<LoopHintAttr>(Attr);
     // HLSL Change Begins
     if (dyn_cast<HLSLLoopAttr>(Attr)) {
-      setHlslLoop(true);
+      setHlslLoop();
     } else if (const HLSLUnrollAttr *UnrollAttr =
                    dyn_cast<HLSLUnrollAttr>(Attr)) {
       unsigned count = UnrollAttr->getCount();
-      setHlslUnrollCount(count);
+      setHlslUnroll(count);
     }
     // HLSL Change Ends
     // Skip non loop hint attributes

+ 13 - 5
tools/clang/lib/CodeGen/CGLoopInfo.h

@@ -52,9 +52,12 @@ struct LoopAttributes {
   unsigned VectorizerUnroll;
 
   // HLSL Change Begins.
-  /// \brief hlsl [loop] attribute
-  bool     HlslLoop;
-  /// \brief hlsl [unroll] attribute
+  /// \brief hlsl loop unrolling policy based on [loop] and [unroll] attributes
+  enum HlslUnrollPolicyEnum { HlslAllowUnroll, HlslDisableUnroll, HlslForceUnroll };
+
+  /// \brief hlsl unrolling policy
+  HlslUnrollPolicyEnum HlslUnrollPolicy;
+  /// \brief argument to hlsl [unroll] attribute, 0 = full unroll
   unsigned HlslUnrollCount;
   // HLSL Change Ends.
 };
@@ -130,9 +133,14 @@ public:
 
   // HLSL Change Begins
   /// \brief Set the hlsl unroll count for the next loop pushed.
-  void setHlslUnrollCount(unsigned U) { StagedAttrs.HlslUnrollCount = U; }
+  void setHlslUnroll(unsigned U) {
+    StagedAttrs.HlslUnrollPolicy = LoopAttributes::HlslForceUnroll;
+    StagedAttrs.HlslUnrollCount = U;
+  }
   /// \brief Set the hlsl loop for the next loop pushed.
-  void setHlslLoop(bool Enable = true) { StagedAttrs.HlslLoop = Enable; }
+  void setHlslLoop() {
+    StagedAttrs.HlslUnrollPolicy = LoopAttributes::HlslDisableUnroll;
+  }
   // HLSL Chagne Ends
 
 private:

+ 6 - 3
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -10618,8 +10618,8 @@ static Attr* HandleClipPlanes(Sema& S, const AttributeList &A)
 static Attr* HandleUnrollAttribute(Sema& S, const AttributeList &Attr)
 {
   int argValue = ValidateAttributeIntArg(S, Attr);
-  // Default value is 1.
-  if (Attr.getNumArgs() == 0) argValue = 1;
+  // Default value is 0 (full unroll).
+  if (Attr.getNumArgs() == 0) argValue = 0;
   return ::new (S.Context) HLSLUnrollAttr(Attr.getRange(), S.Context,
     argValue, Attr.getAttributeSpellingListIndex());
 }
@@ -12238,7 +12238,10 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, con
     Attr * noconst = const_cast<Attr*>(A);
     HLSLUnrollAttr *ACast = static_cast<HLSLUnrollAttr*>(noconst);
     Indent(Indentation, Out);
-    Out << "[unroll(" << ACast->getCount() << ")]\n";
+    if (ACast->getCount() == 0)
+      Out << "[unroll]\n";
+    else
+      Out << "[unroll(" << ACast->getCount() << ")]\n";
     break;
   }
   

+ 27 - 0
tools/clang/test/HLSL/rewriter/attributes_noerr.hlsl

@@ -19,6 +19,33 @@
 //  return 0;
 //}
 
+int unroll_noarg() {
+  int result = 2;
+  
+  [unroll]
+  for (int i = 0; i < 100; i++) result++;
+  
+  return result;
+}
+
+int unroll_zero() {
+  int result = 2;
+  
+  [unroll(0)]
+  for (int i = 0; i < 100; i++) result++;
+  
+  return result;
+}
+
+int unroll_one() {
+  int result = 2;
+  
+  [unroll(1)]
+  for (int i = 0; i < 100; i++) result++;
+  
+  return result;
+}
+
 int short_unroll() {
   int result = 2;
   

+ 27 - 0
tools/clang/test/HLSL/rewriter/correct_rewrites/attributes_gold.hlsl

@@ -1,4 +1,31 @@
 // Rewrite unchanged result:
+int unroll_noarg() {
+  int result = 2;
+  [unroll]
+  for (int i = 0; i < 100; i++) 
+    result++;
+  return result;
+}
+
+
+int unroll_zero() {
+  int result = 2;
+  [unroll]
+  for (int i = 0; i < 100; i++) 
+    result++;
+  return result;
+}
+
+
+int unroll_one() {
+  int result = 2;
+  [unroll(1)]
+  for (int i = 0; i < 100; i++) 
+    result++;
+  return result;
+}
+
+
 int short_unroll() {
   int result = 2;
   [unroll(2)]