2
0
Эх сурвалжийг харах

[spirv] Translate switch statements using if statements (#517)

Ehsan 8 жил өмнө
parent
commit
de4bdebdc6

+ 154 - 4
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -562,8 +562,8 @@ public:
     // First handle the condition variable DeclStmt if one exists.
     // For example: handle 'int a = b' in the following:
     // switch (int a = b) {...}
-    const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt();
-    if (condVarDeclStmt)
+    if (const auto *condVarDeclStmt =
+            switchStmt->getConditionVariableDeclStmt())
       doStmt(condVarDeclStmt);
 
     const uint32_t selector = doExpr(switchStmt->getCond());
@@ -594,9 +594,159 @@ public:
     breakStack.pop();
   }
 
+  /// Flattens structured AST of the given switch statement into a vector of AST
+  /// nodes and stores into flatSwitch.
+  ///
+  /// The AST for a switch statement may look arbitrarily different based on
+  /// several factors such as placement of cases, placement of breaks, placement
+  /// of braces, and fallthrough cases.
+  ///
+  /// A CaseStmt for instance is the child node of a CompoundStmt for
+  /// regular cases and it is the child node of another CaseStmt for fallthrough
+  /// cases.
+  ///
+  /// A BreakStmt for instance could be the child node of a CompoundStmt
+  /// for regular cases, or the child node of a CaseStmt for some fallthrough
+  /// cases.
+  ///
+  /// This method flattens the AST representation of a switch statement to make
+  /// it easier to process for translation.
+  /// For example:
+  /// switch(a) {
+  ///   case 1:
+  ///     <Stmt1>
+  ///   case 2:
+  ///     <Stmt2>
+  ///     break;
+  ///   case 3:
+  ///   case 4:
+  ///     <Stmt4>
+  ///     break;
+  ///   deafult:
+  ///     <Stmt5>
+  /// }
+  ///
+  /// is flattened to the following vector:
+  ///
+  /// +-------------------------------------------------------------------+
+  /// |Case1|Stmt1|Case2|Stmt2|Break|Case3|Case4|Stmt4|Break|Default|Stmt5|
+  /// +-------------------------------------------------------------------+
+  ///
+  void flattenSwitchStmtAST(const Stmt *root,
+                            std::vector<const Stmt *> *flatSwitch) {
+    const auto *caseStmt = dyn_cast<CaseStmt>(root);
+    const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
+    const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
+
+    if (!compoundStmt) {
+      flatSwitch->push_back(root);
+    }
+
+    if (compoundStmt) {
+      for (const auto *st : compoundStmt->body())
+        flattenSwitchStmtAST(st, flatSwitch);
+    } else if (caseStmt) {
+      flattenSwitchStmtAST(caseStmt->getSubStmt(), flatSwitch);
+    } else if (defaultStmt) {
+      flattenSwitchStmtAST(defaultStmt->getSubStmt(), flatSwitch);
+    }
+  }
+
+  /// Translates a switch statement into SPIR-V conditional branches.
+  ///
+  /// This is done by constructing AST if statements out of the cases using the
+  /// following pattern:
+  ///   if { ... } else if { ... } else if { ... } else { ... }
+  /// And then calling the SPIR-V codegen methods for these if statements.
+  ///
+  /// Each case comparison is turned into an if statement, and the "then" body
+  /// of the if statement will be the body of the case.
+  /// If a default statements exists, it becomes the body of the "else"
+  /// statement.
   void processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
