Преглед изворни кода

[spirv] Translate switch statements using OpSwitch (#505)

SPIR-V OpSwitch can be used to represent HLSL switch statements when all
case values are integer literals (or constant integer variables).

This CL also handles pass-through cases and nested switch statements.
Ehsan пре 8 година
родитељ
комит
bb303364a9

+ 6 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -144,6 +144,12 @@ public:
   uint32_t createSelect(uint32_t resultType, uint32_t condition,
                         uint32_t trueValue, uint32_t falseValue);
 
+  /// \brief Creates a switch statement for the given selector, default, and
+  /// branches. Results in OpSelectionMerge followed by OpSwitch.
+  void createSwitch(uint32_t mergeLabel, uint32_t selector,
+                    uint32_t defaultLabel,
+                    llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target);
+
   // \brief Creates an unconditional branch to the given target label.
   void createBranch(uint32_t targetLabel);
 

+ 249 - 1
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -20,6 +20,7 @@
 #include "clang/SPIRV/TypeTranslator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringExtras.h"
 
 namespace clang {
 namespace spirv {
@@ -325,7 +326,7 @@ public:
     }
   }
 
-  void doStmt(const Stmt *stmt) {
+  void doStmt(const Stmt *stmt, llvm::ArrayRef<const Attr *> attrs = {}) {
     if (const auto *compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
       for (auto *st : compoundStmt->body())
         doStmt(st);
@@ -337,6 +338,14 @@ public:
       }
     } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
       doIfStmt(ifStmt);
+    } else if (const auto *switchStmt = dyn_cast<SwitchStmt>(stmt)) {
+      doSwitchStmt(switchStmt, attrs);
+    } else if (const auto *caseStmt = dyn_cast<CaseStmt>(stmt)) {
+      processCaseStmtOrDefaultStmt(stmt);
+    } else if (const auto *defaultStmt = dyn_cast<DefaultStmt>(stmt)) {
+      processCaseStmtOrDefaultStmt(stmt);
+    } else if (const auto *breakStmt = dyn_cast<BreakStmt>(stmt)) {
+      doBreakStmt(breakStmt);
     } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
       doForStmt(forStmt);
     } else if (const auto *nullStmt = dyn_cast<NullStmt>(stmt)) {
@@ -344,6 +353,8 @@ public:
     } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
       // All cases for expressions used as statements
       doExpr(expr);
+    } else if (const auto *attrStmt = dyn_cast<AttributedStmt>(stmt)) {
+      doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs());
     } else {
       emitError("Stmt '%0' is not supported yet.") << stmt->getStmtClassName();
     }
@@ -411,6 +422,225 @@ public:
     }
   }
 
