2
0
Эх сурвалжийг харах

Structurize control flow for functions which has multiple returns. (#2968)

* Structuize control flow for functions which has multiple returns.
Xiang Li 5 жил өмнө
parent
commit
136f2e7989

+ 1 - 0
include/dxc/Support/HLSLOptions.h

@@ -164,6 +164,7 @@ public:
   bool UseHexLiterals = false; // OPT_Lx
   bool UseInstructionByteOffsets = false; // OPT_No
   bool UseInstructionNumbers = false; // OPT_Ni
+  bool StructurizeReturns = false;      // OPT_structurize_returns
   bool NotUseLegacyCBufLoad = false;  // OPT_no_legacy_cbuf_layout
   bool PackPrefixStable = false;  // OPT_pack_prefix_stable
   bool PackOptimized = false;  // OPT_pack_optimized

+ 2 - 0
include/dxc/Support/HLSLOptions.td

@@ -230,6 +230,8 @@ def flegacy_macro_expansion : Flag<["-", "/"], "flegacy-macro-expansion">, Group
     HelpText<"Expand the operands before performing token-pasting operation (fxc behavior)">;
 def flegacy_resource_reservation : Flag<["-", "/"], "flegacy-resource-reservation">, Group<hlslcomp_Group>, Flags<[CoreOption, DriverOption]>,
     HelpText<"Reserve unused explicit register assignments for compatibility with shader model 5.0 and below">;
+def structurize_returns : Flag<["-", "/"], "structurize-returns">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
+  HelpText<"structurize return control flow for functions with multiple returns.">;
 def no_legacy_cbuf_layout : Flag<["-", "/"], "no-legacy-cbuf-layout">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
   HelpText<"Do not use legacy cbuffer load">;
 def not_use_legacy_cbuf_load_ : Flag<["-", "/"], "not_use_legacy_cbuf_load">, Group<hlslcomp_Group>, Flags<[CoreOption, HelpHidden]>,

+ 1 - 0
lib/DxcSupport/HLSLOptions.cpp

@@ -596,6 +596,7 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
   opts.DefaultRowMajor = Args.hasFlag(OPT_Zpr, OPT_INVALID, false);
   opts.DefaultColMajor = Args.hasFlag(OPT_Zpc, OPT_INVALID, false);
   opts.DumpBin = Args.hasFlag(OPT_dumpbin, OPT_INVALID, false);
+  opts.StructurizeReturns = Args.hasFlag(OPT_structurize_returns, OPT_INVALID, false);
   opts.NotUseLegacyCBufLoad = Args.hasFlag(OPT_no_legacy_cbuf_layout, OPT_INVALID, false);
   opts.NotUseLegacyCBufLoad = Args.hasFlag(OPT_not_use_legacy_cbuf_load_, OPT_INVALID, opts.NotUseLegacyCBufLoad);
   opts.PackPrefixStable = Args.hasFlag(OPT_pack_prefix_stable, OPT_INVALID, false);

+ 2 - 0
tools/clang/include/clang/Frontend/CodeGenOptions.h

@@ -214,6 +214,8 @@ public:
   std::vector<std::string> HLSLLibraryExports;
   /// ExportShadersOnly limits library export functions to shaders
   bool ExportShadersOnly = false;
+  /// Structurize control flow for function has multiple returns.
+  bool HLSLStructurizeReturns = false;
   /// DefaultLinkage Internal, External, or Default.  If Default, default
   /// function linkage is determined by library target.
   hlsl::DXIL::DefaultLinkage DefaultLinkage = hlsl::DXIL::DefaultLinkage::Default;

+ 60 - 2
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -211,6 +211,8 @@ private:
   std::unordered_map<Constant*, DxilFieldAnnotation> m_ConstVarAnnotationMap;
   StringSet<> m_PreciseOutputSet;
 
+  DenseMap<Function*, ScopeInfo> m_ScopeMap;
+  ScopeInfo *GetScopeInfo(Function *F);
 public:
   CGMSHLSLRuntime(CodeGenModule &CGM);
 
@@ -283,7 +285,14 @@ public:
   void MarkRetTemp(CodeGenFunction &CGF, llvm::Value *V,
                   clang::QualType QaulTy) override;
   void FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D, llvm::Value *V) override;