-    emitError("Translating Switch statements using If statements is not "
-              "implemented yet.");
+    std::vector<const Stmt *> flatSwitch;
+    flattenSwitchStmtAST(switchStmt->getBody(), &flatSwitch);
+
+    // First handle the condition variable DeclStmt if one exists.
+    // For example: handle 'int a = b' in the following:
+    // switch (int a = b) {...}
+    if (const auto *condVarDeclStmt =
+            switchStmt->getConditionVariableDeclStmt())
+      doStmt(condVarDeclStmt);
+
+    // Figure out the indexes of CaseStmts (and DefaultStmt if it exists) in
+    // the flattened switch AST.
+    // For instance, for the following flat vector, the indexes are:
+    // {0, 2, 5, 6, 9}
+    // +-------------------------------------------------------------------+
+    // |Case1|Stmt1|Case2|Stmt2|Break|Case3|Case4|Stmt4|Break|Default|Stmt5|
+    // +-------------------------------------------------------------------+
+    std::vector<uint32_t> caseStmtLocs;
+    for (uint32_t i = 0; i < flatSwitch.size(); ++i)
+      if (isa<CaseStmt>(flatSwitch[i]) || isa<DefaultStmt>(flatSwitch[i]))
+        caseStmtLocs.push_back(i);
+
+    IfStmt *prevIfStmt = nullptr;
+    IfStmt *rootIfStmt = nullptr;
+    CompoundStmt *defaultBody = nullptr;
+
+    // For each case, start at its index in the vector, and go forward
+    // accumulating statements until BreakStmt or end of vector is reached.
+    for (auto curCaseIndex : caseStmtLocs) {
+      const Stmt *curCase = flatSwitch[curCaseIndex];
+
+      // CompoundStmt to hold all statements for this case.
+      CompoundStmt *cs = new (astContext) CompoundStmt(Stmt::EmptyShell());
+
+      // Accumulate all non-case/default/break statements as the body for the
+      // current case.
+      std::vector<Stmt *> statements;
+      for (int i = curCaseIndex + 1;
+           i < flatSwitch.size() && !isa<BreakStmt>(flatSwitch[i]); ++i) {
+        if (!isa<CaseStmt>(flatSwitch[i]) && !isa<DefaultStmt>(flatSwitch[i]))
+          statements.push_back(const_cast<Stmt *>(flatSwitch[i]));
+      }
+      if (!statements.empty())
+        cs->setStmts(astContext, statements.data(), statements.size());
+
+      // For non-default cases, generate the IfStmt that compares the switch
+      // value to the case value.
+      if (auto *caseStmt = dyn_cast<CaseStmt>(curCase)) {
+        IfStmt *curIf = new (astContext) IfStmt(Stmt::EmptyShell());
+        BinaryOperator *bo =
+            new (astContext) BinaryOperator(Stmt::EmptyShell());
+        bo->setLHS(const_cast<Expr *>(switchStmt->getCond()));
+        bo->setRHS(const_cast<Expr *>(caseStmt->getLHS()));
+        bo->setOpcode(BO_EQ);
+        bo->setType(astContext.getLogicalOperationType());
+        curIf->setCond(bo);
+        curIf->setThen(cs);
+        // Each If statement is the "else" of the previous if statement.
+        if (prevIfStmt)
+          prevIfStmt->setElse(curIf);
+        else
+          rootIfStmt = curIf;
+        prevIfStmt = curIf;
+      } else {
+        // Record the DefaultStmt body as it will be used as the body of the
+        // "else" block in the if-elseif-...-else pattern.
+        defaultBody = cs;
+      }
+    }
+
+    // If a default case exists, it is the "else" of the last if statement.
+    if (prevIfStmt)
+      prevIfStmt->setElse(defaultBody);
+
+    // Since all else-if and else statements are the child nodes of the first
+    // IfStmt, we only need to call doStmt for the first IfStmt.
+    if (rootIfStmt)
+      doStmt(rootIfStmt);
+    // If there are no CaseStmt and there is only 1 DefaultStmt, there will be
+    // no if statements. The switch in that case only executes the body of the
+    // default case.
+    else if (defaultBody)
+      doStmt(defaultBody);
   }
 
   void doSwitchStmt(const SwitchStmt *switchStmt,

+ 325 - 0
tools/clang/test/CodeGenSPIRV/switch-stmt.ifstmt.hlsl

@@ -0,0 +1,325 @@
+// Run: %dxc -T ps_6_0 -E main
+
+int foo() { return 200; }
+
+void main() {
+
+// CHECK:      %a = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %b = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %c = OpVariable %_ptr_Function_int Function
+
+// TODO: We should try not to emit OpVariable for constant variables.
+// CHECK-NEXT: %r = OpVariable %_ptr_Function_int Function %int_20
+// CHECK-NEXT: %s = OpVariable %_ptr_Function_int Function %int_40
+// CHECK-NEXT: %t = OpVariable %_ptr_Function_int Function %int_140
+// CHECK-NEXT: %d = OpVariable %_ptr_Function_int Function %int_5
+  int a,b,c;
+  const int r = 20;
+  const int s = 40;
+  const int t = 3*r+2*s;
+
+
+  ////////////////////////////////////////
+  // DefaultStmt is the first statement //
+  ////////////////////////////////////////
+
+// CHECK-NEXT: [[a0:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[is_a_1:%\d+]] = OpIEqual %bool [[a0]] %int_1
+// CHECK-NEXT: OpSelectionMerge %if_merge None
+// CHECK-NEXT: OpBranchConditional [[is_a_1]] %if_true %if_false
+// CHECK-NEXT: %if_true = OpLabel
+// CHECK-NEXT: OpStore %b %int_1
+// CHECK-NEXT: OpBranch %if_merge
+// CHECK-NEXT: %if_false = OpLabel
+// CHECK-NEXT: [[a1:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[is_a_2:%\d+]] = OpIEqual %bool [[a1]] %int_2
+// CHECK-NEXT: OpSelectionMerge %if_merge_0 None
+// CHECK-NEXT: OpBranchConditional [[is_a_2]] %if_true_0 %if_false_0
+// CHECK-NEXT: %if_true_0 = OpLabel
+// CHECK-NEXT: OpStore %b %int_2
+// CHECK-NEXT: OpBranch %if_merge_0
+// CHECK-NEXT: %if_false_0 = OpLabel
+// CHECK-NEXT: OpStore %b %int_0
+// CHECK-NEXT: OpStore %b %int_1
+// CHECK-NEXT: OpBranch %if_merge_0
+// CHECK-NEXT: %if_merge_0 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge
+// CHECK-NEXT: %if_merge = OpLabel
+  [branch] switch(a) {
+    default:
+      b=0;
+    case 1:
+      b=1;
+      break;
+    case 2:
+      b=2;
+  }
+
+
+  //////////////////////////////////////////////
+  // DefaultStmt in the middle of other cases //
+  //////////////////////////////////////////////
+
+// CHECK-NEXT: [[a2:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[is_a_10:%\d+]] = OpIEqual %bool [[a2]] %int_10
+// CHECK-NEXT: OpSelectionMerge %if_merge_1 None
+// CHECK-NEXT: OpBranchConditional [[is_a_10]] %if_true_1 %if_false_1
+// CHECK-NEXT: %if_true_1 = OpLabel
+// CHECK-NEXT: OpStore %b %int_1
+// CHECK-NEXT: OpStore %b %int_0
+// CHECK-NEXT: OpStore %b %int_2
+// CHECK-NEXT: OpBranch %if_merge_1
+// CHECK-NEXT: %if_false_1 = OpLabel
+// CHECK-NEXT: [[a3:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[is_a_20:%\d+]] = OpIEqual %bool [[a3]] %int_20
+// CHECK-NEXT: OpSelectionMerge %if_merge_2 None
+// CHECK-NEXT: OpBranchConditional [[is_a_20]] %if_true_2 %if_false_2
+// CHECK-NEXT: %if_true_2 = OpLabel
+// CHECK-NEXT: OpStore %b %int_2
+// CHECK-NEXT: OpBranch %if_merge_2
+// CHECK-NEXT: %if_false_2 = OpLabel
+// CHECK-NEXT: OpStore %b %int_0
+// CHECK-NEXT: OpStore %b %int_2
+// CHECK-NEXT: OpBranch %if_merge_2
+// CHECK-NEXT: %if_merge_2 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_1
+// CHECK-NEXT: %if_merge_1 = OpLabel
+  [branch] switch(a) {
+    case 10:
+      b=1;
+    default:
+      b=0;
+    case 20:
+      b=2;
+      break;
+  }
+
+  ///////////////////////////////////////////////
+  // Various CaseStmt and BreakStmt topologies //
+  // DefaultStmt is the last statement         //
+  ///////////////////////////////////////////////
+
+// CHECK-NEXT: [[d0:%\d+]] = OpLoad %int %d
+// CHECK-NEXT: [[is_d_1:%\d+]] = OpIEqual %bool [[d0]] %int_1
+// CHECK-NEXT: OpSelectionMerge %if_merge_3 None
+// CHECK-NEXT: OpBranchConditional [[is_d_1]] %if_true_3 %if_false_3
+// CHECK-NEXT: %if_true_3 = OpLabel
+// CHECK-NEXT: OpStore %b %int_1
+// CHECK-NEXT: [[foo:%\d+]] = OpFunctionCall %int %foo
+// CHECK-NEXT: OpStore %c [[foo]]
+// CHECK-NEXT: OpStore %b %int_2
+// CHECK-NEXT: OpBranch %if_merge_3
+// CHECK-NEXT: %if_false_3 = OpLabel
+// CHECK-NEXT: [[d1:%\d+]] = OpLoad %int %d
+// CHECK-NEXT: [[is_d_2:%\d+]] = OpIEqual %bool [[d1]] %int_2
+// CHECK-NEXT: OpSelectionMerge %if_merge_4 None
+// CHECK-NEXT: OpBranchConditional [[is_d_2]] %if_true_4 %if_false_4
+// CHECK-NEXT: %if_true_4 = OpLabel
+// CHECK-NEXT: OpStore %b %int_2
+// CHECK-NEXT: OpBranch %if_merge_4
+// CHECK-NEXT: %if_false_4 = OpLabel
+// CHECK-NEXT: [[d2:%\d+]] = OpLoad %int %d
+// CHECK-NEXT: [[is_d_3:%\d+]] = OpIEqual %bool [[d2]] %int_3
+// CHECK-NEXT: OpSelectionMerge %if_merge_5 None
+// CHECK-NEXT: OpBranchConditional [[is_d_3]] %if_true_5 %if_false_5
+// CHECK-NEXT: %if_true_5 = OpLabel
+// CHECK-NEXT: OpStore %b %int_3
+// CHECK-NEXT: OpBranch %if_merge_5
+// CHECK-NEXT: %if_false_5 = OpLabel
+// CHECK-NEXT: [[d3:%\d+]] = OpLoad %int %d
+// TODO: We should try to const fold `t` and avoid the following OpLoad:
+// CHECK-NEXT: [[t:%\d+]] = OpLoad %int %t
+// CHECK-NEXT: [[is_d_eq_t:%\d+]] = OpIEqual %bool [[d3]] [[t]]
+// CHECK-NEXT: OpSelectionMerge %if_merge_6 None
+// CHECK-NEXT: OpBranchConditional [[is_d_eq_t]] %if_true_6 %if_false_6
+// CHECK-NEXT: %if_true_6 = OpLabel
+// CHECK-NEXT: [[t1:%\d+]] = OpLoad %int %t
+// CHECK-NEXT: OpStore %b [[t1]]
+// CHECK-NEXT: OpStore %b %int_5
+// CHECK-NEXT: OpBranch %if_merge_6
+// CHECK-NEXT: %if_false_6 = OpLabel
+// CHECK-NEXT: [[d4:%\d+]] = OpLoad %int %d
+// CHECK-NEXT: [[is_d_4:%\d+]] = OpIEqual %bool [[d4]] %int_4
+// CHECK-NEXT: OpSelectionMerge %if_merge_7 None
+// CHECK-NEXT: OpBranchConditional [[is_d_4]] %if_true_7 %if_false_7
+// CHECK-NEXT: %if_true_7 = OpLabel
+// CHECK-NEXT: OpStore %b %int_5
+// CHECK-NEXT: OpBranch %if_merge_7
+// CHECK-NEXT: %if_false_7 = OpLabel
+// CHECK-NEXT: [[d5:%\d+]] = OpLoad %int %d
+// CHECK-NEXT: [[is_d_5:%\d+]] = OpIEqual %bool [[d5]] %int_5
+// CHECK-NEXT: OpSelectionMerge %if_merge_8 None
+// CHECK-NEXT: OpBranchConditional [[is_d_5]] %if_true_8 %if_false_8
+// CHECK-NEXT: %if_true_8 = OpLabel
+// CHECK-NEXT: OpStore %b %int_5
+// CHECK-NEXT: OpBranch %if_merge_8
+// CHECK-NEXT: %if_false_8 = OpLabel
+// CHECK-NEXT: [[d6:%\d+]] = OpLoad %int %d
+// CHECK-NEXT: [[is_d_6:%\d+]] = OpIEqual %bool [[d6]] %int_6
+// CHECK-NEXT: OpSelectionMerge %if_merge_9 None
+// CHECK-NEXT: OpBranchConditional [[is_d_6]] %if_true_9 %if_false_9
+// CHECK-NEXT: %if_true_9 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_9
+// CHECK-NEXT: %if_false_9 = OpLabel
+// CHECK-NEXT: [[d7:%\d+]] = OpLoad %int %d
+// CHECK-NEXT: [[is_d_7:%\d+]] = OpIEqual %bool [[d7]] %int_7
+// CHECK-NEXT: OpSelectionMerge %if_merge_10 None
+// CHECK-NEXT: OpBranchConditional [[is_d_7]] %if_true_10 %if_false_10
+// CHECK-NEXT: %if_true_10 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_10
+// CHECK-NEXT: %if_false_10 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_10
+// CHECK-NEXT: %if_merge_10 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_9
+// CHECK-NEXT: %if_merge_9 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_8
+// CHECK-NEXT: %if_merge_8 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_7
+// CHECK-NEXT: %if_merge_7 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_6
+// CHECK-NEXT: %if_merge_6 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_5
+// CHECK-NEXT: %if_merge_5 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_4
+// CHECK-NEXT: %if_merge_4 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_3
+// CHECK-NEXT: %if_merge_3 = OpLabel
+  [branch] switch(int d = 5) {
+    case 1:
+      b=1;
+      c=foo();
+    case 2:
+      b=2;
+      break;
+    case 3:
+    {
+      b=3;
+      break;
+    }
+    case t:
+      b=t;
+    case 4:
+    case 5:
+      b=5;
+      break;
+    case 6: {
+    case 7:
+      break;}
+    default:
+      break;
+  }
+
+
+  //////////////////////////
+  // No Default statement //
+  //////////////////////////
+
+// CHECK-NEXT: [[a4:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[is_a_100:%\d+]] = OpIEqual %bool [[a4]] %int_100
+// CHECK-NEXT: OpSelectionMerge %if_merge_11 None
+// CHECK-NEXT: OpBranchConditional [[is_a_100]] %if_true_11 %if_merge_11
+// CHECK-NEXT: %if_true_11 = OpLabel
+// CHECK-NEXT: OpStore %b %int_100
+// CHECK-NEXT: OpBranch %if_merge_11
+// CHECK-NEXT: %if_merge_11 = OpLabel
+  [branch] switch(a) {
+    case 100:
+      b=100;
+      break;
+  }
+
+
+  /////////////////////////////////////////////////////////
+  // No cases. Only a default                            //
+  // This means the default body will always be executed //
+  /////////////////////////////////////////////////////////
+
+// CHECK-NEXT: OpStore %b %int_100
+// CHECK-NEXT: OpStore %c %int_200
+  [branch] switch(a) {
+    default:
+      b=100;
+      c=200;
+      break;
+  }
+
+
+  ////////////////////////////////////////////////////////////
+  // Nested Switch with branching                           //
+  // The two inner switch statements should be executed for //
+  // both cases of the outer switch (case 300 and case 400) //
+  ////////////////////////////////////////////////////////////
+
+// CHECK-NEXT: [[a5:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[is_a_300:%\d+]] = OpIEqual %bool [[a5]] %int_300
+// CHECK-NEXT: OpSelectionMerge %if_merge_12 None
+// CHECK-NEXT: OpBranchConditional [[is_a_300]] %if_true_12 %if_false_11
+// CHECK-NEXT: %if_true_12 = OpLabel
+// CHECK-NEXT: OpStore %b %int_300
+// CHECK-NEXT: [[c0:%\d+]] = OpLoad %int %c
+// CHECK-NEXT: [[is_c_500:%\d+]] = OpIEqual %bool [[c0]] %int_500
+// CHECK-NEXT: OpSelectionMerge %if_merge_13 None
+// CHECK-NEXT: OpBranchConditional [[is_c_500]] %if_true_13 %if_false_12
+// CHECK-NEXT: %if_true_13 = OpLabel
+// CHECK-NEXT: OpStore %b %int_500
+// CHECK-NEXT: OpBranch %if_merge_13
+// CHECK-NEXT: %if_false_12 = OpLabel
+// CHECK-NEXT: [[c1:%\d+]] = OpLoad %int %c
+// CHECK-NEXT: [[is_c_600:%\d+]] = OpIEqual %bool [[c1]] %int_600
+// CHECK-NEXT: OpSelectionMerge %if_merge_14 None
+// CHECK-NEXT: OpBranchConditional [[is_c_600]] %if_true_14 %if_merge_14
+// CHECK-NEXT: %if_true_14 = OpLabel
+// CHECK-NEXT: OpStore %a %int_600
+// CHECK-NEXT: OpStore %b %int_600
+// CHECK-NEXT: OpBranch %if_merge_14
+// CHECK-NEXT: %if_merge_14 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_13
+// CHECK-NEXT: %if_merge_13 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_12
+// CHECK-NEXT: %if_false_11 = OpLabel
+// CHECK-NEXT: [[a6:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[is_a_400:%\d+]] = OpIEqual %bool [[a6]] %int_400
+// CHECK-NEXT: OpSelectionMerge %if_merge_15 None
+// CHECK-NEXT: OpBranchConditional [[is_a_400]] %if_true_15 %if_merge_15
+// CHECK-NEXT: %if_true_15 = OpLabel
+// CHECK-NEXT: [[c2:%\d+]] = OpLoad %int %c
+// CHECK-NEXT: [[is_c_500_again:%\d+]] = OpIEqual %bool [[c2]] %int_500
+// CHECK-NEXT: OpSelectionMerge %if_merge_16 None
+// CHECK-NEXT: OpBranchConditional [[is_c_500_again]] %if_true_16 %if_false_13
+// CHECK-NEXT: %if_true_16 = OpLabel
+// CHECK-NEXT: OpStore %b %int_500
+// CHECK-NEXT: OpBranch %if_merge_16
+// CHECK-NEXT: %if_false_13 = OpLabel
+// CHECK-NEXT: [[c3:%\d+]] = OpLoad %int %c
+// CHECK-NEXT: [[is_c_600_again:%\d+]] = OpIEqual %bool [[c3]] %int_600
+// CHECK-NEXT: OpSelectionMerge %if_merge_17 None
+// CHECK-NEXT: OpBranchConditional [[is_c_600_again]] %if_true_17 %if_merge_17
+// CHECK-NEXT: %if_true_17 = OpLabel
+// CHECK-NEXT: OpStore %a %int_600
+// CHECK-NEXT: OpStore %b %int_600
+// CHECK-NEXT: OpBranch %if_merge_17
+// CHECK-NEXT: %if_merge_17 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_16
+// CHECK-NEXT: %if_merge_16 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_15
+// CHECK-NEXT: %if_merge_15 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_12
+// CHECK-NEXT: %if_merge_12 = OpLabel
+  [branch] switch(a) {
+    case 300:
+      b=300;
+    case 400:
+      [branch] switch(c) {
+        case 500:
+          b=500;
+          break;
+        case 600:
+          [branch] switch(b) {
+            default:
+            a=600;
+            b=600;
+          }
+      }
+  }
+
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -191,6 +191,9 @@ TEST_F(FileTest, IfStmtNestedIfStmt) { runFileTest("if-stmt.nested.hlsl"); }
 TEST_F(FileTest, SwitchStmtUsingOpSwitch) {
   runFileTest("switch-stmt.opswitch.hlsl");
 }
+TEST_F(FileTest, SwitchStmtUsingIfStmt) {
+  runFileTest("switch-stmt.ifstmt.hlsl");
+}
 
 // For for statements
 TEST_F(FileTest, ForStmtPlainAssign) { runFileTest("for-stmt.plain.hlsl"); }