Browse Source

[spirv] Translation of do-while loops (#539)

Ehsan 8 years ago
parent
commit
0a04055996

+ 14 - 10
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -150,16 +150,20 @@ public:
                     uint32_t defaultLabel,
                     llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target);
 
-  // \brief Creates an unconditional branch to the given target label.
-  void createBranch(uint32_t targetLabel);
-
-  // \brief Creates a conditional branch. An OpSelectionMerge instruction
-  // will be created if mergeLabel is not 0 and continueLabel is 0.
-  // An OpLoopMerge instruction will also be created if both continueLabel
-  // and mergeLabel are not 0. For other cases, mergeLabel and continueLabel
-  // will be ignored. If selection control mask and/or loop control mask are
-  // provided, they will be applied to the corresponding SPIR-V instruction.
-  // Otherwise, MaskNone will be used.
+  /// \brief Creates an unconditional branch to the given target label.
+  /// If mergeBB and continueBB are non-zero, it creates an OpLoopMerge
+  /// instruction followed by an unconditional branch to the given target label.
+  void createBranch(
+      uint32_t targetLabel, uint32_t mergeBB = 0, uint32_t continueBB = 0,
+      spv::LoopControlMask loopControl = spv::LoopControlMask::MaskNone);
+
+  /// \brief Creates a conditional branch. An OpSelectionMerge instruction
+  /// will be created if mergeLabel is not 0 and continueLabel is 0.
+  /// An OpLoopMerge instruction will also be created if both continueLabel
+  /// and mergeLabel are not 0. For other cases, mergeLabel and continueLabel
+  /// will be ignored. If selection control mask and/or loop control mask are
+  /// provided, they will be applied to the corresponding SPIR-V instruction.
+  /// Otherwise, MaskNone will be used.
   void createConditionalBranch(
       uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
       uint32_t mergeLabel = 0, uint32_t continueLabel = 0,

+ 8 - 1
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -234,9 +234,16 @@ void ModuleBuilder::createSwitch(
   insertPoint->appendInstruction(std::move(constructSite));
 }
 
-void ModuleBuilder::createBranch(uint32_t targetLabel) {
+void ModuleBuilder::createBranch(uint32_t targetLabel, uint32_t mergeBB,
+                                 uint32_t continueBB,
+                                 spv::LoopControlMask loopControl) {
   assert(insertPoint && "null insert point");
 
+  if (mergeBB && continueBB) {
+    instBuilder.opLoopMerge(mergeBB, continueBB, loopControl).x();
+    insertPoint->appendInstruction(std::move(constructSite));
+  }
+
   instBuilder.opBranch(targetLabel).x();
   insertPoint->appendInstruction(std::move(constructSite));
 }

+ 89 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -203,6 +203,8 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
     processCaseStmtOrDefaultStmt(stmt);
   } else if (const auto *breakStmt = dyn_cast<BreakStmt>(stmt)) {
     doBreakStmt(breakStmt);
+  } else if (const auto *theDoStmt = dyn_cast<DoStmt>(stmt)) {
+    doDoStmt(theDoStmt, attrs);
   } else if (const auto *whileStmt = dyn_cast<WhileStmt>(stmt)) {
     doWhileStmt(whileStmt, attrs);
   } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
@@ -469,6 +471,93 @@ spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) {
   return spv::LoopControlMask::MaskNone;
 }
 