-
+  void MarkThenStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) override;
+  void MarkElseStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) override;
+  void MarkSwitchStmt(CodeGenFunction &CGF, SwitchInst *switchInst,
+                      BasicBlock *endSwitch) override;
+  void MarkReturnStmt(CodeGenFunction &CGF, BasicBlock *bbWithRet) override;
+  void MarkLoopStmt(CodeGenFunction &CGF, BasicBlock *loopContinue,
+                     BasicBlock *loopExit) override;
+  void MarkScopeEnd(CodeGenFunction &CGF) override;
   /// Get or add constant to the program
   HLCBuffer &GetOrCreateCBuffer(HLSLBufferDecl *D);
 };
@@ -2156,6 +2165,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   for (const auto &Attr : FD->specific_attrs<HLSLExperimentalAttr>()) {
     F->addFnAttr(Twine("exp-", Attr->getName()).str(), Attr->getValue());
   }
+
+  m_ScopeMap[F] = ScopeInfo(F);
 }
 
 void CGMSHLSLRuntime::RemapObsoleteSemantic(DxilParameterAnnotation &paramInfo, bool isPatchConstantFunction) {
@@ -3279,10 +3290,14 @@ HLCBuffer &CGMSHLSLRuntime::GetOrCreateCBuffer(HLSLBufferDecl *D) {
 void CGMSHLSLRuntime::FinishCodeGen() {
   HLModule &HLM = *m_pHLModule;
   llvm::Module &M = TheModule;
-
   // Do this before CloneShaderEntry and TranslateRayQueryConstructor to avoid
   // update valToResPropertiesMap for cloned inst.
   FinishIntrinsics(HLM, m_IntrinsicMap, valToResPropertiesMap);
+  bool bWaveEnabledStage = m_pHLModule->GetShaderModel()->IsPS() ||
+                           m_pHLModule->GetShaderModel()->IsCS() ||
+                           m_pHLModule->GetShaderModel()->IsLib();
+  if (CGM.getCodeGenOpts().HLSLStructurizeReturns)
+    StructurizeMultiRet(M, m_ScopeMap, bWaveEnabledStage, m_DxBreaks);
 
   FinishEntries(HLM, Entry, CGM, entryFunctionMap, HSEntryPatchConstantFuncAttr,
                 patchConstantFunctionMap, patchConstantFunctionPropsMap);
@@ -5655,6 +5670,49 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
   }
 }
 
+ScopeInfo *CGMSHLSLRuntime::GetScopeInfo(Function *F) {
+  auto it = m_ScopeMap.find(F);
+  if (it == m_ScopeMap.end())
+    return nullptr;
+  return &it->second;
+}
+
+void CGMSHLSLRuntime::MarkThenStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) {
+  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
+    Scope->AddThen(endIfBB);
+}
+
+void CGMSHLSLRuntime::MarkElseStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) {
+  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
+    Scope->AddElse(endIfBB);
+}
+
+void CGMSHLSLRuntime::MarkSwitchStmt(CodeGenFunction &CGF,
+                                     SwitchInst *switchInst,
+                                     BasicBlock *endSwitch) {
+
+  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
+    Scope->AddSwitch(endSwitch);
+}
+
+void CGMSHLSLRuntime::MarkReturnStmt(CodeGenFunction &CGF,
+                                     BasicBlock *bbWithRet) {
+  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
+    Scope->AddRet(bbWithRet);
+}
+
+void CGMSHLSLRuntime::MarkLoopStmt(CodeGenFunction &CGF,
+                                   BasicBlock *loopContinue,
+                                   BasicBlock *loopExit) {
+  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
+    Scope->AddLoop(loopContinue, loopExit);
+}
+
+void CGMSHLSLRuntime::MarkScopeEnd(CodeGenFunction &CGF) {
+  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
+    Scope->EndScope();
+}
+
 CGHLSLRuntime *CodeGen::CreateMSHLSLRuntime(CodeGenModule &CGM) {
   return new CGMSHLSLRuntime(CGM);
 }