+  /// \brief Returns true iff *all* the case values in the given switch
+  /// statement are integer literals. In such cases OpSwitch can be used to
+  /// represent the switch statement.
+  /// We only care about the case values to be compared with the selector. They
+  /// may appear in the top level CaseStmt or be nested in a CompoundStmt.Fall
+  /// through cases will result in the second situation.
+  bool allSwitchCasesAreIntegerLiterals(const Stmt *root) {
+    if (!root)
+      return false;
+
+    const auto *caseStmt = dyn_cast<CaseStmt>(root);
+    const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
+    if (!caseStmt && !compoundStmt)
+      return true;
+
+    if (caseStmt) {
+      const Expr *caseExpr = caseStmt->getLHS();
+      return caseExpr && caseExpr->isEvaluatable(astContext);
+    }
+
+    // Recurse down if facing a compound statement.
+    for (auto *st : compoundStmt->body())
+      if (!allSwitchCasesAreIntegerLiterals(st))
+        return false;
+
+    return true;
+  }
+
+  /// \brief Recursively discovers all CaseStmt and DefaultStmt under the
+  /// sub-tree of the given root. Recursively goes down the tree iff it finds a
+  /// CaseStmt, DefaultStmt, or CompoundStmt. It does not recurse on other
+  /// statement types. For each discovered case, a basic block is created and
+  /// registered within the module, and added as a successor to the current
+  /// active basic block.
+  ///
+  /// Writes a vector of (integer, basic block label) pairs for all cases to the
+  /// given 'targets' argument. If a DefaultStmt is found, it also returns the
+  /// label for the default basic block through the defaultBB parameter. This
+  /// method panics if it finds a case value that is not an integer literal.
+  void discoverAllCaseStmtInSwitchStmt(
+      const Stmt *root, uint32_t *defaultBB,
+      std::vector<std::pair<uint32_t, uint32_t>> *targets) {
+    if (!root)
+      return;
+
+    // A switch case can only appear in DefaultStmt, CaseStmt, or
+    // CompoundStmt. For the rest, we can just return.
+    const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
+    const auto *caseStmt = dyn_cast<CaseStmt>(root);
+    const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
+    if (!defaultStmt && !caseStmt && !compoundStmt)
+      return;
+
+    // Recurse down if facing a compound statement.
+    if (compoundStmt) {
+      for (auto *st : compoundStmt->body())
+        discoverAllCaseStmtInSwitchStmt(st, defaultBB, targets);
+      return;
+    }
+
+    std::string caseLabel;
+    uint32_t caseValue = 0;
+    if (defaultStmt) {
+      // This is the default branch.
+      caseLabel = "switch.default";
+    } else if (caseStmt) {
+      // This is a non-default case.
+      // When using OpSwitch, we only allow integer literal cases. e.g:
+      // case <literal_integer>: {...; break;}
+      const Expr *caseExpr = caseStmt->getLHS();
+      assert(caseExpr && caseExpr->isEvaluatable(astContext));
+      auto bitWidth = astContext.getIntWidth(caseExpr->getType());
+      if (bitWidth != 32)
+        emitError("Switch statement translation currently only supports 32-bit "
+                  "integer case values.");
+      Expr::EvalResult evalResult;
+      caseExpr->EvaluateAsRValue(evalResult, astContext);
+      const int64_t value = evalResult.Val.getInt().getSExtValue();
+      caseValue = static_cast<uint32_t>(value);
+      caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
+                  llvm::itostr(std::abs(value));
+    }
+    const uint32_t caseBB = theBuilder.createBasicBlock(caseLabel);
+    theBuilder.addSuccessor(caseBB);
+    stmtBasicBlock[root] = caseBB;
+
+    // Add all cases to the 'targets' vector.
+    if (caseStmt)
+      targets->emplace_back(caseValue, caseBB);
+
+    // The default label is not part of the 'targets' vector that is passed
+    // to the OpSwitch instruction.
+    // If default statement was discovered, return its label via defaultBB.
+    if (defaultStmt)
+      *defaultBB = caseBB;
+
+    // Process cases nested in other cases. It happens when we have fall through
+    // cases. For example:
+    // case 1: case 2: ...; break;
+    // will result in the CaseSmt for case 2 nested in the one for case 1.
+    discoverAllCaseStmtInSwitchStmt(caseStmt ? caseStmt->getSubStmt()
+                                             : defaultStmt->getSubStmt(),
+                                    defaultBB, targets);
+  }
+
+  void processSwitchStmtUsingSpirvOpSwitch(const SwitchStmt *switchStmt) {
+    // First handle the condition variable DeclStmt if one exists.
+    // For example: handle 'int a = b' in the following:
+    // switch (int a = b) {...}
+    const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt();
+    if (condVarDeclStmt)
+      doStmt(condVarDeclStmt);
+
+    const uint32_t selector = doExpr(switchStmt->getCond());
+
+    // We need a merge block regardless of the number of switch cases.
+    // Since OpSwitch always requires a default label, if the switch statement
+    // does not have a default branch, we use the merge block as the default
+    // target.
+    const uint32_t mergeBB = theBuilder.createBasicBlock("switch.merge");
+    theBuilder.setMergeTarget(mergeBB);
+    breakStack.push(mergeBB);
+    uint32_t defaultBB = mergeBB;
+
+    // (literal, labelId) pairs to pass to the OpSwitch instruction.
+    std::vector<std::pair<uint32_t, uint32_t>> targets;
+    discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB,
+                                    &targets);
+
+    // Create the OpSelectionMerge and OpSwitch.
+    theBuilder.createSwitch(mergeBB, selector, defaultBB, targets);
+
+    // Handle the switch body.
+    doStmt(switchStmt->getBody());
+
+    if (!theBuilder.isCurrentBasicBlockTerminated())
+      theBuilder.createBranch(mergeBB);
+    theBuilder.setInsertPoint(mergeBB);
+    breakStack.pop();
+  }
+
+  void processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
+    emitError("Translating Switch statements using If statements is not "
+              "implemented yet.");
+  }
+
+  void doSwitchStmt(const SwitchStmt *switchStmt,
+                    llvm::ArrayRef<const Attr *> attrs = {}) {
+    // Switch statements are composed of:
+    //   switch (<condition variable>) {
+    //     <CaseStmt>
+    //     <CaseStmt>
+    //     <CaseStmt>
+    //     <DefaultStmt> (optional)
+    //   }
+    //
+    //                             +-------+
+    //                             | check |
+    //                             +-------+
+    //                                 |
+    //         +-------+-------+----------------+---------------+
+    //         | 1             | 2              | 3             | (others)
+    //         v               v                v               v
+    //     +-------+      +-------------+     +-------+     +------------+
+    //     | case1 |      | case2       |     | case3 | ... | default    |
+    //     |       |      |(fallthrough)|---->|       |     | (optional) |
+    //     +-------+      |+------------+     +-------+     +------------+
+    //         |                                  |                |
+    //         |                                  |                |
+    //         |   +-------+                      |                |
+    //         |   |       | <--------------------+                |
+    //         +-> | merge |                                       |
+    //             |       | <-------------------------------------+
+    //             +-------+
+
+    // If no attributes are given, or if "forcecase" attribute was provided,
+    // we'll do our best to use OpSwitch if possible.
+    // If any of the cases compares to a variable (rather than an integer
+    // literal), we cannot use OpSwitch because OpSwitch expects literal
+    // numbers as parameters.
+    const bool isAttrForceCase =
+        !attrs.empty() && attrs.front()->getKind() == attr::HLSLForceCase;
+    const bool canUseSpirvOpSwitch =
+        (attrs.empty() || isAttrForceCase) &&
+        allSwitchCasesAreIntegerLiterals(switchStmt->getBody());
+
+    if (isAttrForceCase && !canUseSpirvOpSwitch)
+      emitWarning("Ignored 'forcecase' attribute for the switch statement "
+                  "since one or more case values are not integer literals.");
+
+    if (canUseSpirvOpSwitch)
+      processSwitchStmtUsingSpirvOpSwitch(switchStmt);
+    else
+      processSwitchStmtUsingIfStmts(switchStmt);
+  }
+
+  void processCaseStmtOrDefaultStmt(const Stmt *stmt) {
+    auto *caseStmt = dyn_cast<CaseStmt>(stmt);
+    auto *defaultStmt = dyn_cast<DefaultStmt>(stmt);
+    assert(caseStmt || defaultStmt);
+
+    uint32_t caseBB = stmtBasicBlock[stmt];
+    if (!theBuilder.isCurrentBasicBlockTerminated()) {
+      // We are about to handle the case passed in as parameter. If the current
+      // basic block is not terminated, it means the previous case is a fall
+      // through case. We need to link it to the case to be processed.
+      theBuilder.createBranch(caseBB);
+      theBuilder.addSuccessor(caseBB);
+    }
+    theBuilder.setInsertPoint(caseBB);
+    doStmt(caseStmt ? caseStmt->getSubStmt() : defaultStmt->getSubStmt());
+  }
+
+  void doBreakStmt(const BreakStmt *breakStmt) {
+    uint32_t breakTargetBB = breakStack.top();
+    theBuilder.addSuccessor(breakTargetBB);
+    theBuilder.createBranch(breakTargetBB);
+  }
+
   void doIfStmt(const IfStmt *ifStmt) {
     // if statements are composed of:
     //   if (<check>) { <then> } else { <else> }
@@ -1862,6 +2092,24 @@ private:
   uint32_t entryFunctionId;
   /// The current function under traversal.
   const FunctionDecl *curFunction;
+
+  /// For loops, while loops, and switch statements may encounter "break"
+  /// statements that alter their control flow. At any point the break statement
+  /// is observed, the control flow jumps to the inner-most scope's merge block.
+  /// For instance: the break in the following example should cause a branch to
+  /// the SwitchMergeBB, not ForLoopMergeBB:
+  /// for (...) {
+  ///   switch(...) {
+  ///     case 1: break;
+  ///   }
+  ///   <--- SwitchMergeBB ---->
+  /// }
+  /// <----- ForLoopMergeBB --->
+  /// This stack keeps track of the basic blocks to which branching could occur.
+  std::stack<uint32_t> breakStack;
+
+  /// Maps a given statement to the basic block that is associated with it.
+  llvm::DenseMap<const Stmt *, uint32_t> stmtBasicBlock;
 };
 
 } // end namespace spirv

