Преглед изворни кода

[spirv] Translation of while loops (#532)

Also added the translation for `loop` and `unroll` attributes.

Similar to for loops, the current implementation does not support
early exits, or early returns.
Ehsan пре 8 година
родитељ
комит
ec9cfae3b6

+ 26 - 2
docs/SPIR-V.rst

@@ -353,7 +353,7 @@ Please note that "unlike short-circuit evaluation of ``&&``, ``||``, and ``?:``
 Unary operators
 +++++++++++++++
 
-FOr `unary operators <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509631(v=vs.85).aspx#Unary_Operators>`_:
+For `unary operators <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509631(v=vs.85).aspx#Unary_Operators>`_:
 
 - ``!`` is translated into ``OpLogicalNot``. Parsing will gurantee the operands are of boolean types by inserting necessary casts.
 - ``+`` requires no additional SPIR-V instructions.
@@ -384,7 +384,31 @@ The ``[]`` operator can also be used to access elements in a matrix or vector. A
 Control flows
 -------------
 
-[TODO]
+Switch Statements
++++++++++++++++++
+
+HLSL `switch statements <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509669(v=vs.85).aspx>`_ are translated into SPIR-V using:
+
+- **OpSwitch**: if (all case values are integer literals or constant integer variables) and (no attribute or the ``forcecase`` attribute is specified)
+- **A series of if statements**: for all other scenarios (e.g., when ``flatten``, ``branch``, or ``call`` attribute is specified)
+
+Loops
++++++
+
+HLSL `for statements <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509602(v=vs.85).aspx>`_ and `while statements <https://msdn.microsoft.com/en-us/library/windows/desktop/bb509708(v=vs.85).aspx>`_ are translated into SPIR-V by constructing all necessary basic blocks and using ``OpLoopMerge`` to organize as structured loops.
+The HLSL attributes for these statements are translated into SPIR-V loop control masks according to the following table:
+
++-------------------------+--------------------------------------------------+
+|   HLSL loop attribute   |            SPIR-V Loop Control Mask              |
++-------------------------+--------------------------------------------------+
+|        ``unroll(x)``    |                ``Unroll``                        |
++-------------------------+--------------------------------------------------+
+|         ``loop``        |              ``DontUnroll``                      |
++-------------------------+--------------------------------------------------+
+|        ``fastopt``      |              ``DontUnroll``                      |
++-------------------------+--------------------------------------------------+
+| ``allow_uav_condition`` |           Currently Unimplemented                |
++-------------------------+--------------------------------------------------+
 
 Functions
 ---------

+ 9 - 4
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -157,10 +157,15 @@ public:
   // will be created if mergeLabel is not 0 and continueLabel is 0.
   // An OpLoopMerge instruction will also be created if both continueLabel
   // and mergeLabel are not 0. For other cases, mergeLabel and continueLabel
-  // will be ignored.
-  void createConditionalBranch(uint32_t condition, uint32_t trueLabel,
-                               uint32_t falseLabel, uint32_t mergeLabel = 0,
-                               uint32_t continueLabel = 0);
+  // will be ignored. If selection control mask and/or loop control mask are
+  // provided, they will be applied to the corresponding SPIR-V instruction.
+  // Otherwise, MaskNone will be used.
+  void createConditionalBranch(
+      uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
+      uint32_t mergeLabel = 0, uint32_t continueLabel = 0,
+      spv::SelectionControlMask selectionControl =
+          spv::SelectionControlMask::MaskNone,
+      spv::LoopControlMask loopControl = spv::LoopControlMask::MaskNone);
 
   /// \brief Creates a return instruction.
   void createReturn();

+ 1 - 0
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -15,6 +15,7 @@
 #include "llvm/ADT/STLExtras.h"
 
 namespace clang {
+
 std::unique_ptr<ASTConsumer>
 EmitSPIRVAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
   return llvm::make_unique<spirv::SPIRVEmitter>(CI);

+ 7 - 12
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -241,24 +241,19 @@ void ModuleBuilder::createBranch(uint32_t targetLabel) {
   insertPoint->appendInstruction(std::move(constructSite));
 }
 
-void ModuleBuilder::createConditionalBranch(uint32_t condition,
-                                            uint32_t trueLabel,
-                                            uint32_t falseLabel,
-                                            uint32_t mergeLabel,
-                                            uint32_t continueLabel) {
+void ModuleBuilder::createConditionalBranch(
+    uint32_t condition, uint32_t trueLabel, uint32_t falseLabel,
+    uint32_t mergeLabel, uint32_t continueLabel,
+    spv::SelectionControlMask selectionControl,
+    spv::LoopControlMask loopControl) {
   assert(insertPoint && "null insert point");
 
   if (mergeLabel) {
     if (continueLabel) {
-      instBuilder
-          .opLoopMerge(mergeLabel, continueLabel,
-                       spv::LoopControlMask::MaskNone)
-          .x();
+      instBuilder.opLoopMerge(mergeLabel, continueLabel, loopControl).x();
       insertPoint->appendInstruction(std::move(constructSite));
     } else {
-      instBuilder
-          .opSelectionMerge(mergeLabel, spv::SelectionControlMask::MaskNone)
-          .x();
+      instBuilder.opSelectionMerge(mergeLabel, selectionControl).x();
       insertPoint->appendInstruction(std::move(constructSite));
     }
   }

+ 118 - 3
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -223,8 +223,10 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
     processCaseStmtOrDefaultStmt(stmt);
   } else if (const auto *breakStmt = dyn_cast<BreakStmt>(stmt)) {
     doBreakStmt(breakStmt);
+  } else if (const auto *whileStmt = dyn_cast<WhileStmt>(stmt)) {
+    doWhileStmt(whileStmt, attrs);
   } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
-    doForStmt(forStmt);
+    doForStmt(forStmt, attrs);
   } else if (const auto *nullStmt = dyn_cast<NullStmt>(stmt)) {
     // For the null statement ";". We don't need to do anything.
   } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
@@ -468,7 +470,115 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
   }
 }
 