+ 207 - 3
tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

@@ -21,6 +21,7 @@
 #include "llvm/Analysis/DxilValueCache.h"
 #include "llvm/Transforms/Utils/ValueMapper.h"
 #include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/IR/CFG.h"
 
 #include "CodeGenModule.h"
 #include "clang/Frontend/CodeGenOptions.h"
@@ -2214,9 +2215,9 @@ void ProcessCtorFunctions(llvm::Module &M, StringRef globalName,
   GV->eraseFromParent();
 }
 
-void FinishCBuffer(HLModule &HLM, llvm::Type *CBufferType,
-                   std::unordered_map<Constant *, DxilFieldAnnotation>
-                       &constVarAnnotationMap) {
+void FinishCBuffer(
+    HLModule &HLM, llvm::Type *CBufferType,
+    std::unordered_map<Constant *, DxilFieldAnnotation> &constVarAnnotationMap) {
   // Allocate constant buffers.
   AllocateDxilConstantBuffers(HLM, constVarAnnotationMap);
   // TODO: create temp variable for constant which has store use.
@@ -2661,3 +2662,206 @@ void AddDxBreak(Module &M, const SmallVector<llvm::BranchInst*, 16> &DxBreaks) {
 }
 
 }
+
+namespace CGHLSLMSHelper {
+ScopeInfo::ScopeInfo(Function *F) {
+  Scope FuncScope;
+  FuncScope.kind = Scope::ScopeKind::FunctionScope;
+  FuncScope.EndScopeBB = nullptr;
+  scopes.emplace_back(FuncScope);
+  scopeStack.emplace_back(0);
+}
+
+void ScopeInfo::AddScope(Scope::ScopeKind k, BasicBlock *endScopeBB) {
+  Scope Scope;
+  Scope.kind = k;
+  Scope.EndScopeBB = endScopeBB;
+  Scope.parentScopeIndex = scopeStack.back();
+  scopeStack.emplace_back(scopes.size());
+  scopes.emplace_back(Scope);
+}
+
+void ScopeInfo::AddThen(BasicBlock *endIfBB) {
+  AddScope(Scope::ScopeKind::ThenScope, endIfBB);
+}
+
+void ScopeInfo::AddElse(BasicBlock *endIfBB) {
+  AddScope(Scope::ScopeKind::ElseScope, endIfBB);
+}
+
+void ScopeInfo::AddSwitch(BasicBlock *endSwitch) {
+  AddScope(Scope::ScopeKind::SwitchScope, endSwitch);
+}
+
+void ScopeInfo::AddLoop(BasicBlock *loopContinue, BasicBlock *endLoop) {
+  AddScope(Scope::ScopeKind::LoopScope, endLoop);
+  scopes.back().loopContinueBB = loopContinue;
+}
+
+void ScopeInfo::AddRet(BasicBlock *bbWithRet) {
+  Scope RetScope;
+  RetScope.kind = Scope::ScopeKind::ReturnScope;
+  RetScope.EndScopeBB = bbWithRet;
+  RetScope.parentScopeIndex = scopeStack.back();
+  // save retScope to rets.
+  rets.emplace_back(scopes.size());
+  scopes.emplace_back(RetScope);
+  // Don't need to put retScope to stack since it cannot nested other scopes.
+}
+
+void ScopeInfo::EndScope() { scopeStack.pop_back(); }
+
+Scope &ScopeInfo::GetScope(unsigned i) { return scopes[i]; }
+
+} // namespace CGHLSLMSHelper
+
+namespace {
+BasicBlock *createBlockBefore(BasicBlock *BB) {
+  BasicBlock *InsertBefore =
+      std::next(Function::iterator(BB)).getNodePtrUnchecked();
+  BasicBlock *New =
+      BasicBlock::Create(BB->getContext(), "", BB->getParent(), InsertBefore);
+  return New;
+}
+// For functions has multiple returns like
+// float foo(float a, float b, float c) {
+//   float r = c;
+//   if (a > 0) {
+//      if (b > 0) {
+//        return -1;
+//      }
+//      ***
+//   }
+//   ...
+//   return r;
+// }
+// transform into
+// float foo(float a, float b, float c) {
+//   bool bRet = false;
+//   float retV;
+//   float r = c;
+//   if (a > 0) {
+//      if (b > 0) {
+//        bRet = true;
+//        retV = -1;
+//      }
+//      if (!bRet) {
+//        ***
+//      }
+//   }
+//   if (!bRet) {
+//     ...
+//     retV = r;
+//   }
+//   return vRet;
+// }
+void StructurizeMultiRetFunction(
+    Function *F, ScopeInfo &ScopeInfo, bool bWaveEnabledStage,
+    SmallVector<BranchInst *, 16> &DxBreaks) {
+  // Get bbWithRets.
+  auto &rets = ScopeInfo.GetRetScopes();
+  if (rets.size() < 2)
+    return;
+
+  IRBuilder<> B(F->getEntryBlock().begin());
+  Type *boolTy = Type::getInt1Ty(F->getContext());
+
+  Scope &FunctionScope = ScopeInfo.GetScope(0);
+
+  // bool bIsReturned = false;
+  AllocaInst *bIsReturned = B.CreateAlloca(boolTy, nullptr, "bReturned");
+  B.CreateStore(ConstantInt::get(boolTy, 0), bIsReturned);
+  Constant *cTrue = ConstantInt::get(boolTy, 1);
+
+  for (unsigned scopeIndex : rets) {
+    Scope &retScope = ScopeInfo.GetScope(scopeIndex);
+    Scope &curScope = ScopeInfo.GetScope(retScope.parentScopeIndex);
+    // skip ret not in nested control flow.
+    if (curScope.kind == Scope::ScopeKind::FunctionScope)
+      continue;
+
+    BasicBlock *retBB = retScope.EndScopeBB;
+    Instruction *RetInst = retBB->getTerminator();
+    IRBuilder<> B(retBB->begin());
+    // bIsReturned = true;
+    B.CreateStore(cTrue, bIsReturned);
+
+    BranchInst *Br = cast<BranchInst>(RetInst);
+    if (!FunctionScope.EndScopeBB) {
+      FunctionScope.EndScopeBB = Br->getSuccessor(0);
+    }
+    // First level just branch to parentScope.
+    BasicBlock *endScopeBB = curScope.EndScopeBB;
+    Br->setSuccessor(0, endScopeBB);
+
+    // Guard parent scopes with bReturned until finish function scope.
+    Scope &parentScope = ScopeInfo.GetScope(curScope.parentScopeIndex);
+    while (true) {
+      BasicBlock *BB = curScope.EndScopeBB;
+      BasicBlock *EndBB = parentScope.EndScopeBB;
+      switch (parentScope.kind) {
+      case Scope::ScopeKind::FunctionScope:
+      case Scope::ScopeKind::ThenScope:
+      case Scope::ScopeKind::ElseScope: {
+        // inside if.
+        // if (!bReturned) {
+        //   rest of if or else.
+        // }
+        BasicBlock *CmpBB =
+            BasicBlock::Create(BB->getContext(), "bReturned.cmp.false", F, BB);
+        // Make BB preds go to cmpBB.
+        BB->replaceAllUsesWith(CmpBB);
+        IRBuilder<> B(CmpBB);
+        Value *isRetured = B.CreateLoad(bIsReturned, "bReturned.load");
+        Value *notRetunred = B.CreateNot(isRetured, "bReturned.not");
+        B.CreateCondBr(notRetunred, BB, EndBB);
+      } break;
+      default: {
+        // inside switch/loop
+        // if (bReturned) {
+        //   br endOfScope.
+        // }
+        BasicBlock *CmpBB =
+            BasicBlock::Create(BB->getContext(), "bReturned.cmp.true", F, BB);
+        BasicBlock *BreakBB =
+            BasicBlock::Create(BB->getContext(), "bReturned.break", F, BB);
+        BB->replaceAllUsesWith(CmpBB);
+        IRBuilder<> B(CmpBB);
+        Value *isRetured = B.CreateLoad(bIsReturned, "bReturned.load");
+        B.CreateCondBr(isRetured, BreakBB, BB);
+
+        B.SetInsertPoint(BreakBB);
+        if (bWaveEnabledStage &&
+            parentScope.kind == Scope::ScopeKind::LoopScope) {
+          BranchInst *BI =
+              B.CreateCondBr(cTrue, EndBB, parentScope.loopContinueBB);
+          DxBreaks.emplace_back(BI);
+        } else {
+          B.CreateBr(EndBB);
+        }
+      } break;
+      }
+
+      curScope = ScopeInfo.GetScope(curScope.parentScopeIndex);
+      // break after done with function scope.
+      if (curScope.kind == Scope::ScopeKind::FunctionScope)
+        break;
+      parentScope = ScopeInfo.GetScope(parentScope.parentScopeIndex);
+    }
+  }
+}
+} // namespace
+
+namespace CGHLSLMSHelper {
+void StructurizeMultiRet(Module &M, DenseMap<Function *, ScopeInfo> &ScopeMap,
+                         bool bWaveEnabledStage,
+                         SmallVector<BranchInst *, 16> &DxBreaks) {
+  for (Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+    auto it = ScopeMap.find(&F);
+    DXASSERT(it != ScopeMap.end(), "cannot find scope info");
+    StructurizeMultiRetFunction(&F, it->second, bWaveEnabledStage, DxBreaks);
+  }
+}
+} // namespace CGHLSLMSHelper

+ 43 - 0
tools/clang/lib/CodeGen/CGHLSLMSHelper.h

@@ -26,6 +26,7 @@ class DebugLoc;
 class Constant;
 class GlobalVariable;
 class CallInst;
+class Instruction;
 template <typename T, unsigned N> class SmallVector;
 }
 
