Browse Source

[spirv] Handle basic block orders (#481)

SPIR-V spec requires basic blocks to appear in an order satisfying
the dominator-tree direction. Added a BlockReadableOrderVisitor
class for visiting basic blocks in a human-readable order following
the spec's requirements.

Also fixed a bug in Constant::operator==(), which causes hashing
weirdness.
Lei Zhang 8 years ago
parent
commit
e02ff7fa4e

+ 59 - 0
tools/clang/include/clang/SPIRV/BlockReadableOrder.h

@@ -0,0 +1,59 @@
+//===--- BlockReadableOrder.h - Visit blocks in human readable order ------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// The SPIR-V spec requires code blocks to appear in an order satisfying the
+// dominator-tree direction (ie, dominator before the dominated).  This is,
+// actually, easy to achieve: any pre-order CFG traversal algorithm will do it.
+// Because such algorithms visit a block only after traversing some path to it
+// from the root, they necessarily visit the block's immediate dominator first.
+//
+// But not every graph-traversal algorithm outputs blocks in an order that
+// appears logical to human readers.  The problem is that unrelated branches may
+// be interspersed with each other, and merge blocks may come before some of the
+// branches being merged.
+//
+// A good, human-readable order of blocks may be achieved by performing
+// depth-first search but delaying continue and merge nodes until after all
+// their branches have been visited.  This is implemented below by the
+// BlockReadableOrderVisitor.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_SPIRV_BLOCKREADABLEORDER_H
+#define LLVM_CLANG_SPIRV_BLOCKREADABLEORDER_H
+
+#include "clang/SPIRV/Structure.h"
+#include "llvm/ADT/DenseSet.h"
+
+namespace clang {
+namespace spirv {
+
+/// \brief A basic block visitor traversing basic blocks in a human readable
+/// order and calling a pre-set callback on each basic block.
+class BlockReadableOrderVisitor {
+public:
+  explicit BlockReadableOrderVisitor(std::function<void(BasicBlock *)> cb)
+      : callback(cb) {}
+
+  /// \brief Recursively visits all blocks reachable from the given starting
+  /// basic block in a depth-first manner and calls the callback passed-in
+  /// during construction on each basic block.
+  void visit(BasicBlock *block);
+
+private:
+  std::function<void(BasicBlock *)> callback;
+
+  llvm::DenseSet<BasicBlock *> doneBlocks; ///< Blocks already visited
+  llvm::DenseSet<BasicBlock *> todoBlocks; ///< Blocks to be visited later
+};
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif

+ 2 - 2
tools/clang/include/clang/SPIRV/Constant.h

@@ -95,8 +95,8 @@ public:
                                           DecorationSet dec = {});
 
   bool operator==(const Constant &other) const {
-    return opcode == other.opcode && args == other.args &&
-           decorations == other.decorations;
+    return opcode == other.opcode && typeId == other.typeId &&
+           args == other.args && decorations == other.decorations;
   }
 
   // \brief Construct the SPIR-V words for this constant with the given

+ 28 - 2
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -67,16 +67,35 @@ public:
   /// for the basic block. On failure, returns zero.
   uint32_t createBasicBlock(llvm::StringRef name = "");
 
+  /// \brief Adds the basic block with the given label as a successor to the
+  /// current basic block.
+  void addSuccessor(uint32_t successorLabel);
+
+  /// \brief Sets the merge target to the basic block with the given <label-id>.
+  /// The caller must make sure the current basic block contains an
+  /// OpSelectionMerge or OpLoopMerge instruction.
+  void setMergeTarget(uint32_t mergeLabel);
+
+  /// \brief Sets the continue target to the basic block with the given
+  /// <label-id>. The caller must make sure the current basic block contains an
+  /// OpLoopMerge instruction.
+  void setContinueTarget(uint32_t continueLabel);
+
   /// \brief Returns true if the current basic block inserting into is
   /// terminated.
   inline bool isCurrentBasicBlockTerminated() const;
 
   /// \brief Sets insertion point to the basic block with the given <label-id>.
-  /// Returns true on success, false on failure.
-  bool setInsertPoint(uint32_t labelId);
+  void setInsertPoint(uint32_t labelId);
 
   // === Instruction at the current Insertion Point ===
 
+  /// \brief Creates a composite construct instruction with the given
+  /// <result-type> and constituents and returns the <result-id> for the
+  /// composite.
+  uint32_t createCompositeConstruct(uint32_t resultType,
+                                    llvm::ArrayRef<uint32_t> constituents);
+
   /// \brief Creates a load instruction loading the value of the given
   /// <result-type> from the given pointer. Returns the <result-id> for the
   /// loaded value.
@@ -169,9 +188,16 @@ public:
 private:
   /// \brief Map from basic blocks' <label-id> to their structured
   /// representation.
+  ///
+  /// We need MapVector here to remember the order of insertion. Order matters
+  /// here since, for example, we'll know for sure the first basic block is the
+  /// entry block.
   using OrderedBasicBlockMap =
       llvm::MapVector<uint32_t, std::unique_ptr<BasicBlock>>;
 
+  /// \brief Returns the basic block with the given <label-id>.
+  BasicBlock *getBasicBlock(uint32_t label);
+
   SPIRVContext &theContext; ///< The SPIR-V context.
   SPIRVModule theModule;    ///< The module under building.
 

+ 48 - 1
tools/clang/include/clang/SPIRV/Structure.h

@@ -108,12 +108,40 @@ public:
   /// \brief Preprends an instruction to this basic block.
   inline void prependInstruction(Instruction &&);
 
+  /// \brief Adds the given basic block as a successsor to this basic block.
+  inline void addSuccessor(BasicBlock *);
+
+  /// \brief Gets all successor basic blocks.
+  inline const llvm::SmallVector<BasicBlock *, 2> &getSuccessors() const;
+
+  /// \brief Sets the merge target to the given basic block.
+  /// The caller must make sure this basic block contains an OpSelectionMerge or
+  /// OpLoopMerge instruction.
+  inline void setMergeTarget(BasicBlock *);
+
+  /// \brief Returns the merge target if this basic block contains an
+  /// OpSelectionMerge or OpLoopMerge instruction. Returns nullptr otherwise.
+  inline BasicBlock *getMergeTarget() const;
+
+  /// \brief Sets the continue target to the given basic block.
+  /// The caller must make sure this basic block contains an OpLoopMerge
+  /// instruction.
+  inline void setContinueTarget(BasicBlock *);
+
+  /// \brief Returns the continue target if this basic block contains an
+  /// OpLoopMerge instruction. Returns nullptr otherwise.
+  inline BasicBlock *getContinueTarget() const;
+
   /// \brief Returns true if this basic block is terminated.
   bool isTerminated() const;
 
 private:
   uint32_t labelId; ///< The label id for this basic block. Zero means invalid.
   std::deque<Instruction> instructions;
+
+  llvm::SmallVector<BasicBlock *, 2> successors;
+  BasicBlock *mergeTarget;
+  BasicBlock *continueTarget;
 };
 
 // === Function definition ===
@@ -327,7 +355,8 @@ std::vector<uint32_t> Instruction::take() { return std::move(words); }
 
 // === Basic block inline implementations ===
 
-BasicBlock::BasicBlock(uint32_t id) : labelId(id) {}
+BasicBlock::BasicBlock(uint32_t id)
+    : labelId(id), mergeTarget(nullptr), continueTarget(nullptr) {}
 
 bool BasicBlock::isEmpty() const {
   return labelId == 0 && instructions.empty();
@@ -341,6 +370,24 @@ void BasicBlock::prependInstruction(Instruction &&inst) {
   instructions.push_front(std::move(inst));
 }
 
+void BasicBlock::addSuccessor(BasicBlock *successor) {
+  successors.push_back(successor);
+}
+
+const llvm::SmallVector<BasicBlock *, 2> &BasicBlock::getSuccessors() const {
+  return successors;
+}
+
+void BasicBlock::setMergeTarget(BasicBlock *target) { mergeTarget = target; }
+
+BasicBlock *BasicBlock::getMergeTarget() const { return mergeTarget; }
+
+void BasicBlock::setContinueTarget(BasicBlock *target) {
+  continueTarget = target;
+}
+
+BasicBlock *BasicBlock::getContinueTarget() const { return continueTarget; }
+
 // === Function inline implementations ===
 
 Function::Function(uint32_t rType, uint32_t rId,

+ 52 - 0
tools/clang/lib/SPIRV/BlockReadableOrder.cpp

@@ -0,0 +1,52 @@
+//===--- BlockReadableOrder.cpp - BlockReadableOrderVisitor impl ----------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/SPIRV/BlockReadableOrder.h"
+
+namespace clang {
+namespace spirv {
+
+void BlockReadableOrderVisitor::visit(BasicBlock *block) {
+  if (doneBlocks.count(block) || todoBlocks.count(block))
+    return;
+
+  callback(block);
+
+  doneBlocks.insert(block);
+
+  // Check the continue and merge targets. If any one of them exists, we need
+  // to make sure visiting it is delayed until we've done the rest.
+
+  BasicBlock *continueBlock = block->getContinueTarget();
+  BasicBlock *mergeBlock = block->getMergeTarget();
+
+  if (continueBlock)
+    todoBlocks.insert(continueBlock);
+
+  if (mergeBlock)
+    todoBlocks.insert(mergeBlock);
+
+  for (BasicBlock *successor : block->getSuccessors())
+    visit(successor);
+
+  // Handle continue and merge targets now.
+
+  if (continueBlock) {
+    todoBlocks.erase(continueBlock);
+    visit(continueBlock);
+  }
+
+  if (mergeBlock) {
+    todoBlocks.erase(mergeBlock);
+    visit(mergeBlock);
+  }
+}
+
+} // end namespace spirv
+} // end namespace clang

+ 1 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -3,6 +3,7 @@ set(LLVM_LINK_COMPONENTS
   )
 
 add_clang_library(clangSPIRV
+  BlockReadableOrder.cpp
   Constant.cpp
   DeclResultIdMapper.cpp
   Decoration.cpp

+ 38 - 13
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -348,19 +348,24 @@ public:
     // We'll need the <label-id> for the then/else/merge block to do so.
     const bool hasElse = ifStmt->getElse() != nullptr;
     const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
-    const uint32_t elseBB = hasElse ? theBuilder.createBasicBlock("if.false")
-                                    : theBuilder.createBasicBlock("if.merge");
-    const uint32_t mergeBB =
-        hasElse ? theBuilder.createBasicBlock("if.merge") : elseBB;
+    const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
+    const uint32_t elseBB =
+        hasElse ? theBuilder.createBasicBlock("if.false") : mergeBB;
 
     // Create the branch instruction. This will end the current basic block.
     theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB);
+    theBuilder.addSuccessor(thenBB);
+    theBuilder.addSuccessor(elseBB);
+    // The current basic block has the OpSelectionMerge instruction. We need
+    // to record its merge target.
+    theBuilder.setMergeTarget(mergeBB);
 
     // Handle the then branch
     theBuilder.setInsertPoint(thenBB);
     doStmt(ifStmt->getThen());
     if (!theBuilder.isCurrentBasicBlockTerminated())
       theBuilder.createBranch(mergeBB);
+    theBuilder.addSuccessor(mergeBB);
 
     // Handle the else branch (if exists)
     if (hasElse) {
@@ -368,6 +373,7 @@ public:
       doStmt(ifStmt->getElse());
       if (!theBuilder.isCurrentBasicBlockTerminated())
         theBuilder.createBranch(mergeBB);
+      theBuilder.addSuccessor(mergeBB);
     }
 
     // From now on, we'll emit instructions into the merge block.
@@ -420,6 +426,7 @@ public:
       doStmt(initStmt);
     }
     theBuilder.createBranch(checkBB);
+    theBuilder.addSuccessor(checkBB);
 
     // Process the <check> block
     theBuilder.setInsertPoint(checkBB);
@@ -432,6 +439,12 @@ public:
     theBuilder.createConditionalBranch(condition, bodyBB,
                                        /*false branch*/ mergeBB,
                                        /*merge*/ mergeBB, continueBB);
+    theBuilder.addSuccessor(bodyBB);
+    theBuilder.addSuccessor(mergeBB);
+    // The current basic block has OpLoopMerge instruction. We need to set its
+    // continue and merge target.
+    theBuilder.setContinueTarget(continueBB);
+    theBuilder.setMergeTarget(mergeBB);
 
     // Process the <body> block
     theBuilder.setInsertPoint(bodyBB);
@@ -439,6 +452,7 @@ public:
       doStmt(body);
     }
     theBuilder.createBranch(continueBB);
+    theBuilder.addSuccessor(continueBB);
 
     // Process the <continue> block
     theBuilder.setInsertPoint(continueBB);
@@ -446,6 +460,7 @@ public:
       doExpr(cont);
     }
     theBuilder.createBranch(checkBB); // <continue> should jump back to header