+void SPIRVEmitter::doDoStmt(const DoStmt *theDoStmt,
+                            llvm::ArrayRef<const Attr *> attrs) {
+  // do-while loops are composed of:
+  //
+  // do {
+  //   <body>
+  // } while(<check>);
+  //
+  // SPIR-V requires loops to have a merge basic block as well as a continue
+  // basic block. Even though do-while loops do not have an explicit continue
+  // block as in for-loops, we still do need to create a continue block.
+  //
+  // Since SPIR-V requires structured control flow, we need two more basic
+  // blocks, <header> and <merge>. <header> is the block before control flow
+  // diverges, and <merge> is the block where control flow subsequently
+  // converges. The <check> can be performed in the <continue> basic block.
+  // The final CFG should normally be like the following. Exceptions
+  // will occur with non-local exits like loop breaks or early returns.
+  //
+  //            +----------+
+  //            |  header  | <-----------------------------------+
+  //            +----------+                                     |
+  //                 |                                           |  (true)
+  //                 v                                           |
+  //             +------+       +--------------------+           |
+  //             | body | ----> | continue (<check>) |-----------+
+  //             +------+       +--------------------+
+  //                                     |
+  //                                     | (false)
+  //             +-------+               |
+  //             | merge | <-------------+
+  //             +-------+
+  //
+  // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
+
+  const spv::LoopControlMask loopControl =
+      attrs.empty() ? spv::LoopControlMask::MaskNone
+                    : translateLoopAttribute(*attrs.front());
+
+  // Create basic blocks
+  const uint32_t headerBB = theBuilder.createBasicBlock("do_while.header");
+  const uint32_t bodyBB = theBuilder.createBasicBlock("do_while.body");
+  const uint32_t continueBB = theBuilder.createBasicBlock("do_while.continue");
+  const uint32_t mergeBB = theBuilder.createBasicBlock("do_while.merge");
+
+  // Branch from the current insert point to the header block.
+  theBuilder.createBranch(headerBB);
+  theBuilder.addSuccessor(headerBB);
+
+  // Process the <header> block
+  // The header block must always branch to the body.
+  theBuilder.setInsertPoint(headerBB);
+  theBuilder.createBranch(bodyBB, mergeBB, continueBB, loopControl);
+  theBuilder.addSuccessor(bodyBB);
+  // The current basic block has OpLoopMerge instruction. We need to set its
+  // continue and merge target.
+  theBuilder.setContinueTarget(continueBB);
+  theBuilder.setMergeTarget(mergeBB);
+
+  // Process the <body> block
+  theBuilder.setInsertPoint(bodyBB);
+  if (const Stmt *body = theDoStmt->getBody()) {
+    doStmt(body);
+  }
+  theBuilder.createBranch(continueBB);
+  theBuilder.addSuccessor(continueBB);
+
+  // Process the <continue> block. The check for whether the loop should
+  // continue lies in the continue block.
+  // *NOTE*: There's a SPIR-V rule that when a conditional branch is to occur in
+  // a continue block of a loop, there should be no OpSelectionMerge. Only an
+  // OpBranchConditional must be specified.
+  theBuilder.setInsertPoint(continueBB);
+  uint32_t condition = 0;
+  if (const Expr *check = theDoStmt->getCond()) {
+    condition = doExpr(check);
+  } else {
+    condition = theBuilder.getConstantBool(true);
+  }
+  theBuilder.createConditionalBranch(condition, headerBB, mergeBB);
+  theBuilder.addSuccessor(headerBB);
+  theBuilder.addSuccessor(mergeBB);
+
+  // Set insertion point to the <merge> block for subsequent statements
+  theBuilder.setInsertPoint(mergeBB);
+}
+
 void SPIRVEmitter::doWhileStmt(const WhileStmt *whileStmt,
                                llvm::ArrayRef<const Attr *> attrs) {
   // While loops are composed of:

+ 1 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -80,6 +80,7 @@ private:
   void doSwitchStmt(const SwitchStmt *stmt,
                     llvm::ArrayRef<const Attr *> attrs = {});
   void doWhileStmt(const WhileStmt *, llvm::ArrayRef<const Attr *> attrs = {});
+  void doDoStmt(const DoStmt *, llvm::ArrayRef<const Attr *> attrs = {});
 
   uint32_t doBinaryOperator(const BinaryOperator *expr);
   uint32_t doCallExpr(const CallExpr *callExpr);

+ 70 - 0
tools/clang/test/CodeGenSPIRV/do-stmt.nested.hlsl

@@ -0,0 +1,70 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+  int val=0, i=0, j=0, k=0;
+
+// CHECK:      OpBranch %do_while_header
+// CHECK-NEXT: %do_while_header = OpLabel
+// CHECK-NEXT: OpLoopMerge %do_while_merge %do_while_continue DontUnroll
+  [loop] do {
+// CHECK-NEXT: OpBranch %do_while_body
+// CHECK-NEXT: %do_while_body = OpLabel
+// CHECK-NEXT: [[val0:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[val_plus_i:%\d+]] = OpIAdd %int [[val0]] [[i0]]
+// CHECK-NEXT: OpStore %val [[val_plus_i]]
+// CHECK-NEXT: OpBranch %do_while_header_0
+    val = val + i;
+// CHECK-NEXT: %do_while_header_0 = OpLabel
+// CHECK-NEXT: OpLoopMerge %do_while_merge_0 %do_while_continue_0 Unroll
+// CHECK-NEXT: OpBranch %do_while_body_0
+    [unroll(20)] do {
+// CHECK-NEXT: %do_while_body_0 = OpLabel
+// CHECK-NEXT: OpBranch %do_while_header_1
+
+// CHECK-NEXT: %do_while_header_1 = OpLabel
+// CHECK-NEXT: OpLoopMerge %do_while_merge_1 %do_while_continue_1 DontUnroll
+// CHECK-NEXT: OpBranch %do_while_body_1
+      [fastopt] do {
+// CHECK-NEXT: %do_while_body_1 = OpLabel
+// CHECK-NEXT: [[k0:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[k_plus_1:%\d+]] = OpIAdd %int [[k0]] %int_1
+// CHECK-NEXT: OpStore %k [[k_plus_1]]
+// CHECK-NEXT: OpBranch %do_while_continue_1
+        ++k;
+// CHECK-NEXT: %do_while_continue_1 = OpLabel
+// CHECK-NEXT: [[k1:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[k_lt_30:%\d+]] = OpSLessThan %bool [[k1]] %int_30
+// CHECK-NEXT: OpBranchConditional [[k_lt_30]] %do_while_header_1 %do_while_merge_1
+      } while (k < 30);
+
+// CHECK-NEXT: %do_while_merge_1 = OpLabel
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[j_plus_1:%\d+]] = OpIAdd %int [[j0]] %int_1
+// CHECK-NEXT: OpStore %j [[j_plus_1]]
+// CHECK-NEXT: OpBranch %do_while_continue_0
+      ++j;
+// CHECK-NEXT: %do_while_continue_0 = OpLabel
+// CHECK-NEXT: [[j1:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[j_lt_20:%\d+]] = OpSLessThan %bool [[j1]] %int_20
+// CHECK-NEXT: OpBranchConditional [[j_lt_20]] %do_while_header_0 %do_while_merge_0
+    } while (j < 20);
+
+// CHECK-NEXT: %do_while_merge_0 = OpLabel
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[i_plus_1:%\d+]] = OpIAdd %int [[i0]] %int_1
+// CHECK-NEXT: OpStore %i [[i_plus_1]]
+// CHECK-NEXT: OpBranch %do_while_continue
+    ++i;
+
+// CHECK-NEXT: %do_while_continue = OpLabel
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[i_lt_10:%\d+]] = OpSLessThan %bool [[i1]] %int_10
+// CHECK-NEXT: OpBranchConditional [[i_lt_10]] %do_while_header %do_while_merge
+  } while (i < 10);
+// CHECK-NEXT: %do_while_merge = OpLabel
+
+
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+}