+ 14 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -220,6 +220,20 @@ uint32_t ModuleBuilder::createSelect(uint32_t resultType, uint32_t condition,
   return id;
 }
 
+void ModuleBuilder::createSwitch(
+    uint32_t mergeLabel, uint32_t selector, uint32_t defaultLabel,
+    llvm::ArrayRef<std::pair<uint32_t, uint32_t>> target) {
+  assert(insertPoint && "null insert point");
+  // Create the OpSelectioMerege.
+  instBuilder.opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
+      .x();
+  insertPoint->appendInstruction(std::move(constructSite));
+
+  // Create the OpSwitch.
+  instBuilder.opSwitch(selector, defaultLabel, target).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+}
+
 void ModuleBuilder::createBranch(uint32_t targetLabel) {
   assert(insertPoint && "null insert point");
 

+ 347 - 0
tools/clang/test/CodeGenSPIRV/switch-stmt.opswitch.hlsl

@@ -0,0 +1,347 @@
+// Run: %dxc -T ps_6_0 -E main
+
+int foo() { return 200; }
+
+void main() {
+  int result;
+
+
+
+  ////////////////////////////
+  // The most basic case    //
+  // Has a 'default' case   //
+  // All cases have 'break' //
+  ////////////////////////////
+  
+  int a = 0;
+// CHECK: [[a:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpSelectionMerge %switch_merge None
+// CHECK-NEXT: OpSwitch [[a]] %switch_default -3 %switch_n3 0 %switch_0 1 %switch_1 2 %switch_2
+  switch(a) {
+// CHECK-NEXT: %switch_n3 = OpLabel
+// CHECK-NEXT: OpStore %result %int_n300
+// CHECK-NEXT: OpBranch %switch_merge
+    case -3:
+      result = -300;
+      break;
+// CHECK-NEXT: %switch_0 = OpLabel
+// CHECK-NEXT: OpStore %result %int_0
+// CHECK-NEXT: OpBranch %switch_merge
+    case 0:
+      result = 0;
+      break;
+// CHECK-NEXT: %switch_1 = OpLabel
+// CHECK-NEXT: OpStore %result %int_100
+// CHECK-NEXT: OpBranch %switch_merge
+    case 1:
+      result = 100;
+      break;
+// CHECK-NEXT: %switch_2 = OpLabel
+// CHECK-NEXT: [[foo:%\d+]] = OpFunctionCall %int %foo
+// CHECK-NEXT: OpStore %result [[foo]]
+// CHECK-NEXT: OpBranch %switch_merge
+    case 2:
+      result = foo();
+      break;
+// CHECK-NEXT: %switch_default = OpLabel
+// CHECK-NEXT: OpStore %result %int_777
+// CHECK-NEXT: OpBranch %switch_merge
+    default:
+      result = 777;
+      break;
+  }
+// CHECK-NEXT: %switch_merge = OpLabel
+
+
+
+  ////////////////////////////////////
+  // The selector is a statement    //
+  // Does not have a 'default' case //
+  // All cases have 'break'         //  
+  ////////////////////////////////////
+
+// CHECK-NEXT: [[a1:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpStore %c [[a1]]
+// CHECK-NEXT: [[c:%\d+]] = OpLoad %int %c
+// CHECK-NEXT: OpSelectionMerge %switch_merge_0 None
+// CHECK-NEXT: OpSwitch [[c]] %switch_merge_0 -4 %switch_n4 4 %switch_4
+  switch(int c = a) {
+// CHECK-NEXT: %switch_n4 = OpLabel
+// CHECK-NEXT: OpStore %result %int_n400
+// CHECK-NEXT: OpBranch %switch_merge_0  
+    case -4:
+      result = -400;
+      break;
+// CHECK-NEXT: %switch_4 = OpLabel
+// CHECK-NEXT: OpStore %result %int_400
+// CHECK-NEXT: OpBranch %switch_merge_0
+    case 4:
+      result = 400;
+      break;
+  }
+// CHECK-NEXT: %switch_merge_0 = OpLabel
+
+
+
+  ///////////////////////////////////
+  // All cases are fall-through    //
+  // The last case is fall-through //
+  ///////////////////////////////////
+
+// CHECK-NEXT: [[a2:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpSelectionMerge %switch_merge_1 None
+// CHECK-NEXT: OpSwitch [[a2]] %switch_merge_1 -5 %switch_n5 5 %switch_5
+  switch(a) {
+// CHECK-NEXT: %switch_n5 = OpLabel
+// CHECK-NEXT: OpStore %result %int_n500
+// CHECK-NEXT: OpBranch %switch_5
+    case -5:
+      result = -500;
+// CHECK-NEXT: %switch_5 = OpLabel
+// CHECK-NEXT: OpStore %result %int_500
+// CHECK-NEXT: OpBranch %switch_merge_1
+    case 5:
+      result = 500;
+  }
+// CHECK-NEXT: %switch_merge_1 = OpLabel
+
+
+
+  ///////////////////////////////////////
+  // Some cases are fall-through       //
+  // The last case is not fall-through //
+  ///////////////////////////////////////
+
+// CHECK-NEXT: [[a3:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpSelectionMerge %switch_merge_2 None
+// CHECK-NEXT: OpSwitch [[a3]] %switch_default_0 6 %switch_6 7 %switch_7 8 %switch_8
+  switch(a) {
+// CHECK-NEXT: %switch_6 = OpLabel
+// CHECK-NEXT: OpStore %result %int_600
+// CHECK-NEXT: OpBranch %switch_7
+    case 6:
+      result = 600;
+    case 7:
+// CHECK-NEXT: %switch_7 = OpLabel
+// CHECK-NEXT: OpStore %result %int_700
+// CHECK-NEXT: OpBranch %switch_8
+      result = 700;
+// CHECK-NEXT: %switch_8 = OpLabel
+// CHECK-NEXT: OpStore %result %int_800
+// CHECK-NEXT: OpBranch %switch_merge_2
+    case 8:
+      result = 800;
+      break;
+// CHECK-NEXT: %switch_default_0 = OpLabel
+// CHECK-NEXT: OpStore %result %int_777
+// CHECK-NEXT: OpBranch %switch_merge_2
+    default:
+      result = 777;
+      break;
+  }
+// CHECK-NEXT: %switch_merge_2 = OpLabel
+
+
+
+  ///////////////////////////////////////
+  // Fall-through cases with no body   //
+  ///////////////////////////////////////
+
+// CHECK-NEXT: [[a4:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpSelectionMerge %switch_merge_3 None
+// CHECK-NEXT: OpSwitch [[a4]] %switch_default_1 10 %switch_10 11 %switch_11 12 %switch_12
+  switch(a) {
+// CHECK-NEXT: %switch_10 = OpLabel
+// CHECK-NEXT: OpBranch %switch_11
+    case 10:
+// CHECK-NEXT: %switch_11 = OpLabel
+// CHECK-NEXT: OpBranch %switch_default_1
+    case 11:
+// CHECK-NEXT: %switch_default_1 = OpLabel
+// CHECK-NEXT: OpBranch %switch_12
+    default:
+// CHECK-NEXT: %switch_12 = OpLabel
+// CHECK-NEXT: OpStore %result %int_12
+// CHECK-NEXT: OpBranch %switch_merge_3
+    case 12:
+      result = 12;
+  }
+// CHECK-NEXT: %switch_merge_3 = OpLabel
+
+
+
+  ////////////////////////////////////////////////
+  // No-op. Two nested cases and a nested break //
+  ////////////////////////////////////////////////
+
+// CHECK-NEXT: [[a5:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpSelectionMerge %switch_merge_4 None
+// CHECK-NEXT: OpSwitch [[a5]] %switch_merge_4 15 %switch_15 16 %switch_16
+  switch(a) {
+// CHECK-NEXT: %switch_15 = OpLabel
+// CHECK-NEXT: OpBranch %switch_16
+    case 15:
+// CHECK-NEXT: %switch_16 = OpLabel
+// CHECK-NEXT: OpBranch %switch_merge_4
+    case 16:
+      break;
+  }
+// CHECK-NEXT: %switch_merge_4 = OpLabel
+
+
+
+  ////////////////////////////////////////////////////////////////
+  // Using braces (compound statements) in various parts        //
+  // Using breaks such that each AST configuration is different //
+  // Also uses 'forcecase' attribute                            //
+  ////////////////////////////////////////////////////////////////
+
+// CHECK-NEXT: [[a6:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpSelectionMerge %switch_merge_5 None
+// CHECK-NEXT: OpSwitch [[a6]] %switch_merge_5 20 %switch_20 21 %switch_21 22 %switch_22 23 %switch_23 24 %switch_24 25 %switch_25 26 %switch_26 27 %switch_27 28 %switch_28 29 %switch_29
+  [forcecase] switch(a) {
+// CHECK-NEXT: %switch_20 = OpLabel
+// CHECK-NEXT: OpStore %result %int_20
+// CHECK-NEXT: OpBranch %switch_merge_5
+    case 20: {
+      result = 20;
+      break;
+    }
+// CHECK-NEXT: %switch_21 = OpLabel
+// CHECK-NEXT: OpStore %result %int_21
+// CHECK-NEXT: OpBranch %switch_merge_5
+    case 21:
+      result = 21;
+      break;
+// CHECK-NEXT: %switch_22 = OpLabel
+// CHECK-NEXT: OpBranch %switch_23
+// CHECK-NEXT: %switch_23 = OpLabel
+// CHECK-NEXT: OpBranch %switch_merge_5
+    case 22:
+    case 23:
+      break;
+// CHECK-NEXT: %switch_24 = OpLabel
+// CHECK-NEXT: OpBranch %switch_25
+// CHECK-NEXT: %switch_25 = OpLabel
+// CHECK-NEXT: OpStore %result %int_25
+// CHECK-NEXT: OpBranch %switch_merge_5
+    case 24:
+    case 25: { result = 25; }
+      break;
+// CHECK-NEXT: %switch_26 = OpLabel
+// CHECK-NEXT: OpBranch %switch_27
+// CHECK-NEXT: %switch_27 = OpLabel
+// CHECK-NEXT: OpBranch %switch_merge_5
+    case 26:
+    case 27: {
+      break;
+    }
+// CHECK-NEXT: %switch_28 = OpLabel
+// CHECK-NEXT: OpStore %result %int_28
+// CHECK-NEXT: OpBranch %switch_merge_5
+    case 28: {
+      result = 28;
+      {{break;}}
+    }
+// CHECK-NEXT: %switch_29 = OpLabel
+// CHECK-NEXT: OpStore %result %int_29
+// CHECK-NEXT: OpBranch %switch_merge_5
+    case 29: {
+      {
+        result = 29;
+        {break;}
+      }
+    }
+  }
+// CHECK-NEXT: %switch_merge_5 = OpLabel    
+
+
+
+  ////////////////////////////////////////////////////////////////////////
+  // Nested Switch statements with mixed use of fall-through and braces //
+  ////////////////////////////////////////////////////////////////////////
+
+// CHECK-NEXT: [[a7:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpSelectionMerge %switch_merge_6 None
+// CHECK-NEXT: OpSwitch [[a7]] %switch_merge_6 30 %switch_30
+  switch(a) {
+// CHECK-NEXT: %switch_30 = OpLabel
+    case 30: {
+// CHECK-NEXT: OpStore %result %int_30
+        result = 30;
+// CHECK-NEXT: [[result:%\d+]] = OpLoad %int %result
+// CHECK-NEXT: OpSelectionMerge %switch_merge_7 None
+// CHECK-NEXT: OpSwitch [[result]] %switch_default_2 50 %switch_50 51 %switch_51 52 %switch_52 53 %switch_53 54 %switch_54
+        switch(result) {
+// CHECK-NEXT: %switch_default_2 = OpLabel
+// CHECK-NEXT: OpStore %a %int_55
+// CHECK-NEXT: OpBranch %switch_50
+          default:
+            a = 55;
+// CHECK-NEXT: %switch_50 = OpLabel
+// CHECK-NEXT: OpStore %a %int_50
+// CHECK-NEXT: OpBranch %switch_merge_7
+          case 50:
+            a = 50;
+            break;
+// CHECK-NEXT: %switch_51 = OpLabel
+// CHECK-NEXT: OpBranch %switch_52
+          case 51:
+// CHECK-NEXT: %switch_52 = OpLabel
+// CHECK-NEXT: OpStore %a %int_52
+// CHECK-NEXT: OpBranch %switch_53
+          case 52:
+            a = 52;
+// CHECK-NEXT: %switch_53 = OpLabel
+// CHECK-NEXT: OpStore %a %int_53
+// CHECK-NEXT: OpBranch %switch_merge_7
+          case 53:
+            a = 53;
+            break;
+// CHECK-NEXT: %switch_54 = OpLabel
+// CHECK-NEXT: OpStore %a %int_54
+// CHECK-NEXT: OpBranch %switch_merge_7
+          case 54 : {
+            a = 54;
+            break;
+          }
+        }
+// CHECK-NEXT: %switch_merge_7 = OpLabel
+// CHECK-NEXT: OpBranch %switch_merge_6
+    }
+  }
+// CHECK-NEXT: %switch_merge_6 = OpLabel
+
+
+
+  ///////////////////////////////////////////////
+  // Constant integer variables as case values //
+  ///////////////////////////////////////////////
+
+  const int r = 35;
+  const int s = 45;
+  const int t = 2*r + s;  // evaluates to 115.
+
+// CHECK-NEXT: [[a8:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpSelectionMerge %switch_merge_8 None
+// CHECK-NEXT: OpSwitch [[a8]] %switch_merge_8 35 %switch_35 115 %switch_115
+  switch(a) {
+// CHECK-NEXT: %switch_35 = OpLabel
+// CHECK-NEXT: [[r:%\d+]] = OpLoad %int %r
+// CHECK-NEXT: OpStore %result [[r]]
+// CHECK-NEXT: OpBranch %switch_115
+    case r:
+      result = r;
+// CHECK-NEXT: %switch_115 = OpLabel
+// CHECK-NEXT: [[t:%\d+]] = OpLoad %int %t
+// CHECK-NEXT: OpStore %result [[t]]
+// CHECK-NEXT: OpBranch %switch_merge_8
+    case t:
+      result = t;
+      break;
+// CHECK-NEXT: %switch_merge_8 = OpLabel
+  }
+
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+}

+ 5 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -155,6 +155,11 @@ TEST_F(FileTest, CastSplatVector) { runFileTest("cast.vector.splat.hlsl"); }
 TEST_F(FileTest, IfStmtPlainAssign) { runFileTest("if-stmt.plain.hlsl"); }
 TEST_F(FileTest, IfStmtNestedIfStmt) { runFileTest("if-stmt.nested.hlsl"); }
 
+// For switch statements
+TEST_F(FileTest, SwitchStmtUsingOpSwitch) {
+  runFileTest("switch-stmt.opswitch.hlsl");
+}
+
 // For for statements
 TEST_F(FileTest, ForStmtPlainAssign) { runFileTest("for-stmt.plain.hlsl"); }
 TEST_F(FileTest, ForStmtNestedForStmt) { runFileTest("for-stmt.nested.hlsl"); }