Browse Source

[spirv] Translate if statement (#462)

Support plain and nested if statments now. Complicated cases (like
with early returns and other control flow statements in body) yet
to be covered.
Lei Zhang 8 years ago
parent
commit
a4538a9252

+ 7 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -96,6 +96,13 @@ public:
   uint32_t createBinaryOp(spv::Op op, uint32_t resultType, uint32_t lhs,
                           uint32_t rhs);
 
+  // \brief Creates an unconditional branch to the given target label.
+  void createBranch(uint32_t targetLabel);
+  // \brief Creates a conditional branch. The OpSelectionMerge instruction
+  // required for the conditional branch will also be created.
+  void createConditionalBranch(uint32_t condition, uint32_t trueLabel,
+                               uint32_t falseLabel, uint32_t mergeLabel);
+
   /// \brief Creates a return instruction.
   void createReturn();
   /// \brief Creates a return value instruction.

+ 72 - 7
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -178,8 +178,7 @@ public:
         theBuilder.setInsertPoint(entryLabel);
 
         // Process all statments in the body.
-        for (Stmt *stmt : cast<CompoundStmt>(decl->getBody())->body())
-          doStmt(stmt);
+        doStmt(decl->getBody());
 
         // We have processed all Stmts in this function and now in the last
         // basic block. Make sure we have OpReturn if missing.
@@ -230,14 +229,21 @@ public:
     }
   }
 
-  void doStmt(Stmt *stmt) {
-    if (auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
+  void doStmt(const Stmt *stmt) {
+    if (const auto *compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
+      for (auto *st : compoundStmt->body())
+        doStmt(st);
+    } else if (const auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
       doReturnStmt(retStmt);
-    } else if (auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
+    } else if (const auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
       for (auto *decl : declStmt->decls()) {
         doDecl(decl);
       }
-    } else if (auto *expr = dyn_cast<Expr>(stmt)) {
+    } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
+      doIfStmt(ifStmt);
+    } else if (const auto *nullStmt = dyn_cast<NullStmt>(stmt)) {
+      // We don't need to do anything for NullStmt
+    } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
       // All cases for expressions used as statements
       doExpr(expr);
     } else {
@@ -245,7 +251,7 @@ public:
     }
   }
 