@@ -74,6 +75,43 @@ private:
       constants; // constants inside const buffer
 };
 
+// Scope to help transform multiple returns.
+struct Scope {
+ enum class ScopeKind {
+   ThenScope,
+   ElseScope,
+   SwitchScope,
+   LoopScope,
+   ReturnScope,
+   FunctionScope,
+ };
+ ScopeKind kind;
+ llvm::BasicBlock *EndScopeBB;
+ // Save loopContinueBB to create dxBreak.
+ llvm::BasicBlock *loopContinueBB;
+ unsigned parentScopeIndex;
+};
+
+class ScopeInfo {
+public:
+  ScopeInfo(){}
+  ScopeInfo(llvm::Function *F);
+  void AddThen(llvm::BasicBlock *endIfBB);
+  void AddElse(llvm::BasicBlock *endIfBB);
+  void AddSwitch(llvm::BasicBlock *endSwitchBB);
+  void AddLoop(llvm::BasicBlock *loopContinue, llvm::BasicBlock *endLoopBB);
+  void AddRet(llvm::BasicBlock *bbWithRet);
+  void EndScope();
+  Scope &GetScope(unsigned i);
+  const llvm::SmallVector<unsigned, 2> &GetRetScopes() { return rets; }
+private:
+  void AddScope(Scope::ScopeKind k, llvm::BasicBlock *endScopeBB);
+  llvm::SmallVector<unsigned, 2> rets;
+  llvm::SmallVector<unsigned, 8> scopeStack;
+  // save all scopes.
+  llvm::SmallVector<Scope, 16> scopes;
+};
+
 // Align cbuffer offset in legacy mode (16 bytes per row).
 unsigned AlignBufferOffsetInLegacy(unsigned offset, unsigned size,
                                    unsigned scalarSizeInBytes,
@@ -128,6 +166,11 @@ void UpdateLinkage(
     llvm::StringMap<EntryFunctionInfo> &entryFunctionMap,
     llvm::StringMap<PatchConstantInfo> &patchConstantFunctionMap);
 
+void StructurizeMultiRet(llvm::Module &M,
+                         llvm::DenseMap<llvm::Function *, ScopeInfo> &ScopeMap,
+                         bool bWaveEnabledStage,
+                         llvm::SmallVector<llvm::BranchInst *, 16> &DxBreaks);
+
 llvm::Value *TryEvalIntrinsic(llvm::CallInst *CI, hlsl::IntrinsicOp intriOp);
 void SimpleTransformForHLDXIR(llvm::Module *pM);
 void ExtensionCodeGen(hlsl::HLModule &HLM, clang::CodeGen::CodeGenModule &CGM);

+ 16 - 1
tools/clang/lib/CodeGen/CGHLSLRuntime.h

@@ -23,6 +23,7 @@ class GlobalVariable;
 class Type;
 class BasicBlock;
 class BranchInst;
+class SwitchInst;
 template <typename T> class ArrayRef;
 }
 