+    theBuilder.addSuccessor(checkBB);
 
     // Set insertion point to the <merge> block for subsequent statements
     theBuilder.setInsertPoint(mergeBB);
@@ -497,10 +512,9 @@ public:
 
       if (expr->isConstantInitializer(astContext, false)) {
         return theBuilder.getConstantComposite(resultType, constituents);
+      } else {
+        return theBuilder.createCompositeConstruct(resultType, constituents);
       }
-      // TODO: use OpCompositeConstruct for non-constant initializer lists.
-      emitError("Non-const initializer lists are currently not supported.");
-      return 0;
     } else if (auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
       const bool value = boolLiteral->getValue();
       return theBuilder.getConstantBool(value);
@@ -592,13 +606,13 @@ public:
     const QualType toType = expr->getType();
 
     switch (expr->getCastKind()) {
-    // Integer literals in the AST are represented using 64bit APInt
-    // themselves and then implicitly casted into the expected bitwidth.
-    // We need special treatment of integer literals here because generating
-    // a 64bit constant and then explicit casting in SPIR-V requires Int64
-    // capability. We should avoid introducing unnecessary capabilities to
-    // our best.
     case CastKind::CK_IntegralCast: {
+      // Integer literals in the AST are represented using 64bit APInt
+      // themselves and then implicitly casted into the expected bitwidth.
+      // We need special treatment of integer literals here because generating
+      // a 64bit constant and then explicit casting in SPIR-V requires Int64
+      // capability. We should avoid introducing unnecessary capabilities to
+      // our best.
       llvm::APSInt intValue;
       if (expr->EvaluateAsInt(intValue, astContext, Expr::SE_NoSideEffects)) {
         return translateAPInt(intValue, toType);
@@ -607,6 +621,17 @@ public:
         return 0;
       }
     }
+    case CastKind::CK_FloatingCast: {
+      // First try to see if we can do constant folding for floating point
+      // numbers like what we are doing for integers in the above.
+      Expr::EvalResult evalResult;
+      if (expr->EvaluateAsRValue(evalResult, astContext) &&
+          !evalResult.HasSideEffects) {
+        return translateAPFloat(evalResult.Val.getFloat(), toType);
+      }
+      emitError("floating cast unimplemented");
+      return 0;
+    }
     case CastKind::CK_LValueToRValue: {
       const uint32_t fromValue = doExpr(subExpr);
       // Using lvalue as rvalue means we need to OpLoad the contents from

+ 37 - 8
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -109,14 +109,33 @@ uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name) {
   return labelId;
 }
 
-bool ModuleBuilder::setInsertPoint(uint32_t labelId) {
-  auto it = basicBlocks.find(labelId);
-  if (it == basicBlocks.end()) {
-    assert(false && "invalid <label-id>");
-    return false;
-  }
-  insertPoint = it->second.get();
-  return true;
+void ModuleBuilder::addSuccessor(uint32_t successorLabel) {
+  assert(insertPoint && "null insert point");
+  insertPoint->addSuccessor(getBasicBlock(successorLabel));
+}
+
+void ModuleBuilder::setMergeTarget(uint32_t mergeLabel) {
+  assert(insertPoint && "null insert point");
+  insertPoint->setMergeTarget(getBasicBlock(mergeLabel));
+}
+
+void ModuleBuilder::setContinueTarget(uint32_t continueLabel) {
+  assert(insertPoint && "null insert point");
+  insertPoint->setContinueTarget(getBasicBlock(continueLabel));
+}
+
+void ModuleBuilder::setInsertPoint(uint32_t labelId) {
+  insertPoint = getBasicBlock(labelId);
+}
+
+uint32_t
+ModuleBuilder::createCompositeConstruct(uint32_t resultType,
+                                        llvm::ArrayRef<uint32_t> constituents) {
+  assert(insertPoint && "null insert point");
+  const uint32_t resultId = theContext.takeNextId();
+  instBuilder.opCompositeConstruct(resultType, resultId, constituents).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return resultId;
 }
 
 uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
@@ -350,5 +369,15 @@ ModuleBuilder::getConstantComposite(uint32_t typeId,
   return constId;
 }
 
+BasicBlock *ModuleBuilder::getBasicBlock(uint32_t labelId) {
+  auto it = basicBlocks.find(labelId);
+  if (it == basicBlocks.end()) {
+    assert(false && "invalid <label-id>");
+    return nullptr;
+  }
+
+  return it->second.get();
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 16 - 1
tools/clang/lib/SPIRV/Structure.cpp

@@ -9,6 +9,8 @@
 
 #include "clang/SPIRV/Structure.h"
 
+#include "clang/SPIRV/BlockReadableOrder.h"
+
 namespace clang {
 namespace spirv {
 
@@ -122,6 +124,10 @@ void Function::take(InstBuilder *builder) {
     builder->opFunctionParameter(param.first, param.second).x();
   }
 
+  if (!variables.empty()) {
+    assert(!blocks.empty());
+  }
+
   // Preprend all local variables to the entry block.
   // This is necessary since SPIR-V requires all local variables to be defined
   // at the very begining of the entry block.
@@ -131,8 +137,17 @@ void Function::take(InstBuilder *builder) {
     blocks.front()->prependInstruction(std::move(*it));
   }
 
+  // Collect basic blocks in a human-readable order that satisfies SPIR-V
+  // validation rules.
+  std::vector<BasicBlock *> orderedBlocks;
+  if (!blocks.empty()) {
+    BlockReadableOrderVisitor([&orderedBlocks](BasicBlock *block) {
+      orderedBlocks.push_back(block);
+    }).visit(blocks.front().get());
+  }
+
   // Write out all basic blocks.
-  for (auto &block : blocks) {
+  for (auto *block : orderedBlocks) {
     block->take(builder);
   }
 

+ 164 - 0
tools/clang/test/CodeGenSPIRV/cf.if.for.hlsl

@@ -0,0 +1,164 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// Stage IO variables
+// CHECK-DAG: [[color:%\d+]] = OpVariable %_ptr_Input_float Input
+// CHECK-DAG: [[target:%\d+]] = OpVariable %_ptr_Output_v4float Output
+
+float4 main(float color: COLOR) : SV_TARGET {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+// CHECK-NEXT: %val = OpVariable %_ptr_Function_float Function %float_0
+    float val = 0.;
+// CHECK-NEXT: %i = OpVariable %_ptr_Function_int Function %int_0
+// CHECK-NEXT: %j = OpVariable %_ptr_Function_int Function %int_0
+// CHECK-NEXT: %k = OpVariable %_ptr_Function_int Function %int_0
+
+// CHECK-NEXT: [[color0:%\d+]] = OpLoad %float [[color]]
+// CHECK-NEXT: [[lt0:%\d+]] = OpFOrdLessThan %bool [[color0]] %float_0_3
+// CHECK-NEXT: OpSelectionMerge %if_merge None
+// CHECK-NEXT: OpBranchConditional [[lt0]] %if_true %if_merge
+    if (color < 0.3) {
+// CHECK-LABEL: %if_true = OpLabel
+// CHECK-NEXT: OpStore %val %float_1
+        val = 1.;
+// CHECK-NEXT: OpBranch %if_merge
+    }
+// CHECK-LABEL: %if_merge = OpLabel
+// CHECK-NEXT: OpBranch %for_check
+
+    // for-stmt following if-stmt
+// CHECK-LABEL: %for_check = OpLabel
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[lt1:%\d+]] = OpSLessThan %bool [[i0]] %int_10
+// CHECK-NEXT: OpLoopMerge %for_merge %for_continue None
+// CHECK-NEXT: OpBranchConditional [[lt1]] %for_body %for_merge
+    for (int i = 0; i < 10; ++i) {
+// CHECK-LABEL: %for_body = OpLabel
+// CHECK-NEXT: [[color1:%\d+]] = OpLoad %float [[color]]
+// CHECK-NEXT: [[lt2:%\d+]] = OpFOrdLessThan %bool [[color1]] %float_0_5
+// CHECK-NEXT: OpSelectionMerge %if_merge_0 None
+// CHECK-NEXT: OpBranchConditional [[lt2]] %if_true_0 %if_merge_0
+        if (color < 0.5) { // if-stmt nested in for-stmt
+// CHECK-LABEL: %if_true_0 = OpLabel
+// CHECK-NEXT: [[val0:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[add1:%\d+]] = OpFAdd %float [[val0]] %float_1
+// CHECK-NEXT: OpStore %val [[add1]]
+            val = val + 1.;
+// CHECK-NEXT: OpBranch %for_check_0
+
+// CHECK-LABEL: %for_check_0 = OpLabel
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[lt3:%\d+]] = OpSLessThan %bool [[j0]] %int_15
+// CHECK-NEXT: OpLoopMerge %for_merge_0 %for_continue_0 None
+// CHECK-NEXT: OpBranchConditional [[lt3]] %for_body_0 %for_merge_0
+            for (int j = 0; j < 15; ++j) { // for-stmt deeply nested in if-then
+// CHECK-LABEL: %for_body_0 = OpLabel
+// CHECK-NEXT: [[val1:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[mul2:%\d+]] = OpFMul %float [[val1]] %float_2
+// CHECK-NEXT: OpStore %val [[mul2]]
+                val = val * 2.;
+// CHECK-NEXT: OpBranch %for_continue_0
+
+// CHECK-LABEL: %for_continue_0 = OpLabel
+// CHECK-NEXT: [[j1:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[incj:%\d+]] = OpIAdd %int [[j1]] %int_1
+// CHECK-NEXT: OpStore %j [[incj]]
+// CHECK-NEXT: OpBranch %for_check_0
+            } // end for (int j
+// CHECK-LABEL: %for_merge_0 = OpLabel
+// CHECK-NEXT: [[val2:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[add3:%\d+]] = OpFAdd %float [[val2]] %float_3
+// CHECK-NEXT: OpStore %val [[add3]]
+
+            val = val + 3.;
+// CHECK-NEXT: OpBranch %if_merge_0
+        }
+// CHECK-LABEL: %if_merge_0 = OpLabel
+
+// CHECK-NEXT: [[color2:%\d+]] = OpLoad %float [[color]]
+// CHECK-NEXT: [[lt4:%\d+]] = OpFOrdLessThan %bool [[color2]] %float_0_8
+// CHECK-NEXT: OpSelectionMerge %if_merge_1 None
+// CHECK-NEXT: OpBranchConditional [[lt4]] %if_true_1 %if_false
+        if (color < 0.8) { // if-stmt following if-stmt
+// CHECK-LABEL: %if_true_1 = OpLabel
+// CHECK-NEXT: [[val3:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[mul4:%\d+]] = OpFMul %float [[val3]] %float_4
+// CHECK-NEXT: OpStore %val [[mul4]]
+            val = val * 4.;
+// CHECK-NEXT: OpBranch %if_merge_1
+        } else {
+// CHECK-LABEL: %if_false = OpLabel
+// CHECK-NEXT: OpBranch %for_check_1
+
+// CHECK-LABEL: %for_check_1 = OpLabel
+// CHECK-NEXT: [[k0:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[lt5:%\d+]] = OpSLessThan %bool [[k0]] %int_20
+// CHECK-NEXT: OpLoopMerge %for_merge_1 %for_continue_1 None
+// CHECK-NEXT: OpBranchConditional [[lt5]] %for_body_1 %for_merge_1
+            for (int k = 0; k < 20; ++k) { // for-stmt deeply nested in if-else
+// CHECK-LABEL: %for_body_1 = OpLabel
+// CHECK-NEXT: [[val4:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[sub5:%\d+]] = OpFSub %float [[val4]] %float_5
+// CHECK-NEXT: OpStore %val [[sub5]]
+                val = val - 5.;
+
+// CHECK-NEXT: [[val5:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[lt6:%\d+]] = OpFOrdLessThan %bool [[val5]] %float_0
+// CHECK-NEXT: OpSelectionMerge %if_merge_2 None
+// CHECK-NEXT: OpBranchConditional [[lt6]] %if_true_2 %if_merge_2
+                if (val < 0.) { // deeply nested if-stmt
+// CHECK-LABEL: %if_true_2 = OpLabel
+// CHECK-NEXT: [[val6:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[add100:%\d+]] = OpFAdd %float [[val6]] %float_100
+// CHECK-NEXT: OpStore %val [[add100]]
+                    val = val + 100.;
+// CHECK-NEXT: OpBranch %if_merge_2
+                }
+// CHECK-LABEL: %if_merge_2 = OpLabel
+// CHECK-NEXT: OpBranch %for_continue_1
+
+// CHECK-LABEL: %for_continue_1 = OpLabel
+// CHECK-NEXT: [[k1:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[inck:%\d+]] = OpIAdd %int [[k1]] %int_1
+// CHECK-NEXT: OpStore %k [[inck]]
+// CHECK-NEXT: OpBranch %for_check_1
+            } // end for (int k
+// CHECK-LABEL: %for_merge_1 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge_1
+        } // end elsek
+// CHECK-LABEL: %if_merge_1 = OpLabel
+// CHECK-NEXT: OpBranch %for_continue
+
+// CHECK-LABEL: %for_continue = OpLabel
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[inci:%\d+]] = OpIAdd %int [[i1]] %int_1
+// CHECK-NEXT: OpStore %i [[inci]]
+// CHECK-NEXT: OpBranch %for_check
+    } // end for (int i
+// CHECK-LABEL: %for_merge = OpLabel
+
+    // if-stmt following for-stmt
+// CHECK-NEXT: [[color3:%\d+]] = OpLoad %float [[color]]
+// CHECK-NEXT: [[lt7:%\d+]] = OpFOrdLessThan %bool [[color3]] %float_0_9
+// CHECK-NEXT: OpSelectionMerge %if_merge_3 None
+// CHECK-NEXT: OpBranchConditional [[lt7]] %if_true_3 %if_merge_3
+    if (color < 0.9) {
+// CHECK-LABEL: %if_true_3 = OpLabel
+// CHECK-NEXT: [[val7:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[add6:%\d+]] = OpFAdd %float [[val7]] %float_6
+// CHECK-NEXT: OpStore %val [[add6]]
+        val = val + 6.;
+// CHECK-NEXT: OpBranch %if_merge_3
+    }
+// CHECK-LABEL: %if_merge_3 = OpLabel
+
+// CHECK-NEXT: [[comp0:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[comp1:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[comp2:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[comp3:%\d+]] = OpLoad %float %val
+// CHECK-NEXT: [[ret:%\d+]] = OpCompositeConstruct %v4float [[comp0]] [[comp1]] [[comp2]] [[comp3]]
+// CHECK-NEXT: OpStore [[target]] [[ret]]
+// CHECK-NEXT: OpReturn
+    return float4(val, val, val, val);
+// CHECK-NEXT: OpFunctionEnd
+}

+ 82 - 0
tools/clang/test/CodeGenSPIRV/for-stmt.nested.hlsl

@@ -0,0 +1,82 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+// CHECK-NEXT: %val = OpVariable %_ptr_Function_int Function %int_0
+    int val = 0;
+
+// CHECK-NEXT: %i = OpVariable %_ptr_Function_int Function %int_0
+// CHECK-NEXT: %j = OpVariable %_ptr_Function_int Function %int_0
+// CHECK-NEXT: %k = OpVariable %_ptr_Function_int Function %int_0
+// CHECK-NEXT: OpBranch %for_check
+
+// CHECK-LABEL: %for_check = OpLabel
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[lt0:%\d+]] = OpSLessThan %bool [[i0]] %int_10
+// CHECK-NEXT: OpLoopMerge %for_merge %for_continue None
+// CHECK-NEXT: OpBranchConditional [[lt0]] %for_body %for_merge
+    for (int i = 0; i < 10; ++i) {
+// CHECK-LABEL: %for_body = OpLabel
+// CHECK-NEXT: [[val0:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[add0:%\d+]] = OpIAdd %int [[val0]] [[i1]]
+// CHECK-NEXT: OpStore %val [[add0]]
+        val = val + i;
+// CHECK-NEXT: OpBranch %for_check_0
+
+// CHECK-LABEL: %for_check_0 = OpLabel
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[lt1:%\d+]] = OpSLessThan %bool [[j0]] %int_10
+// CHECK-NEXT: OpLoopMerge %for_merge_0 %for_continue_0 None
+// CHECK-NEXT: OpBranchConditional [[lt1]] %for_body_0 %for_merge_0
+        for (int j = 0; j < 10; ++j) {
+// CHECK-LABEL: %for_body_0 = OpLabel
+// CHECK-NEXT: OpBranch %for_check_1
+
+// CHECK-LABEL: %for_check_1 = OpLabel
+// CHECK-NEXT: [[k0:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[lt2:%\d+]] = OpSLessThan %bool [[k0]] %int_10
+// CHECK-NEXT: OpLoopMerge %for_merge_1 %for_continue_1 None
+// CHECK-NEXT: OpBranchConditional [[lt2]] %for_body_1 %for_merge_1
+            for (int k = 0; k < 10; ++k) {
+// CHECK-LABEL: %for_body_1 = OpLabel
+// CHECK-NEXT: [[val1:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[k1:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[add1:%\d+]] = OpIAdd %int [[val1]] [[k1]]
+// CHECK-NEXT: OpStore %val [[add1]]
+// CHECK-NEXT: OpBranch %for_continue_1
+                val = val + k;
+
+// CHECK-LABEL: %for_continue_1 = OpLabel
+// CHECK-NEXT: [[k2:%\d+]] = OpLoad %int %k
+// CHECK-NEXT: [[add2:%\d+]] = OpIAdd %int [[k2]] %int_1
+// CHECK-NEXT: OpStore %k [[add2]]
+// CHECK-NEXT: OpBranch %for_check_1
+            }
+
+// CHECK-LABEL: %for_merge_1 = OpLabel
+// CHECK-NEXT: [[val2:%\d+]] = OpLoad %int %val
+// CHECK-NEXT: [[mul0:%\d+]] = OpIMul %int [[val2]] %int_2
+// CHECK-NEXT: OpStore %val [[mul0]]
+// CHECK-NEXT: OpBranch %for_continue_0
+            val = val * 2;
+
+// CHECK-LABEL: %for_continue_0 = OpLabel
+// CHECK-NEXT: [[j1:%\d+]] = OpLoad %int %j
+// CHECK-NEXT: [[add3:%\d+]] = OpIAdd %int [[j1]] %int_1
+// CHECK-NEXT: OpStore %j [[add3]]
+// CHECK-NEXT: OpBranch %for_check_0
+        }
+// CHECK-LABEL: %for_merge_0 = OpLabel
+// CHECK-NEXT: OpBranch %for_continue
+
+// CHECK-LABEL: %for_continue = OpLabel
+// CHECK-NEXT: [[i2:%\d+]] = OpLoad %int %i
+// CHECK-NEXT: [[add4:%\d+]] = OpIAdd %int [[i2]] %int_1
+// CHECK-NEXT: OpStore %i [[add4]]
+// CHECK-NEXT: OpBranch %for_check
+    }
+
+// CHECK-LABEL: %for_merge = OpLabel
+// CHECK-NEXT: OpReturn
+}

+ 15 - 15
tools/clang/test/CodeGenSPIRV/if-stmt.nested.hlsl

@@ -10,20 +10,10 @@ void main() {
 // CHECK-NEXT: OpBranchConditional [[c1]] %if_true %if_false
     if (c1) {
 // CHECK-LABEL: %if_true = OpLabel
+
 // CHECK-NEXT: [[c2:%\d+]] = OpLoad %bool %c2
 // CHECK-NEXT: OpSelectionMerge %if_merge_0 None
 // CHECK-NEXT: OpBranchConditional [[c2]] %if_true_0 %if_merge_0
-
-// TODO: Move this basic block to the else branch
-// CHECK-LABEL: %if_false = OpLabel
-// CHECK-NEXT: [[c3:%\d+]] = OpLoad %bool %c3
-// CHECK-NEXT: OpSelectionMerge %if_merge_1 None
-// CHECK-NEXT: OpBranchConditional [[c3]] %if_true_1 %if_false_0
-
-// TODO: Move this basic block to the end
-// CHECK-LABEL: %if_merge = OpLabel
-// CHECK-NEXT: OpReturn
-
         if (c2)
 // CHECK-LABEL: %if_true_0 = OpLabel
 // CHECK-NEXT: OpStore %val %int_1
@@ -33,28 +23,38 @@ void main() {
 // CHECK-LABEL: %if_merge_0 = OpLabel
 // CHECK-NEXT: OpBranch %if_merge
     } else {
+// CHECK-LABEL: %if_false = OpLabel
+
+// CHECK-NEXT: [[c3:%\d+]] = OpLoad %bool %c3
+// CHECK-NEXT: OpSelectionMerge %if_merge_1 None
+// CHECK-NEXT: OpBranchConditional [[c3]] %if_true_1 %if_false_0
         if (c3) {
 // CHECK-LABEL: %if_true_1 = OpLabel
+
 // CHECK-NEXT: OpStore %val %int_2
 // CHECK-NEXT: OpBranch %if_merge_1
             val = 2;
         } else {
 // CHECK-LABEL: %if_false_0 = OpLabel
+
 // CHECK-NEXT: [[c4:%\d+]] = OpLoad %bool %c4
 // CHECK-NEXT: OpSelectionMerge %if_merge_2 None
 // CHECK-NEXT: OpBranchConditional [[c4]] %if_true_2 %if_merge_2
-
-// TODO: Make this basic block the second to last one
-// CHECK-LABEL: %if_merge_1 = OpLabel
-// CHECK-NEXT: OpBranch %if_merge
             if (c4) {
 // CHECK-LABEL: %if_true_2 = OpLabel
 // CHECK-NEXT: OpStore %val %int_3
 // CHECK-NEXT: OpBranch %if_merge_2
                 val = 3;
             }
+
 // CHECK-LABEL: %if_merge_2 = OpLabel
 // CHECK-NEXT: OpBranch %if_merge_1
         }
+
+// CHECK-LABEL: %if_merge_1 = OpLabel
+// CHECK-NEXT: OpBranch %if_merge
     }
+
+// CHECK-LABEL: %if_merge = OpLabel
+// CHECK-NEXT: OpReturn
 }

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -58,4 +58,8 @@ TEST_F(FileTest, IfStmtNestedIfStmt) { runFileTest("if-stmt.nested.hlsl"); }
 
 TEST_F(FileTest, ForStmtPlainAssign) { runFileTest("for-stmt.plain.hlsl"); }
 
+TEST_F(FileTest, ForStmtNestedForStmt) { runFileTest("for-stmt.nested.hlsl"); }
+
+TEST_F(FileTest, ControlFlowNestedIfForStmt) { runFileTest("cf.if.for.hlsl"); }
+
 } // namespace

+ 14 - 0
tools/clang/unittests/SPIRV/ConstantTest.cpp

@@ -262,4 +262,18 @@ TEST(Constant, DecoratedSpecComposite) {
   EXPECT_THAT(c->getDecorations(), ElementsAre(d));
 }
 
+TEST(Constant, ConstantsWithSameBitPatternButDifferentTypeIdAreNotEqual) {
+  SPIRVContext ctx;
+
+  const Constant *int1 = Constant::getInt32(ctx, /*type_id*/ 1, 0);
+  const Constant *uint1 = Constant::getUint32(ctx, /*type_id*/ 2, 0);
+  const Constant *float1 = Constant::getFloat32(ctx, /*type_id*/ 3, 0);
+  const Constant *anotherInt1 = Constant::getInt32(ctx, /*type_id*/ 4, 0);
+
+  EXPECT_FALSE(*int1 == *uint1);
+  EXPECT_FALSE(*int1 == *float1);
+  EXPECT_FALSE(*uint1 == *float1);
+  EXPECT_FALSE(*int1 == *anotherInt1);
+}
+
 } // anonymous namespace

