Selaa lähdekoodia

Handle structured conditions in for and while loops (#4783)

Daniele Vettorel 2 vuotta sitten
vanhempi
commit
4349c3e1f7

+ 250 - 71
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -2000,14 +2000,50 @@ void SpirvEmitter::doWhileStmt(const WhileStmt *whileStmt,
   //     | merge |
   //     +-------+
   //
+  // The only exception is when the condition cannot be expressed in a single
+  // block. Specifically, short-circuited operators end up producing multiple
+  // blocks. In that case, we cannot treat the <check> block as the header
+  // block, and must instead have a bespoke <header> block. The condition is
+  // then moved into the loop. For example, given a loop in the form
+  //   while (a && b) { <body> }
+  // we will generate instructions for the equivalent loop
+  //   while (true) { if (!(a && b)) { break }  <body> }
+  //            +----------+
+  //            |  header  | <------------------+
+  //            +----------+                    |
+  //                 |                          |
+  //                 v                          |
+  //            +----------+                    |
+  //            |  check   |                    |
+  //            +----------+                    |
+  //                 |                          |
+  //         +-------+-------+                  |
+  //         | false         | true             |
+  //         |               v                  |
+  //         |            +------+     +------------------+
+  //         |            | body | --> | continue (no-op) |
+  //         v            +------+     +------------------+
+  //     +-------+
+  //     | merge |
+  //     +-------+
+  // The reason we don't unconditionally apply this transformation, which is
+  // technically always legal, is because it prevents loop unrolling in SPIR-V
+  // Tools, which does not support unrolling loops with early breaks.
   // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
 
   const spv::LoopControlMask loopControl =
       attrs.empty() ? spv::LoopControlMask::MaskNone
                     : translateLoopAttribute(whileStmt, *attrs.front());
 
+  const Expr *check = whileStmt->getCond();
+  const Stmt *body = whileStmt->getBody();
+  bool checkHasShortcircuitedOp = stmtTreeContainsShortCircuitedOp(check);
+
   // Create basic blocks
   auto *checkBB = spvBuilder.createBasicBlock("while.check");
+  auto *headerBB = checkHasShortcircuitedOp
+                       ? spvBuilder.createBasicBlock("while.header")
+                       : checkBB;
   auto *bodyBB = spvBuilder.createBasicBlock("while.body");
   auto *continueBB = spvBuilder.createBasicBlock("while.continue");
   auto *mergeBB = spvBuilder.createBasicBlock("while.merge");
@@ -2017,42 +2053,80 @@ void SpirvEmitter::doWhileStmt(const WhileStmt *whileStmt,
   continueStack.push(continueBB);
   breakStack.push(mergeBB);
 
-  // Process the <check> block
-  spvBuilder.createBranch(checkBB, whileStmt->getLocStart());
-  spvBuilder.addSuccessor(checkBB);
-  spvBuilder.setInsertPoint(checkBB);
-
-  // If we have:
-  // while (int a = foo()) {...}
-  // we should evaluate 'a' by calling 'foo()' every single time the check has
-  // to occur.
-  if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt())
-    doStmt(condVarDecl);
+  spvBuilder.createBranch(headerBB, whileStmt->getLocStart());
+  spvBuilder.addSuccessor(headerBB);
+  spvBuilder.setInsertPoint(headerBB);
+  if (checkHasShortcircuitedOp) {
+    // Process the <header> block.
+    spvBuilder.setInsertPoint(headerBB);
+    spvBuilder.createBranch(
+        checkBB,
+        check ? check->getLocStart()
+              : (body ? body->getLocStart() : whileStmt->getLocStart()),
+        mergeBB, continueBB, loopControl,
+        check
+            ? check->getSourceRange()
+            : SourceRange(whileStmt->getLocStart(), whileStmt->getLocStart()));
+    spvBuilder.addSuccessor(checkBB);
+    // The current basic block has a OpLoopMerge instruction. We need to set
+    // its continue and merge target.
+    spvBuilder.setContinueTarget(continueBB);
+    spvBuilder.setMergeTarget(mergeBB);
 
-  SpirvInstruction *condition = nullptr;
-  const Expr *check = whileStmt->getCond();
-  if (check) {
-    condition = doExpr(check);
+    // Process the <check> block.
+    spvBuilder.setInsertPoint(checkBB);
+
+    // If we have:
+    //   while (int a = foo()) {...}
+    // we should evaluate 'a' by calling 'foo()' every single time the check has
+    // to occur.
+    if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt())
+      doStmt(condVarDecl);
+
+    SpirvInstruction *condition = doExpr(check);
+    spvBuilder.createConditionalBranch(
+        condition, bodyBB, mergeBB,
+        check ? check->getLocEnd()
+              : (body ? body->getLocStart() : whileStmt->getLocStart()),
+        nullptr, nullptr, spv::SelectionControlMask::MaskNone,
+        spv::LoopControlMask::MaskNone,
+        check
+            ? check->getSourceRange()
+            : SourceRange(whileStmt->getLocStart(), whileStmt->getLocStart()));
+    spvBuilder.addSuccessor(bodyBB);
+    spvBuilder.addSuccessor(mergeBB);
   } else {
-    condition = spvBuilder.getConstantBool(true);
+    // In the case of simple or empty conditions, we can use a
+    // single block for <check> and <header>.
+
+    // If we have:
+    //   while (int a = foo()) {...}
+    // we should evaluate 'a' by calling 'foo()' every single time the check has
+    // to occur.
+    if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt())
+      doStmt(condVarDecl);
+
+    SpirvInstruction *condition = nullptr;
+    if (check) {
+      condition = doExpr(check);
+    } else {
+      condition = spvBuilder.getConstantBool(true);
+    }
+    spvBuilder.createConditionalBranch(
+        condition, bodyBB, mergeBB, whileStmt->getLocStart(), mergeBB,
+        continueBB, spv::SelectionControlMask::MaskNone, loopControl,
+        check ? check->getSourceRange()
+              : SourceRange(whileStmt->getWhileLoc(), whileStmt->getLocEnd()));
+    spvBuilder.addSuccessor(bodyBB);
+    spvBuilder.addSuccessor(mergeBB);
+    // The current basic block has OpLoopMerge instruction. We need to set its
+    // continue and merge target.
+    spvBuilder.setContinueTarget(continueBB);
+    spvBuilder.setMergeTarget(mergeBB);
   }
-  spvBuilder.createConditionalBranch(
-      condition, bodyBB,
-      /*false branch*/ mergeBB, whileStmt->getLocStart(),
-      /*merge*/ mergeBB, continueBB, spv::SelectionControlMask::MaskNone,
-      loopControl,
-      check ? check->getSourceRange()
-            : SourceRange(whileStmt->getWhileLoc(), whileStmt->getLocEnd()));
-  spvBuilder.addSuccessor(bodyBB);
-  spvBuilder.addSuccessor(mergeBB);
-  // The current basic block has OpLoopMerge instruction. We need to set its
-  // continue and merge target.
-  spvBuilder.setContinueTarget(continueBB);
-  spvBuilder.setMergeTarget(mergeBB);
 
-  // Process the <body> block
+  // Process the <body> block.
   spvBuilder.setInsertPoint(bodyBB);
-  const Stmt *body = whileStmt->getBody();
   if (body) {
     doStmt(body);
   }
@@ -2061,12 +2135,12 @@ void SpirvEmitter::doWhileStmt(const WhileStmt *whileStmt,
   spvBuilder.addSuccessor(continueBB);
 
   // Process the <continue> block. While loops do not have an explicit
-  // continue block. The continue block just branches to the <check> block.
+  // continue block. The continue block just branches to the <header> block.
   spvBuilder.setInsertPoint(continueBB);
-  spvBuilder.createBranch(checkBB, whileStmt->getLocEnd());
-  spvBuilder.addSuccessor(checkBB);
+  spvBuilder.createBranch(headerBB, whileStmt->getLocEnd());
+  spvBuilder.addSuccessor(headerBB);
 
-  // Set insertion point to the <merge> block for subsequent statements
+  // Set insertion point to the <merge> block for subsequent statements.
   spvBuilder.setInsertPoint(mergeBB);
 
   // Done with the current scope's continue and merge blocks.
@@ -2108,13 +2182,56 @@ void SpirvEmitter::doForStmt(const ForStmt *forStmt,
   //     | merge |
   //     +-------+
   //
+  // The only exception is when the condition cannot be expressed in a single
+  // block. Specifically, short-circuited operators end up producing multiple
+  // blocks. In that case, we cannot treat the <check> block as the header
+  // block, and must instead have a bespoke <header> block. The condition is
+  // then moved into the loop. For example, given a loop in the form
+  //   for (<init>; a && b; <continue>) { <body> }
+  // we will generate instructions for the equivalent loop
+  //   for (<init>; ; <continue>) { if (!(a && b)) { break }  <body> }
+  //             +--------+
+  //             |  init  |
+  //             +--------+
+  //                 |
+  //                 v
+  //            +----------+
+  //            |  header  | <---------------+
+  //            +----------+                 |
+  //                 |                       |
+  //                 v                       |
+  //            +----------+                 |
+  //            |  check   |                 |
+  //            +----------+                 |
+  //                 |                       |
+  //         +-------+-------+               |
+  //         | false         | true          |
+  //         |               v               |
+  //         |            +------+     +----------+
+  //         |            | body | --> | continue |
+  //         v            +------+     +----------+
+  //     +-------+
+  //     | merge |
+  //     +-------+
+  // The reason we don't unconditionally apply this transformation, which is
+  // technically always legal, is because it prevents loop unrolling in SPIR-V
+  // Tools, which does not support unrolling loops with early breaks.
   // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
   const spv::LoopControlMask loopControl =
       attrs.empty() ? spv::LoopControlMask::MaskNone
                     : translateLoopAttribute(forStmt, *attrs.front());
 
-  // Create basic blocks
+  const Stmt *initStmt = forStmt->getInit();
+  const Stmt *body = forStmt->getBody();
+  const Expr *check = forStmt->getCond();
+
+  bool checkHasShortcircuitedOp = stmtTreeContainsShortCircuitedOp(check);
+
+  // Create basic blocks.
   auto *checkBB = spvBuilder.createBasicBlock("for.check");
+  auto *headerBB = checkHasShortcircuitedOp
+                       ? spvBuilder.createBasicBlock("for.header")
+                       : checkBB;
   auto *bodyBB = spvBuilder.createBasicBlock("for.body");
   auto *continueBB = spvBuilder.createBasicBlock("for.continue");
   auto *mergeBB = spvBuilder.createBasicBlock("for.merge");
@@ -2124,47 +2241,78 @@ void SpirvEmitter::doForStmt(const ForStmt *forStmt,
   continueStack.push(continueBB);
   breakStack.push(mergeBB);
 
-  // Process the <init> block
-  const Stmt *initStmt = forStmt->getInit();
+  // Process the <init> block.
   if (initStmt) {
     doStmt(initStmt);
   }
-  const Expr *check = forStmt->getCond();
   spvBuilder.createBranch(
-      checkBB, check ? check->getLocStart() : forStmt->getLocStart(), nullptr,
+      headerBB, check ? check->getLocStart() : forStmt->getLocStart(), nullptr,
       nullptr, spv::LoopControlMask::MaskNone,
       initStmt ? initStmt->getSourceRange()
                : SourceRange(forStmt->getLocStart(), forStmt->getLocStart()));
-  spvBuilder.addSuccessor(checkBB);
+  spvBuilder.addSuccessor(headerBB);
 
-  // Process the <check> block
-  spvBuilder.setInsertPoint(checkBB);
-  SpirvInstruction *condition = nullptr;
-  if (check) {
-    condition = doExpr(check);
+  if (checkHasShortcircuitedOp) {
+    // Process the <header> block.
+    spvBuilder.setInsertPoint(headerBB);
+    spvBuilder.createBranch(
+        checkBB,
+        check ? check->getLocStart()
+              : (body ? body->getLocStart() : forStmt->getLocStart()),
+        mergeBB, continueBB, loopControl,
+        check ? check->getSourceRange()
+              : (initStmt ? initStmt->getSourceRange()
+                          : SourceRange(forStmt->getLocStart(),
+                                        forStmt->getLocStart())));
+    spvBuilder.addSuccessor(checkBB);
+    // The current basic block has a OpLoopMerge instruction. We need to set
+    // its continue and merge target.
+    spvBuilder.setContinueTarget(continueBB);
+    spvBuilder.setMergeTarget(mergeBB);
+
+    // Process the <check> block.
+    spvBuilder.setInsertPoint(checkBB);
+    SpirvInstruction *condition = doExpr(check);
+    spvBuilder.createConditionalBranch(
+        condition, bodyBB, mergeBB,
+        check ? check->getLocEnd()
+              : (body ? body->getLocStart() : forStmt->getLocStart()),
+        nullptr, nullptr, spv::SelectionControlMask::MaskNone,
+        spv::LoopControlMask::MaskNone,
+        check ? check->getSourceRange()
+              : (initStmt ? initStmt->getSourceRange()
+                          : SourceRange(forStmt->getLocStart(),
+                                        forStmt->getLocStart())));
+    spvBuilder.addSuccessor(bodyBB);
+    spvBuilder.addSuccessor(mergeBB);
   } else {
-    condition = spvBuilder.getConstantBool(true);
+    // In the case of simple or empty conditions, we can use a
+    // single block for <check> and <header>.
+    spvBuilder.setInsertPoint(checkBB);
+    SpirvInstruction *condition = nullptr;
+    if (check) {
+      condition = doExpr(check);
+    } else {
+      condition = spvBuilder.getConstantBool(true);
+    }
+    spvBuilder.createConditionalBranch(
+        condition, bodyBB, mergeBB,
+        check ? check->getLocEnd()
+              : (body ? body->getLocStart() : forStmt->getLocStart()),
+        mergeBB, continueBB, spv::SelectionControlMask::MaskNone, loopControl,
+        check ? check->getSourceRange()
+              : (initStmt ? initStmt->getSourceRange()
+                          : SourceRange(forStmt->getLocStart(),
+                                        forStmt->getLocStart())));
+    spvBuilder.addSuccessor(bodyBB);
+    spvBuilder.addSuccessor(mergeBB);
+    // The current basic block has a OpLoopMerge instruction. We need to set
+    // its continue and merge target.
+    spvBuilder.setContinueTarget(continueBB);
+    spvBuilder.setMergeTarget(mergeBB);
   }
-  const Stmt *body = forStmt->getBody();
-  spvBuilder.createConditionalBranch(
-      condition, bodyBB,
-      /*false branch*/ mergeBB,
-      check ? check->getLocEnd()
-            : (body ? body->getLocStart() : forStmt->getLocStart()),
-      /*merge*/ mergeBB, continueBB, spv::SelectionControlMask::MaskNone,
-      loopControl,
-      check ? check->getSourceRange()
-            : (initStmt ? initStmt->getSourceRange()
-                        : SourceRange(forStmt->getLocStart(),
-                                      forStmt->getLocStart())));
-  spvBuilder.addSuccessor(bodyBB);
-  spvBuilder.addSuccessor(mergeBB);
-  // The current basic block has OpLoopMerge instruction. We need to set its
-  // continue and merge target.
-  spvBuilder.setContinueTarget(continueBB);
-  spvBuilder.setMergeTarget(mergeBB);
 
-  // Process the <body> block
+  // Process the <body> block.
   spvBuilder.setInsertPoint(bodyBB);
   if (body) {
     doStmt(body);
@@ -2178,20 +2326,19 @@ void SpirvEmitter::doForStmt(const ForStmt *forStmt,
              : SourceRange(forStmt->getLocStart(), forStmt->getLocStart()));
   spvBuilder.addSuccessor(continueBB);
 
-  // Process the <continue> block
+  // Process the <continue> block. It will jump back to the header.
   spvBuilder.setInsertPoint(continueBB);
   if (cont) {
     doExpr(cont);
   }
-  // <continue> should jump back to header
   spvBuilder.createBranch(
-      checkBB, forStmt->getLocEnd(), nullptr, nullptr,
+      headerBB, forStmt->getLocEnd(), nullptr, nullptr,
       spv::LoopControlMask::MaskNone,
       cont ? cont->getSourceRange()
            : SourceRange(forStmt->getLocStart(), forStmt->getLocStart()));
-  spvBuilder.addSuccessor(checkBB);
+  spvBuilder.addSuccessor(headerBB);
 
-  // Set insertion point to the <merge> block for subsequent statements
+  // Set insertion point to the <merge> block for subsequent statements.
   spvBuilder.setInsertPoint(mergeBB);
 
   // Done with the current scope's continue block and merge block.
@@ -6479,6 +6626,38 @@ bool SpirvEmitter::isVectorShuffle(const Expr *expr) {
   return false;
 }
 
+bool SpirvEmitter::isShortCircuitedOp(const Expr *expr) {
+  if (!expr || !getCompilerInstance().getLangOpts().EnableShortCircuit) {
+    return false;
+  }
+
+  const auto *binOp = dyn_cast<BinaryOperator>(expr->IgnoreParens());
+  if (binOp) {
+    return binOp->getOpcode() == BO_LAnd || binOp->getOpcode() == BO_LOr;
+  }
+
+  const auto *condOp = dyn_cast<ConditionalOperator>(expr->IgnoreParens());
+  return condOp;
+}
+
+bool SpirvEmitter::stmtTreeContainsShortCircuitedOp(const Stmt *stmt) {
+  if (!stmt) {
+    return false;
+  }
+
+  if (isShortCircuitedOp(dyn_cast<Expr>(stmt))) {
+    return true;
+  }
+
+  for (const auto *child : stmt->children()) {
+    if (stmtTreeContainsShortCircuitedOp(child)) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
 bool SpirvEmitter::isTextureMipsSampleIndexing(const CXXOperatorCallExpr *expr,
                                                const Expr **base,
                                                const Expr **location,

+ 7 - 0
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -226,6 +226,13 @@ private:
   ///   the original vector, no shuffling needed).
   bool isVectorShuffle(const Expr *expr);
 
+  /// Returns true if the given expression is a short-circuited operator.
+  bool isShortCircuitedOp(const Expr *expr);
+
+  /// Returns true if the given statement or any of its children are a
+  /// short-circuited operator.
+  bool stmtTreeContainsShortCircuitedOp(const Stmt *stmt);
+
   /// \brief Returns true if the given CXXOperatorCallExpr is indexing into a
   /// Buffer/RWBuffer/Texture/RWTexture using operator[].
   /// On success, writes the base buffer into *base if base is not nullptr, and

+ 75 - 0
tools/clang/test/CodeGenSPIRV/cf.for.short-circuited-cond.hlsl

@@ -0,0 +1,75 @@
+// RUN: %dxc -T ps_6_0 -E main -HV 2021
+
+void main() {
+  bool a, b;
+  // CHECK:      OpBranch %for_header
+  // CHECK-NEXT: %for_header = OpLabel
+  // CHECK-NEXT: OpLoopMerge %for_merge %for_continue None
+  // CHECK-NEXT: OpBranch %for_check
+  // CHECK-NEXT: %for_check = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %for_body %for_merge
+  for (int i = 0; a && b; ++i) {
+    // CHECK-NEXT: %for_body = OpLabel
+    // CHECK-NEXT: OpBranch %for_continue
+    // CHECK-NEXT: %for_continue = OpLabel
+    // CHECK:      OpBranch %for_header
+  }
+  // CHECK-NEXT: %for_merge = OpLabel
+
+  // CHECK:      OpBranch %for_header_0
+  // CHECK-NEXT: %for_header_0 = OpLabel
+  // CHECK-NEXT: OpLoopMerge %for_merge_0 %for_continue_0 None
+  // CHECK-NEXT: OpBranch %for_check_0
+  // CHECK-NEXT: %for_check_0 = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %for_body_0 %for_merge_0
+  for (int i = 0; a || b; ++i) {
+    // CHECK-NEXT: %for_body_0 = OpLabel
+    // CHECK-NEXT: OpBranch %for_continue_0
+    // CHECK-NEXT: %for_continue_0 = OpLabel
+    // CHECK:      OpBranch %for_header_0
+  }
+  // CHECK-NEXT: %for_merge_0 = OpLabel
+
+  // CHECK:      OpBranch %for_header_1
+  // CHECK-NEXT: %for_header_1 = OpLabel
+  // CHECK-NEXT: OpLoopMerge %for_merge_1 %for_continue_1 None
+  // CHECK-NEXT: OpBranch %for_check_1
+  // CHECK-NEXT: %for_check_1 = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %for_body_1 %for_merge_1
+  for (int i = 0; a && ((a || b) && b); ++i) {
+    // CHECK-NEXT: %for_body_1 = OpLabel
+    // CHECK-NEXT: OpBranch %for_continue_1
+    // CHECK-NEXT: %for_continue_1 = OpLabel
+    // CHECK:      OpBranch %for_header_1
+  }
+  // CHECK-NEXT: %for_merge_1 = OpLabel
+
+  // CHECK:      OpBranch %for_header_2
+  // CHECK-NEXT: %for_header_2 = OpLabel
+  // CHECK-NEXT: OpLoopMerge %for_merge_2 %for_continue_2 None
+  // CHECK-NEXT: OpBranch %for_check_2
+  // CHECK-NEXT: %for_check_2 = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %for_body_2 %for_merge_2
+  for (int i = 0; a ? a : b; ++i) {
+    // CHECK-NEXT: %for_body_2 = OpLabel
+    // CHECK-NEXT: OpBranch %for_continue_2
+    // CHECK-NEXT: %for_continue_2 = OpLabel
+    // CHECK:      OpBranch %for_header_2
+  }
+  // CHECK-NEXT: %for_merge_2 = OpLabel
+
+  int x, y;
+  // CHECK:      OpBranch %for_header_3
+  // CHECK-NEXT: %for_header_3 = OpLabel
+  // CHECK-NEXT: OpLoopMerge %for_merge_3 %for_continue_3 None
+  // CHECK-NEXT: OpBranch %for_check_3
+  // CHECK-NEXT: %for_check_3 = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %for_body_3 %for_merge_3
+  for (int i = 0; x + (x && y); ++i) {
+    // CHECK-NEXT: %for_body_3 = OpLabel
+    // CHECK-NEXT: OpBranch %for_continue_3
+    // CHECK-NEXT: %for_continue_3 = OpLabel
+    // CHECK:      OpBranch %for_header_3
+  }
+  // CHECK-NEXT: %for_merge_3 = OpLabel
+}

+ 75 - 0
tools/clang/test/CodeGenSPIRV/cf.while.short-circuited-cond.hlsl

@@ -0,0 +1,75 @@
+// RUN: %dxc -T ps_6_0 -E main -HV 2021
+
+void main() {
+  bool a, b;
+  // CHECK:      OpBranch %while_header
+  // CHECK-NEXT: %while_header = OpLabel
+  // CHECK-NEXT: OpLoopMerge %while_merge %while_continue None
+  // CHECK-NEXT: OpBranch %while_check
+  // CHECK-NEXT: %while_check = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %while_body %while_merge
+  while (a && b) {
+    // CHECK-NEXT: %while_body = OpLabel
+    // CHECK-NEXT: OpBranch %while_continue
+    // CHECK-NEXT: %while_continue = OpLabel
+    // CHECK:      OpBranch %while_header
+  }
+  // CHECK-NEXT: %while_merge = OpLabel
+
+  // CHECK:      OpBranch %while_header_0
+  // CHECK-NEXT: %while_header_0 = OpLabel
+  // CHECK-NEXT: OpLoopMerge %while_merge_0 %while_continue_0 None
+  // CHECK-NEXT: OpBranch %while_check_0
+  // CHECK-NEXT: %while_check_0 = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %while_body_0 %while_merge_0
+  while (a || b) {
+    // CHECK-NEXT: %while_body_0 = OpLabel
+    // CHECK-NEXT: OpBranch %while_continue_0
+    // CHECK-NEXT: %while_continue_0 = OpLabel
+    // CHECK:      OpBranch %while_header_0
+  }
+  // CHECK-NEXT: %while_merge_0 = OpLabel
+
+  // CHECK:      OpBranch %while_header_1
+  // CHECK-NEXT: %while_header_1 = OpLabel
+  // CHECK-NEXT: OpLoopMerge %while_merge_1 %while_continue_1 None
+  // CHECK-NEXT: OpBranch %while_check_1
+  // CHECK-NEXT: %while_check_1 = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %while_body_1 %while_merge_1
+  while (a && ((a || b) && b)) {
+    // CHECK-NEXT: %while_body_1 = OpLabel
+    // CHECK-NEXT: OpBranch %while_continue_1
+    // CHECK-NEXT: %while_continue_1 = OpLabel
+    // CHECK:      OpBranch %while_header_1
+  }
+  // CHECK-NEXT: %while_merge_1 = OpLabel
+
+  // CHECK:      OpBranch %while_header_2
+  // CHECK-NEXT: %while_header_2 = OpLabel
+  // CHECK-NEXT: OpLoopMerge %while_merge_2 %while_continue_2 None
+  // CHECK-NEXT: OpBranch %while_check_2
+  // CHECK-NEXT: %while_check_2 = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %while_body_2 %while_merge_2
+  while (a ? a : b) {
+    // CHECK-NEXT: %while_body_2 = OpLabel
+    // CHECK-NEXT: OpBranch %while_continue_2
+    // CHECK-NEXT: %while_continue_2 = OpLabel
+    // CHECK:      OpBranch %while_header_2
+  }
+  // CHECK-NEXT: %while_merge_2 = OpLabel
+
+  int x, y;
+  // CHECK:      OpBranch %while_header_3
+  // CHECK-NEXT: %while_header_3 = OpLabel
+  // CHECK-NEXT: OpLoopMerge %while_merge_3 %while_continue_3 None
+  // CHECK-NEXT: OpBranch %while_check_3
+  // CHECK-NEXT: %while_check_3 = OpLabel
+  // CHECK:      OpBranchConditional {{%\d+}} %while_body_3 %while_merge_3
+  while (x + (x && y)) {
+    // CHECK-NEXT: %while_body_3 = OpLabel
+    // CHECK-NEXT: OpBranch %while_continue_3
+    // CHECK-NEXT: %while_continue_3 = OpLabel
+    // CHECK:      OpBranch %while_header_3
+  }
+  // CHECK-NEXT: %while_merge_3 = OpLabel
+}

+ 2 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -559,12 +559,14 @@ TEST_F(FileTest, ForStmtPlainAssign) { runFileTest("cf.for.plain.hlsl"); }
 TEST_F(FileTest, ForStmtNestedForStmt) { runFileTest("cf.for.nested.hlsl"); }
 TEST_F(FileTest, ForStmtContinue) { runFileTest("cf.for.continue.hlsl"); }
 TEST_F(FileTest, ForStmtBreak) { runFileTest("cf.for.break.hlsl"); }
+TEST_F(FileTest, ForStmtShortCircuitedCond) { runFileTest("cf.for.short-circuited-cond.hlsl"); }
 
 // For while statements
 TEST_F(FileTest, WhileStmtPlain) { runFileTest("cf.while.plain.hlsl"); }
 TEST_F(FileTest, WhileStmtNested) { runFileTest("cf.while.nested.hlsl"); }
 TEST_F(FileTest, WhileStmtContinue) { runFileTest("cf.while.continue.hlsl"); }
 TEST_F(FileTest, WhileStmtBreak) { runFileTest("cf.while.break.hlsl"); }
+TEST_F(FileTest, WhileStmtShortCircuitedCond) { runFileTest("cf.while.short-circuited-cond.hlsl"); }
 
 // For do statements
 TEST_F(FileTest, DoStmtPlain) { runFileTest("cf.do.plain.hlsl"); }