@@ -127,7 +128,21 @@ public:
   
   virtual void AddControlFlowHint(CodeGenFunction &CGF, const Stmt &S, llvm::TerminatorInst *TI, llvm::ArrayRef<const Attr *> Attrs) = 0;
 
-  virtual void FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D, llvm::Value *V) = 0;
+  virtual void FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D,
+                             llvm::Value *V) = 0;
+  virtual void MarkThenStmt(CodeGenFunction &CGF,
+                            llvm::BasicBlock *endIfBB) = 0;
+  virtual void MarkElseStmt(CodeGenFunction &CGF,
+                            llvm::BasicBlock *endIfBB) = 0;
+  virtual void MarkSwitchStmt(CodeGenFunction &CGF,
+                              llvm::SwitchInst *switchInst,
+                              llvm::BasicBlock *endSwitch) = 0;
+  virtual void MarkReturnStmt(CodeGenFunction &CGF,
+                              llvm::BasicBlock *bbWithRet) = 0;
+  virtual void MarkLoopStmt(CodeGenFunction &CGF,
+                             llvm::BasicBlock *loopContinue,
+                             llvm::BasicBlock *loopExit) = 0;
+  virtual void MarkScopeEnd(CodeGenFunction &CGF) = 0;
 };
 
 /// Create an instance of a HLSL runtime class.