+ 18 - 0
tools/clang/unittests/SPIRV/SPIRVContextTest.cpp

@@ -75,6 +75,24 @@ TEST(SPIRVContext, UniqueIdForUniqueAggregateType) {
   EXPECT_EQ(struct_1_id, struct_2_id);
 }
 
+TEST(SPIRVContext, UniqueIdForUniqueConstants) {
+  SPIRVContext ctx;
+
+  const Constant *int1 = Constant::getInt32(ctx, /*type_id*/ 1, /*value*/ 0);
+  const Constant *uint1 = Constant::getUint32(ctx, 2, 0);
+  const Constant *float1 = Constant::getFloat32(ctx, 3, 0);
+  const Constant *anotherInt1 = Constant::getInt32(ctx, /*type_id*/ 4, 0);
+
+  const uint32_t int1Id = ctx.getResultIdForConstant(int1);
+  const uint32_t uint1Id = ctx.getResultIdForConstant(uint1);
+  const uint32_t float1Id = ctx.getResultIdForConstant(float1);
+  const uint32_t anotherInt1Id = ctx.getResultIdForConstant(anotherInt1);
+
+  EXPECT_NE(int1Id, uint1Id);
+  EXPECT_NE(int1Id, float1Id);
+  EXPECT_NE(uint1Id, float1Id);
+  EXPECT_NE(int1Id, anotherInt1Id);
+}
 // TODO: Add more SPIRVContext tests
 
 } // anonymous namespace