-  void doReturnStmt(ReturnStmt *stmt) {
+  void doReturnStmt(const ReturnStmt *stmt) {
     // For normal functions, just return in the normal way.
     if (curFunction->getName() != entryFunctionName) {
       theBuilder.createReturnValue(doExpr(stmt->getRetValue()));
@@ -307,6 +313,65 @@ public:
     }
   }
 
+  void doIfStmt(const IfStmt *ifStmt) {
+    // if statements are composed of:
+    //   if (<check>) { <then> } else { <else> }
+    //
+    // To translate if statements, we'll need to emit the <check> expressions
+    // in the current basic block, and then create separate basic blocks for
+    // <then> and <else>. Additionally, we'll need a <merge> block as per
+    // SPIR-V's structured control flow requirements. Depending whether there
+    // exists the else branch, the final CFG should normally be like the
+    // following. Exceptions will occur with non-local exits like loop breaks
+    // or early returns.
+    //             +-------+                        +-------+
+    //             | check |                        | check |
+    //             +-------+                        +-------+
+    //                 |                                |
+    //         +-------+-------+                  +-----+-----+
+    //         | true          | false            | true      | false
+    //         v               v         or       v           |
+    //     +------+         +------+           +------+       |
+    //     | then |         | else |           | then |       |
+    //     +------+         +------+           +------+       |
+    //         |               |                  |           v
+    //         |   +-------+   |                  |     +-------+
+    //         +-> | merge | <-+                  +---> | merge |
+    //             +-------+                            +-------+
+
+    // First emit the instruction for evaluating the condition.
+    const uint32_t condition = doExpr(ifStmt->getCond());
+
+    // Then we need to emit the instruction for the conditional branch.
+    // We'll need the <label-id> for the then/else/merge block to do so.
+    const bool hasElse = ifStmt->getElse() != nullptr;
+    const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
+    const uint32_t elseBB = hasElse ? theBuilder.createBasicBlock("if.false")
+                                    : theBuilder.createBasicBlock("if.merge");
+    const uint32_t mergeBB =
+        hasElse ? theBuilder.createBasicBlock("if.merge") : elseBB;
+
+    // Create the branch instruction. This will end the current basic block.
+    theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB);
+
+    // Handle the then branch
+    theBuilder.setInsertPoint(thenBB);
+    doStmt(ifStmt->getThen());
+    if (!theBuilder.isCurrentBasicBlockTerminated())
+      theBuilder.createBranch(mergeBB);
+
+    // Handle the else branch (if exists)
+    if (hasElse) {
+      theBuilder.setInsertPoint(elseBB);
+      doStmt(ifStmt->getElse());
+      if (!theBuilder.isCurrentBasicBlockTerminated())
+        theBuilder.createBranch(mergeBB);
+    }
+
+    // From now on, we'll emit instructions into the merge block.
+    theBuilder.setInsertPoint(mergeBB);
+  }
+
   uint32_t doExpr(const Expr *expr) {
     if (auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
       // Returns the <result-id> of the referenced Decl.

+ 18 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -151,6 +151,24 @@ uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
   return id;
 }
 
+void ModuleBuilder::createBranch(uint32_t targetLabel) {
+  assert(insertPoint && "null insert point");
+  instBuilder.opBranch(targetLabel).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+}
+
+void ModuleBuilder::createConditionalBranch(uint32_t condition,
+                                            uint32_t trueLabel,
+                                            uint32_t falseLabel,
+                                            uint32_t mergeLabel) {
+  assert(insertPoint && "null insert point");
+  instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
+      .x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  instBuilder.opBranchConditional(condition, trueLabel, falseLabel, {}).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+}
+
 void ModuleBuilder::createReturn() {
   assert(insertPoint && "null insert point");
   instBuilder.opReturn().x();

+ 60 - 0
tools/clang/test/CodeGenSPIRV/if-stmt.nested.hlsl

@@ -0,0 +1,60 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+    bool c1, c2, c3, c4;
+    int val = 0;
+
+// CHECK:      [[c1:%\d+]] = OpLoad %bool %c1
+// CHECK-NEXT: OpSelectionMerge %if_merge None
+// CHECK-NEXT: OpBranchConditional [[c1]] %if_true %if_false
+    if (c1) {
+// CHECK-LABEL: %if_true = OpLabel
+// CHECK-NEXT: [[c2:%\d+]] = OpLoad %bool %c2
+// CHECK-NEXT: OpSelectionMerge %if_merge_0 None
+// CHECK-NEXT: OpBranchConditional [[c2]] %if_true_0 %if_merge_0
+
+// TODO: Move this basic block to the else branch
+// CHECK-LABEL: %if_false = OpLabel
+// CHECK-NEXT: [[c3:%\d+]] = OpLoad %bool %c3
+// CHECK-NEXT: OpSelectionMerge %if_merge_1 None
+// CHECK-NEXT: OpBranchConditional [[c3]] %if_true_1 %if_false_0
+
+// TODO: Move this basic block to the end
+// CHECK-LABEL: %if_merge = OpLabel
+// CHECK-NEXT: OpReturn
+
+        if (c2)
+// CHECK-LABEL: %if_true_0 = OpLabel
+// CHECK-NEXT: OpStore %val %int_1
+// CHECK-NEXT: OpBranch %if_merge_0
+            val = 1;
+
+// CHECK-LABEL: %if_merge_0 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge
+    } else {
+        if (c3) {
+// CHECK-LABEL: %if_true_1 = OpLabel
+// CHECK-NEXT: OpStore %val %int_2
+// CHECK-NEXT: OpBranch %if_merge_1
+            val = 2;
+        } else {
+// CHECK-LABEL: %if_false_0 = OpLabel
+// CHECK-NEXT: [[c4:%\d+]] = OpLoad %bool %c4
+// CHECK-NEXT: OpSelectionMerge %if_merge_2 None
+// CHECK-NEXT: OpBranchConditional [[c4]] %if_true_2 %if_merge_2
+
+// TODO: Make this basic block the second to last one
+// CHECK-LABEL: %if_merge_1 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge
+            if (c4) {
+// CHECK-LABEL: %if_true_2 = OpLabel
+// CHECK-NEXT: OpStore %val %int_3
+// CHECK-NEXT: OpBranch %if_merge_2
+                val = 3;
+            }
+// CHECK-LABEL: %if_merge_2 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_1
+        }
+    }
+}

+ 69 - 0
tools/clang/test/CodeGenSPIRV/if-stmt.plain.hlsl

@@ -0,0 +1,69 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// Note: we need to consider the order of basic blocks. So CHECK-NEXT is used
+// extensively.
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+    bool c;
+    int val;
+
+    // Both then and else
+// CHECK:      [[c0:%\d+]] = OpLoad %bool %c
+// CHECK-NEXT: OpSelectionMerge %if_merge None
+// CHECK-NEXT: OpBranchConditional [[c0]] %if_true %if_false
+    if (c) {
+// CHECK-LABEL: %if_true = OpLabel
+// CHECK-NEXT: [[val0:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[val1:%\d+]] = OpIAdd %int [[val0]] %int_1
+// CHECK-NEXT: OpStore %val [[val1]]
+// CHECK-NEXT: OpBranch %if_merge
+        val = val + 1;
+    } else {
+// CHECK-LABEL: %if_false = OpLabel
+// CHECK-NEXT: [[val2:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[val3:%\d+]] = OpIAdd %int [[val2]] %int_2
+// CHECK-NEXT: OpStore %val [[val3]]
+// CHECK-NEXT: OpBranch %if_merge
+        val = val + 2;
+    }
+// CHECK-LABEL: %if_merge = OpLabel
+
+    // No else
+// CHECK-NEXT: [[c1:%\d+]] = OpLoad %bool %c
+// CHECK-NEXT: OpSelectionMerge %if_merge_0 None
+// CHECK-NEXT: OpBranchConditional [[c1]] %if_true_0 %if_merge_0
+    if (c)
+// CHECK-LABEL: %if_true_0 = OpLabel
+// CHECK-NEXT: OpStore %val %int_1
+// CHECK-NEXT: OpBranch %if_merge_0
+        val = 1;
+// CHECK-LABEL: %if_merge_0 = OpLabel
+
+    // Empty then
+// CHECK-NEXT: [[c2:%\d+]] = OpLoad %bool %c
+// CHECK-NEXT: OpSelectionMerge %if_merge_1 None
+// CHECK-NEXT: OpBranchConditional [[c2]] %if_true_1 %if_false_0
+    if (c) {
+// CHECK-LABEL: %if_true_1 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_1
+    } else {
+// CHECK-LABEL: %if_false_0 = OpLabel
+// CHECK-NEXT: OpStore %val %int_2
+// CHECK-NEXT: OpBranch %if_merge_1
+        val = 2;
+    }
+// CHECK-LABEL: %if_merge_1 = OpLabel
+
+    // Null body
+// CHECK-NEXT: [[c3:%\d+]] = OpLoad %bool %c
+// CHECK-NEXT: OpSelectionMerge %if_merge_2 None
+// CHECK-NEXT: OpBranchConditional [[c3]] %if_true_2 %if_merge_2
+    if (c)
+// CHECK-LABEL: %if_true_2 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_2
+        ;
+
+// CHECK-LABEL: %if_merge_2 = OpLabel
+// CHECK-NEXT: OpReturn
+}

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -44,4 +44,8 @@ TEST_F(FileTest, BinaryOpScalarArithmetic) {
   runFileTest("binary-op.arithmetic.scalar.hlsl");
 }
 
+TEST_F(FileTest, IfStmtPlainAssign) { runFileTest("if-stmt.plain.hlsl"); }
+
+TEST_F(FileTest, IfStmtNestedIfStmt) { runFileTest("if-stmt.nested.hlsl"); }
+
 } // namespace