+ 40 - 3
tools/clang/lib/CodeGen/CGStmt.cpp

@@ -600,6 +600,7 @@ void CodeGenFunction::EmitIfStmt(const IfStmt &S,
   llvm::TerminatorInst *TI =
       cast<llvm::TerminatorInst>(*ThenBlock->user_begin());
   CGM.getHLSLRuntime().AddControlFlowHint(*this, S, TI, Attrs);
+  CGM.getHLSLRuntime().MarkThenStmt(*this, ContBlock);
   // HLSL Change Ends
 
   // Emit the 'then' code.
@@ -611,8 +612,15 @@ void CodeGenFunction::EmitIfStmt(const IfStmt &S,
   }
   EmitBranch(ContBlock);
 
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
+
   // Emit the 'else' code if present.
   if (const Stmt *Else = S.getElse()) {
+    // HLSL Change Begin.
+    CGM.getHLSLRuntime().MarkElseStmt(*this, ContBlock);
+    // HLSL Change End.
     {
       // There is no need to emit line number for an unconditional branch.
       auto NL = ApplyDebugLocation::CreateEmpty(*this);
@@ -627,8 +635,11 @@ void CodeGenFunction::EmitIfStmt(const IfStmt &S,
       auto NL = ApplyDebugLocation::CreateEmpty(*this);
       EmitBranch(ContBlock);
     }
-  }
 
+    // HLSL Change Begin.
+    CGM.getHLSLRuntime().MarkScopeEnd(*this);
+    // HLSL Change End.
+  }
   // Emit the continuation block for code after the if.
   EmitBlock(ContBlock, true);
 }
@@ -748,6 +759,11 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   // Store the blocks to use for break and continue.
   BreakContinueStack.push_back(BreakContinue(LoopExit, LoopHeader));
 
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkLoopStmt(*this, LoopHeader.getBlock(),
+                                     LoopExit.getBlock());
+  // HLSL Change End.
+
   // C++ [stmt.while]p2:
   //   When the condition of a while statement is a declaration, the
   //   scope of the variable that is declared extends from its point