-void SPIRVEmitter::doForStmt(const ForStmt *forStmt) {
+spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) {
+  switch (attr.getKind()) {
+  case attr::HLSLLoop:
+  case attr::HLSLFastOpt:
+    return spv::LoopControlMask::DontUnroll;
+  case attr::HLSLUnroll:
+    return spv::LoopControlMask::Unroll;
+  case attr::HLSLAllowUAVCondition:
+    emitWarning("Unsupported allow_uav_condition attribute ignored.");
+    break;
+  default:
+    emitError("Found unknown loop attribute.");
+  }
+  return spv::LoopControlMask::MaskNone;
+}
+
+void SPIRVEmitter::doWhileStmt(const WhileStmt *whileStmt,
+                               llvm::ArrayRef<const Attr *> attrs) {
+  // While loops are composed of:
+  //   while (<check>)  { <body> }
+  //
+  // SPIR-V requires loops to have a merge basic block as well as a continue
+  // basic block. Even though while loops do not have an explicit continue
+  // block as in for-loops, we still do need to create a continue block.
+  //
+  // Since SPIR-V requires structured control flow, we need two more basic
+  // blocks, <header> and <merge>. <header> is the block before control flow
+  // diverges, and <merge> is the block where control flow subsequently
+  // converges. The <check> block can take the responsibility of the <header>
+  // block. The final CFG should normally be like the following. Exceptions
+  // will occur with non-local exits like loop breaks or early returns.
+  //
+  //            +----------+
+  //            |  header  | <------------------+
+  //            | (check)  |                    |
+  //            +----------+                    |
+  //                 |                          |
+  //         +-------+-------+                  |
+  //         | false         | true             |
+  //         |               v                  |
+  //         |            +------+     +------------------+
+  //         |            | body | --> | continue (no-op) |
+  //         v            +------+     +------------------+
+  //     +-------+
+  //     | merge |
+  //     +-------+
+  //
+  // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
+
+  const spv::LoopControlMask loopControl =
+      attrs.empty() ? spv::LoopControlMask::MaskNone
+                    : translateLoopAttribute(*attrs.front());
+
+  // Create basic blocks
+  const uint32_t checkBB = theBuilder.createBasicBlock("while.check");
+  const uint32_t bodyBB = theBuilder.createBasicBlock("while.body");
+  const uint32_t continueBB = theBuilder.createBasicBlock("while.continue");
+  const uint32_t mergeBB = theBuilder.createBasicBlock("while.merge");
+
+  // Process the <check> block
+  theBuilder.createBranch(checkBB);
+  theBuilder.addSuccessor(checkBB);
+  theBuilder.setInsertPoint(checkBB);
+
+  // If we have:
+  // while (int a = foo()) {...}
+  // we should evaluate 'a' by calling 'foo()' every single time the check has
+  // to occur.
+  if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt())
+    doStmt(condVarDecl);
+
+  uint32_t condition = 0;
+  if (const Expr *check = whileStmt->getCond()) {
+    condition = doExpr(check);
+  } else {
+    condition = theBuilder.getConstantBool(true);
+  }
+  theBuilder.createConditionalBranch(condition, bodyBB,
+                                     /*false branch*/ mergeBB,
+                                     /*merge*/ mergeBB, continueBB,
+                                     spv::SelectionControlMask::MaskNone,
+                                     loopControl);
+  theBuilder.addSuccessor(bodyBB);
+  theBuilder.addSuccessor(mergeBB);
+  // The current basic block has OpLoopMerge instruction. We need to set its
+  // continue and merge target.
+  theBuilder.setContinueTarget(continueBB);
+  theBuilder.setMergeTarget(mergeBB);
+
+  // Process the <body> block
+  theBuilder.setInsertPoint(bodyBB);
+  if (const Stmt *body = whileStmt->getBody()) {
+    doStmt(body);
+  }
+  theBuilder.createBranch(continueBB);
+  theBuilder.addSuccessor(continueBB);
+
+  // Process the <continue> block. While loops do not have an explicit
+  // continue block. The continue block just branches to the <check> block.
+  theBuilder.setInsertPoint(continueBB);
+  theBuilder.createBranch(checkBB);
+  theBuilder.addSuccessor(checkBB);
+
+  // Set insertion point to the <merge> block for subsequent statements
+  theBuilder.setInsertPoint(mergeBB);
+}
+
+void SPIRVEmitter::doForStmt(const ForStmt *forStmt,
+                             llvm::ArrayRef<const Attr *> attrs) {
   // for loops are composed of:
   //   for (<init>; <check>; <continue>) <body>
   //
@@ -502,6 +612,9 @@ void SPIRVEmitter::doForStmt(const ForStmt *forStmt) {
   //     +-------+
   //
   // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
+  const spv::LoopControlMask loopControl =
+      attrs.empty() ? spv::LoopControlMask::MaskNone
+                    : translateLoopAttribute(*attrs.front());
 
   // Create basic blocks
   const uint32_t checkBB = theBuilder.createBasicBlock("for.check");
@@ -526,7 +639,9 @@ void SPIRVEmitter::doForStmt(const ForStmt *forStmt) {
   }
   theBuilder.createConditionalBranch(condition, bodyBB,
                                      /*false branch*/ mergeBB,
-                                     /*merge*/ mergeBB, continueBB);
+                                     /*merge*/ mergeBB, continueBB,
+                                     spv::SelectionControlMask::MaskNone,
+                                     loopControl);
   theBuilder.addSuccessor(bodyBB);
   theBuilder.addSuccessor(mergeBB);
   // The current basic block has OpLoopMerge instruction. We need to set its

+ 6 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -72,7 +72,8 @@ private:
   void doVarDecl(const VarDecl *decl);
 
   void doBreakStmt(const BreakStmt *stmt);
-  void doForStmt(const ForStmt *forStmt);
+  void doWhileStmt(const WhileStmt *, llvm::ArrayRef<const Attr *> attrs = {});
+  void doForStmt(const ForStmt *, llvm::ArrayRef<const Attr *> attrs = {});
   void doIfStmt(const IfStmt *ifStmt);
   void doReturnStmt(const ReturnStmt *stmt);
   void doSwitchStmt(const SwitchStmt *stmt,
@@ -254,6 +255,10 @@ private:
   uint32_t translateAPFloat(const llvm::APFloat &floatValue,
                             QualType targetType);
 
+  /// Translates the given HLSL loop attribute into SPIR-V loop control mask.
+  /// Emits an error if the given attribute is not a loop attribute.
+  spv::LoopControlMask translateLoopAttribute(const Attr &);
+
 private:
   static spv::ExecutionModel
   getSpirvShaderStage(const hlsl::ShaderModel &model);

+ 6 - 6
tools/clang/test/CodeGenSPIRV/for-stmt.nested.hlsl

@@ -13,9 +13,9 @@ void main() {
 // CHECK-LABEL: %for_check = OpLabel
 // CHECK-NEXT: [[i0:%\d+]] = OpLoad %int %i
 // CHECK-NEXT: [[lt0:%\d+]] = OpSLessThan %bool [[i0]] %int_10
-// CHECK-NEXT: OpLoopMerge %for_merge %for_continue None
+// CHECK-NEXT: OpLoopMerge %for_merge %for_continue Unroll
 // CHECK-NEXT: OpBranchConditional [[lt0]] %for_body %for_merge
-    for (int i = 0; i < 10; ++i) {
+    [unroll] for (int i = 0; i < 10; ++i) {
 // CHECK-LABEL: %for_body = OpLabel
 // CHECK-NEXT: [[val0:%\d+]] = OpLoad %int %val
 // CHECK-NEXT: [[i1:%\d+]] = OpLoad %int %i
@@ -27,18 +27,18 @@ void main() {
 // CHECK-LABEL: %for_check_0 = OpLabel
 // CHECK-NEXT: [[j0:%\d+]] = OpLoad %int %j
 // CHECK-NEXT: [[lt1:%\d+]] = OpSLessThan %bool [[j0]] %int_10
-// CHECK-NEXT: OpLoopMerge %for_merge_0 %for_continue_0 None
+// CHECK-NEXT: OpLoopMerge %for_merge_0 %for_continue_0 DontUnroll
 // CHECK-NEXT: OpBranchConditional [[lt1]] %for_body_0 %for_merge_0
-        for (int j = 0; j < 10; ++j) {
+        [loop] for (int j = 0; j < 10; ++j) {
 // CHECK-LABEL: %for_body_0 = OpLabel
 // CHECK-NEXT: OpBranch %for_check_1
 
 // CHECK-LABEL: %for_check_1 = OpLabel
 // CHECK-NEXT: [[k0:%\d+]] = OpLoad %int %k
 // CHECK-NEXT: [[lt2:%\d+]] = OpSLessThan %bool [[k0]] %int_10
-// CHECK-NEXT: OpLoopMerge %for_merge_1 %for_continue_1 None
+// CHECK-NEXT: OpLoopMerge %for_merge_1 %for_continue_1 DontUnroll
 // CHECK-NEXT: OpBranchConditional [[lt2]] %for_body_1 %for_merge_1
-            for (int k = 0; k < 10; ++k) {
+            [fastopt] for (int k = 0; k < 10; ++k) {
 // CHECK-LABEL: %for_body_1 = OpLabel
 // CHECK-NEXT: [[val1:%\d+]] = OpLoad %int %val
 // CHECK-NEXT: [[k1:%\d+]] = OpLoad %int %k

+ 78 - 0
tools/clang/test/CodeGenSPIRV/while-stmt.nested.hlsl

@@ -0,0 +1,78 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+  int val=0, i=0, j=0, k=0;
+
+// CHECK:      OpBranch %while_check
+// CHECK-NEXT: %while_check = OpLabel
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[i_lt_10:%\d+]] = OpSLessThan %bool [[i0]] %int_10
+// CHECK-NEXT: OpLoopMerge %while_merge %while_continue DontUnroll
+// CHECK-NEXT: OpBranchConditional [[i_lt_10]] %while_body %while_merge
+  [loop] while (i < 10) {
+// CHECK-NEXT: %while_body = OpLabel
+// CHECK-NEXT: [[val1:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[val_plus_i:%\d+]] = OpIAdd %int [[val1]] [[i1]]
+// CHECK-NEXT: OpStore %val [[val_plus_i]]
+// CHECK-NEXT: OpBranch %while_check_0
+    val = val + i;
+// CHECK-NEXT: %while_check_0 = OpLabel
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[j_lt_20:%\d+]] = OpSLessThan %bool [[j0]] %int_20
+// CHECK-NEXT: OpLoopMerge %while_merge_0 %while_continue_0 Unroll
+// CHECK-NEXT: OpBranchConditional [[j_lt_20]] %while_body_0 %while_merge_0
+    [unroll(20)] while (j < 20) {
+// CHECK-NEXT: %while_body_0 = OpLabel
+// CHECK-NEXT: OpBranch %while_check_1
+
+// CHECK-NEXT: %while_check_1 = OpLabel
+// CHECK-NEXT: [[k0:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[k_lt_30:%\d+]] = OpSLessThan %bool [[k0]] %int_30
+// CHECK-NEXT: OpLoopMerge %while_merge_1 %while_continue_1 DontUnroll
+// CHECK-NEXT: OpBranchConditional [[k_lt_30]] %while_body_1 %while_merge_1
+      [fastopt] while (k < 30) {
+// CHECK-NEXT: %while_body_1 = OpLabel
+// CHECK-NEXT: [[val2:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[k2:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[val_plus_k:%\d+]] = OpIAdd %int [[val2]] [[k2]]
+// CHECK-NEXT: OpStore %val [[val_plus_k]]
+        val = val + k;
+// CHECK-NEXT: [[k3:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[k_plus_1:%\d+]] = OpIAdd %int [[k3]] %int_1
+// CHECK-NEXT: OpStore %k [[k_plus_1]]
+        ++k;
+// CHECK-NEXT: OpBranch %while_continue_1
+// CHECK-NEXT: %while_continue_1 = OpLabel
+// CHECK-NEXT: OpBranch %while_check_1
+      }
+// CHECK-NEXT: %while_merge_1 = OpLabel
+
+// CHECK-NEXT: [[val3:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[val_mult_2:%\d+]] = OpIMul %int [[val3]] %int_2
+// CHECK-NEXT: OpStore %val [[val_mult_2]]
+      val = val * 2;
+// CHECK-NEXT: [[j1:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[j_plus_1:%\d+]] = OpIAdd %int [[j1]] %int_1
+// CHECK-NEXT: OpStore %j [[j_plus_1]]
+      ++j;
+// CHECK-NEXT: OpBranch %while_continue_0
+// CHECK-NEXT: %while_continue_0 = OpLabel
+// CHECK-NEXT: OpBranch %while_check_0
+    }
+// CHECK-NEXT: %while_merge_0 = OpLabel
+
+// CHECK-NEXT: [[i2:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[i_plus_1:%\d+]] = OpIAdd %int [[i2]] %int_1
+// CHECK-NEXT: OpStore %i [[i_plus_1]]
+    ++i;
+// CHECK-NEXT: OpBranch %while_continue
+// CHECK-NEXT: %while_continue = OpLabel
+// CHECK-NEXT: OpBranch %while_check
+  }
+// CHECK-NEXT: %while_merge = OpLabel
+
+
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+}

+ 100 - 0
tools/clang/test/CodeGenSPIRV/while-stmt.plain.hlsl

@@ -0,0 +1,100 @@
+// Run: %dxc -T ps_6_0 -E main
+
+int foo() { return true; }
+
+void main() {
+  int val = 0;
+  int i = 0;
+
+    //////////////////////////
+    //// Basic while loop ////
+    //////////////////////////
+
+// CHECK:      OpBranch %while_check
+// CHECK-NEXT: %while_check = OpLabel
+
+// CHECK-NEXT: [[i:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[i_lt_10:%\d+]] = OpSLessThan %bool [[i]] %int_10
+// CHECK-NEXT: OpLoopMerge %while_merge %while_continue None
+// CHECK-NEXT: OpBranchConditional [[i_lt_10]] %while_body %while_merge
+  while (i < 10) {
+// CHECK-NEXT: %while_body = OpLabel
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: OpStore %val [[i1]]
+      val = i;
+// CHECK-NEXT: OpBranch %while_continue
+// CHECK-NEXT: %while_continue = OpLabel
+// CHECK-NEXT: OpBranch %while_check
+  }
+// CHECK-NEXT: %while_merge = OpLabel
+
+
+
+    //////////////////////////
+    ////  infinite loop   ////
+    //////////////////////////
+
+// CHECK-NEXT: OpBranch %while_check_0
+// CHECK-NEXT: %while_check_0 = OpLabel
+// CHECK-NEXT: OpLoopMerge %while_merge_0 %while_continue_0 None
+// CHECK-NEXT: OpBranchConditional %true %while_body_0 %while_merge_0
+  while (true) {
+// CHECK-NEXT: %while_body_0 = OpLabel
+// CHECK-NEXT: OpStore %val %int_0
+      val = 0;
+// CHECK-NEXT: OpBranch %while_continue_0
+// CHECK-NEXT: %while_continue_0 = OpLabel
+// CHECK-NEXT: OpBranch %while_check_0
+  }
+// CHECK-NEXT: %while_merge_0 = OpLabel
+// CHECK-NEXT: OpBranch %while_check_1
+
+
+
+    //////////////////////////
+    ////    Null Body     ////
+    //////////////////////////
+
+// CHECK-NEXT: %while_check_1 = OpLabel
+// CHECK-NEXT: [[val1:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[val_lt_20:%\d+]] = OpSLessThan %bool [[val1]] %int_20
+// CHECK-NEXT: OpLoopMerge %while_merge_1 %while_continue_1 None
+// CHECK-NEXT: OpBranchConditional [[val_lt_20]] %while_body_1 %while_merge_1
+  while (val < 20)
+// CHECK-NEXT: %while_body_1 = OpLabel
+// CHECK-NEXT: OpBranch %while_continue_1
+// CHECK-NEXT: %while_continue_1 = OpLabel
+// CHECK-NEXT: OpBranch %while_check_1
+    ;
+// CHECK-NEXT: %while_merge_1 = OpLabel
+// CHECK-NEXT: OpBranch %while_check_2
+
+
+
+    ////////////////////////////////////////////////////////////////
+    //// Condition variable has VarDecl                         ////
+    //// foo() returns an integer which must be cast to boolean ////
+    ////////////////////////////////////////////////////////////////
+
+// CHECK-NEXT: %while_check_2 = OpLabel
+// CHECK-NEXT: [[foo:%\d+]] = OpFunctionCall %int %foo
+// CHECK-NEXT: OpStore %a [[foo]]
+// CHECK-NEXT: [[a:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[is_a_true:%\d+]] = OpINotEqual %bool [[a]] %int_0
+// CHECK-NEXT: OpLoopMerge %while_merge_2 %while_continue_2 None
+// CHECK-NEXT: OpBranchConditional [[is_a_true]] %while_body_2 %while_merge_2
+  while (int a = foo()) {
+// CHECK-NEXT: %while_body_2 = OpLabel
+// CHECK-NEXT: [[a1:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: OpStore %val [[a1]]
+    val = a;
+// CHECK-NEXT: OpBranch %while_continue_2
+// CHECK-NEXT: %while_continue_2 = OpLabel
+// CHECK-NEXT: OpBranch %while_check_2
+  }
+// CHECK-NEXT: %while_merge_2 = OpLabel
+
+
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+}

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -204,6 +204,10 @@ TEST_F(FileTest, SwitchStmtUsingIfStmt) {
 TEST_F(FileTest, ForStmtPlainAssign) { runFileTest("for-stmt.plain.hlsl"); }
 TEST_F(FileTest, ForStmtNestedForStmt) { runFileTest("for-stmt.nested.hlsl"); }
 
+// For while statements
+TEST_F(FileTest, WhileStmtPlain) { runFileTest("while-stmt.plain.hlsl"); }
+TEST_F(FileTest, WhileStmtNested) { runFileTest("while-stmt.nested.hlsl"); }
+
 // For control flows
 TEST_F(FileTest, ControlFlowNestedIfForStmt) { runFileTest("cf.if.for.hlsl"); }
 TEST_F(FileTest, ControlFlowLogicalAnd) { runFileTest("cf.logical-and.hlsl"); }