+ 77 - 0
tools/clang/test/CodeGenSPIRV/do-stmt.plain.hlsl

@@ -0,0 +1,77 @@
+// Run: %dxc -T ps_6_0 -E main
+
+int foo() { return true; }
+
+void main() {
+  int val = 0;
+  int i = 0;
+
+
+
+    /////////////////////////////
+    //// Basic do-while loop ////
+    /////////////////////////////
+
+// CHECK:      OpBranch %do_while_header
+// CHECK-NEXT: %do_while_header = OpLabel
+// CHECK-NEXT: OpLoopMerge %do_while_merge %do_while_continue None
+// CHECK-NEXT: OpBranch %do_while_body
+  do {
+// CHECK-NEXT: %do_while_body = OpLabel
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: OpStore %val [[i0]]
+// CHECK-NEXT: OpBranch %do_while_continue
+      val = i;
+// CHECK-NEXT: %do_while_continue = OpLabel
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[i_lt_10:%\d+]] = OpSLessThan %bool [[i1]] %int_10
+// CHECK-NEXT: OpBranchConditional [[i_lt_10]] %do_while_header %do_while_merge
+  } while (i < 10);
+// CHECK-NEXT: %do_while_merge = OpLabel
+
+
+
+    //////////////////////////
+    ////  infinite loop   ////
+    //////////////////////////
+
+// CHECK-NEXT: OpBranch %do_while_header_0
+// CHECK-NEXT: %do_while_header_0 = OpLabel
+// CHECK-NEXT: OpLoopMerge %do_while_merge_0 %do_while_continue_0 None
+// CHECK-NEXT: OpBranch %do_while_body_0
+  do {
+// CHECK-NEXT: %do_while_body_0 = OpLabel
+// CHECK-NEXT: OpStore %val %int_0
+// CHECK-NEXT: OpBranch %do_while_continue_0
+    val = 0;
+// CHECK-NEXT: %do_while_continue_0 = OpLabel
+// CHECK-NEXT: OpBranchConditional %true %do_while_header_0 %do_while_merge_0
+  } while (true);
+// CHECK-NEXT: %do_while_merge_0 = OpLabel
+
+
+
+
+    //////////////////////////
+    ////    Null Body     ////
+    //////////////////////////
+// CHECK-NEXT: OpBranch %do_while_header_1
+// CHECK-NEXT: %do_while_header_1 = OpLabel
+// CHECK-NEXT: OpLoopMerge %do_while_merge_1 %do_while_continue_1 None
+// CHECK-NEXT: OpBranch %do_while_body_1
+  do {
+// CHECK-NEXT: %do_while_body_1 = OpLabel
+// CHECK-NEXT: OpBranch %do_while_continue_1
+
+// CHECK-NEXT: %do_while_continue_1 = OpLabel
+// CHECK-NEXT: [[val:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[val_lt_20:%\d+]] = OpSLessThan %bool [[val]] %int_20
+// CHECK-NEXT: OpBranchConditional [[val_lt_20]] %do_while_header_1 %do_while_merge_1
+  } while (val < 20);
+// CHECK-NEXT: %do_while_merge_1 = OpLabel
+
+
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+
+}

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -209,6 +209,10 @@ TEST_F(FileTest, ForStmtNestedForStmt) { runFileTest("for-stmt.nested.hlsl"); }
 TEST_F(FileTest, WhileStmtPlain) { runFileTest("while-stmt.plain.hlsl"); }
 TEST_F(FileTest, WhileStmtNested) { runFileTest("while-stmt.nested.hlsl"); }
 
+// For do statements
+TEST_F(FileTest, DoStmtPlain) { runFileTest("do-stmt.plain.hlsl"); }
+TEST_F(FileTest, DoStmtNested) { runFileTest("do-stmt.nested.hlsl"); }
+
 // For control flows
 TEST_F(FileTest, ControlFlowNestedIfForStmt) { runFileTest("cf.if.for.hlsl"); }
 TEST_F(FileTest, ControlFlowLogicalAnd) { runFileTest("cf.logical-and.hlsl"); }