@@ -818,6 +834,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   // a branch, try to erase it.
   if (!EmitBoolCondBranch)
     SimplifyForwardingBlocks(LoopHeader.getBlock());
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
 }
 
 void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -829,6 +848,10 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
 
   // Store the blocks to use for break and continue.
   BreakContinueStack.push_back(BreakContinue(LoopExit, LoopCond));
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkLoopStmt(*this, LoopCond.getBlock(),
+                                  LoopExit.getBlock());
+  // HLSL Change End.
 
   // Emit the body of the loop.
   llvm::BasicBlock *LoopBody = createBasicBlock("do.body");
@@ -880,6 +903,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   // emitting a branch, try to erase it.
   if (!EmitBoolCondBranch)
     SimplifyForwardingBlocks(LoopCond.getBlock());
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
 }
 
 void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -891,7 +917,6 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   // Evaluate the first part before the loop.
   if (S.getInit())
     EmitStmt(S.getInit());
-
   // Start the loop with a block that tests the condition.
   // If there's an increment, the continue scope will be overwritten
   // later.
@@ -911,6 +936,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   // Store the blocks to use for break and continue.
   BreakContinueStack.push_back(BreakContinue(LoopExit, Continue));
 
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkLoopStmt(*this, Continue.getBlock(), LoopExit.getBlock());
+  // HLSL Change End.
   // Create a cleanup scope for the condition variable cleanups.
   LexicalScope ConditionScope(*this, S.getSourceRange());
 
@@ -978,6 +1006,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
 
   // Emit the fall-through block.
   EmitBlock(LoopExit.getBlock(), true);
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
 }
 
 void
@@ -1147,8 +1178,10 @@ void CodeGenFunction::EmitReturnStmt(const ReturnStmt &S) {
   ++NumReturnExprs;
   if (!RV || RV->isEvaluatable(getContext()))
     ++NumSimpleReturnExprs;
-
   cleanupScope.ForceCleanup();
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkReturnStmt(*this, Builder.GetInsertBlock());
+  // HLSL Change End.
   EmitBranchThroughCleanup(ReturnBlock);
 }
 
@@ -1637,6 +1670,7 @@ void CodeGenFunction::EmitSwitchStmt(const SwitchStmt &S,
   // HLSL Change Begins
   llvm::TerminatorInst *TI = cast<llvm::TerminatorInst>(SwitchInsn);
   CGM.getHLSLRuntime().AddControlFlowHint(*this, S, TI, Attrs);
+  CGM.getHLSLRuntime().MarkSwitchStmt(*this, SwitchInsn, SwitchExit.getBlock());
   // HLSL Change Ends
 
   if (PGO.haveRegionCounts()) {
@@ -1710,6 +1744,9 @@ void CodeGenFunction::EmitSwitchStmt(const SwitchStmt &S,
   SwitchInsn = SavedSwitchInsn;
   SwitchWeights = SavedSwitchWeights;
   CaseRangeBlock = SavedCRBlock;
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
 }
 
 static std::string

+ 131 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/return/multi_ret.hlsl

@@ -0,0 +1,131 @@
+// RUN: %dxc -E main -fcgl -structurize-returns -T ps_6_0 %s | FileCheck %s
+
+int i;
+
+float main(float4 a:A) : SV_Target {
+// Init bReturned.
+// CHECK:%[[bReturned:.*]] = alloca i1
+// CHECK:store i1 false, i1* %[[bReturned]]
+
+  float c = 0;
+
+// CHECK: [[label:.*]] ; preds =
+  if (i < 0) {
+// CHECK: [[label2:.*]] ; preds =
+    if (a.w > 2)
+// return inside if.
+// set bReturned to true.
+// CHECK:store i1 true, i1* %[[bReturned]]
+      return -1;
+// guard rest of scope with !bReturned
+// CHECK: [[label_bRet_cmp_false:.*]] ; preds =
+// CHECK:%[[RET:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:%[[NRET:.*]] = xor i1 %[[RET]], true
+// CHECK:br i1 %[[NRET]],
+
+// CHECK: [[label3:.*]]  ; preds =
+    c += sin(a.z);
+  }
+  else {
+// CHECK: [[else:.*]] ; preds =
+    if (a.z > 3)
+// return inside else
+// set bIsReturn to true
+// CHECK:store i1 true, i1* %[[bReturned]]
+      return -5;
+// CHECK: [[label_bRet_cmp_false2:.*]] ; preds =
+// CHECK:%[[RET2:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:%[[NRET2:.*]] = xor i1 %[[RET2]], true
+// CHECK:br i1 %[[NRET2]],
+
+// CHECK: [[label4:.*]] ; preds =
+    c *= cos(a.w);
+// guard after endif.
+// CHECK: [[label_bRet_cmp_false3:.*]] ; preds =
+// CHECK:%[[RET3:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:%[[NRET3:.*]] = xor i1 %[[RET3]], true
+// CHECK:br i1 %[[NRET3]],
+// guard after endif for else.
+// CHECK: [[label_bRet_cmp_false4:.*]] ; preds =
+// CHECK:%[[RET4:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:%[[NRET4:.*]] = xor i1 %[[RET4]], true
+// CHECK:br i1 %[[NRET4]],
+  }
+// CHECK: [[endif:.*]] ; preds =
+
+// CHECK: [[forCond:.*]]; preds =
+
+
+// CHECK: [[forBody:.*]] ; preds =
+  for (int j=0;j<i;j++) {
+    c += pow(2,j);
+// CHECK: [[if_in_loop:.*]] ; preds =
+    if (c > 10)
+// set bIsReturn to true
+// CHECK:store i1 true, i1* %[[bReturned]]
+      return -2;
+
+// CHECK: [[label_bRet_cmp_true:.*]] ; preds =
+// CHECK:%[[RET5:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:br i1 %[[RET5]],
+// CHECK: [[label_bRet_break:.*]] ; preds =
+// dxBreak
+// CHECK:br i1 true,
+// CHECK: [[endif_in_loop:.*]] ; preds =
+
+// CHECK: [[for_inc:.*]] ; preds =
+
+// Guard after loop.
+// CHECK: [[label_bRet_cmp_false5:.*]] ; preds =
+// CHECK:%[[RET6:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:%[[NRET6:.*]] = xor i1 %[[RET6]], true
+// CHECK:br i1 %[[NRET6]],
+
+  }
+// CHECK: [[for_end:.*]] ; preds =
+// CHECK:switch i32
+  switch (i) {
+// CHECK: [[case1:.*]] ; preds =
+    case 1:
+     c += log(a.x);
+     break;
+
+// CHECK: [[case2:.*]] ; preds =
+    case 2:
+
+        c += cos(a.y);
+     break;
+
+// CHECK: [[case3:.*]] ; preds =
+    case 3:
+
+// CHECK: [[if_in_switch:.*]]  ; preds =
+         if (c < 10)
+// set bIsReturn to true
+// CHECK:store i1 true, i1* %[[bReturned]]
+         return -3;
+// CHECK: [[label_bRet_cmp_true2:.*]] ; preds =
+// CHECK:%[[RET7:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:br i1 %[[RET7]],
+// CHECK: [[label_bRet_break2:.*]] ; preds =
+// normal break
+// CHECK:br label
+
+// CHECK: [[endif_in_switch:.*]] ; preds =
+       c += sin(a.x);
+     break;
+  }
+// guard code after switch.
+// CHECK: [[label_bRet_cmp_false6:.*]] ; preds =
+// CHECK:%[[RET8:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:%[[NRET8:.*]] = xor i1 %[[RET8]], true
+// CHECK:br i1 %[[NRET8]]
+
+// CHECK: [[end_switch:.*]]; preds =
+
+// CHECK: [[return:.*]] ; preds =
+// CHECK-NOT:preds
+// CHECK:ret
+
+  return c;
+}

+ 1 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -1096,6 +1096,7 @@ public:
       compiler.getCodeGenOpts().HLSLFloat32DenormMode = DXIL::Float32DenormMode::Preserve;
     }
 
+    compiler.getCodeGenOpts().HLSLStructurizeReturns = Opts.StructurizeReturns;
     if (Opts.DisableOptimizations)
       compiler.getCodeGenOpts().DisableLLVMOpts = true;