Browse Source

Support case whole scope is return. (#2971)

0. Support case whole scope is return.
1. Update EndScopeBB when 2 scopes share same EndScopeBB.
2. Initialize return value to 0 when possible to avoid undef phi.
3. One scope only need to guard once.
4. For return inside loop/switch, just break, don't go thru nested levels.
5. Generate dxBreak when ret is in loop directly.
Xiang Li 5 years ago
parent
commit
de743930c2

+ 8 - 11
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -285,8 +285,7 @@ public:
   void MarkRetTemp(CodeGenFunction &CGF, llvm::Value *V,
                   clang::QualType QaulTy) override;
   void FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D, llvm::Value *V) override;
-  void MarkThenStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) override;
-  void MarkElseStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) override;
+  void MarkIfStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) override;
   void MarkSwitchStmt(CodeGenFunction &CGF, SwitchInst *switchInst,
                       BasicBlock *endSwitch) override;
   void MarkReturnStmt(CodeGenFunction &CGF, BasicBlock *bbWithRet) override;
@@ -5677,20 +5676,15 @@ ScopeInfo *CGMSHLSLRuntime::GetScopeInfo(Function *F) {
   return &it->second;
 }
 
-void CGMSHLSLRuntime::MarkThenStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) {
+void CGMSHLSLRuntime::MarkIfStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) {
   if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
-    Scope->AddThen(endIfBB);
+    Scope->AddIf(endIfBB);
 }
 
-void CGMSHLSLRuntime::MarkElseStmt(CodeGenFunction &CGF, BasicBlock *endIfBB) {
-  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
-    Scope->AddElse(endIfBB);
-}
 
 void CGMSHLSLRuntime::MarkSwitchStmt(CodeGenFunction &CGF,
                                      SwitchInst *switchInst,
                                      BasicBlock *endSwitch) {
-
   if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
     Scope->AddSwitch(endSwitch);
 }
@@ -5709,8 +5703,11 @@ void CGMSHLSLRuntime::MarkLoopStmt(CodeGenFunction &CGF,
 }
 
 void CGMSHLSLRuntime::MarkScopeEnd(CodeGenFunction &CGF) {
-  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn))
-    Scope->EndScope();
+  if (ScopeInfo *Scope = GetScopeInfo(CGF.CurFn)) {
+    llvm::BasicBlock *CurBB = CGF.Builder.GetInsertBlock();
+    bool bScopeFinishedWithRet = !CurBB || CurBB->getTerminator();
+    Scope->EndScope(bScopeFinishedWithRet);
+  }
 }
 
 CGHLSLRuntime *CodeGen::CreateMSHLSLRuntime(CodeGenModule &CGM) {

+ 231 - 92
tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

@@ -2664,29 +2664,37 @@ void AddDxBreak(Module &M, const SmallVector<llvm::BranchInst*, 16> &DxBreaks) {
 }
 
 namespace CGHLSLMSHelper {
-ScopeInfo::ScopeInfo(Function *F) {
+
+ScopeInfo::ScopeInfo(Function *F) : maxRetLevel(0), bAllReturnsInIf(true) {
   Scope FuncScope;
   FuncScope.kind = Scope::ScopeKind::FunctionScope;
   FuncScope.EndScopeBB = nullptr;
+  FuncScope.bWholeScopeReturned = false;
+  // Make it 0 to avoid check when get parent.
+  // All loop on scopes should check kind != FunctionScope.
+  FuncScope.parentScopeIndex = 0;
   scopes.emplace_back(FuncScope);
   scopeStack.emplace_back(0);
 }
 
+// When all returns is inside if which is not nested, the flow is still
+// structurized even there're more than one return.
+bool ScopeInfo::CanSkipStructurize() {
+  return bAllReturnsInIf && maxRetLevel < 2;
+}
+
 void ScopeInfo::AddScope(Scope::ScopeKind k, BasicBlock *endScopeBB) {
   Scope Scope;
   Scope.kind = k;
+  Scope.bWholeScopeReturned = false;
   Scope.EndScopeBB = endScopeBB;
   Scope.parentScopeIndex = scopeStack.back();
   scopeStack.emplace_back(scopes.size());
   scopes.emplace_back(Scope);
 }
 
-void ScopeInfo::AddThen(BasicBlock *endIfBB) {
-  AddScope(Scope::ScopeKind::ThenScope, endIfBB);
-}
-
-void ScopeInfo::AddElse(BasicBlock *endIfBB) {
-  AddScope(Scope::ScopeKind::ElseScope, endIfBB);
+void ScopeInfo::AddIf(BasicBlock *endIfBB) {
+  AddScope(Scope::ScopeKind::IfScope, endIfBB);
 }
 
 void ScopeInfo::AddSwitch(BasicBlock *endSwitch) {
@@ -2703,26 +2711,108 @@ void ScopeInfo::AddRet(BasicBlock *bbWithRet) {
   RetScope.kind = Scope::ScopeKind::ReturnScope;
   RetScope.EndScopeBB = bbWithRet;
   RetScope.parentScopeIndex = scopeStack.back();
+  // - 1 for function scope which is at scopeStack[0].
+  unsigned retLevel = scopeStack.size() - 1;
+  // save max nested level for ret.
+  maxRetLevel = std::max<unsigned>(maxRetLevel, retLevel);
+  bool bGotLoopOrSwitch = false;
+  for (auto it = scopeStack.rbegin(); it != scopeStack.rend(); it++) {
+    unsigned idx = *it;
+    Scope &S = scopes[idx];
+    switch (S.kind) {
+    default:
+      break;
+    case Scope::ScopeKind::LoopScope:
+    case Scope::ScopeKind::SwitchScope:
+      bGotLoopOrSwitch = true;
+      // For return inside loop and switch, can just break.
+      RetScope.parentScopeIndex = idx;
+      break;
+    }
+    if (bGotLoopOrSwitch)
+      break;
+  }
+  bAllReturnsInIf &= !bGotLoopOrSwitch;
+  // return finish current scope.
+  RetScope.bWholeScopeReturned = true;
   // save retScope to rets.
   rets.emplace_back(scopes.size());
   scopes.emplace_back(RetScope);
   // Don't need to put retScope to stack since it cannot nested other scopes.
 }
 
-void ScopeInfo::EndScope() { scopeStack.pop_back(); }
+void ScopeInfo::EndScope(bool bScopeFinishedWithRet) {
+  unsigned idx = scopeStack.pop_back_val();
+  Scope &Scope = GetScope(idx);
+  // If whole stmt is finished and end scope bb has not used(nothing branch to
+  // it). Then the whole scope is returned.
+  Scope.bWholeScopeReturned =
+      bScopeFinishedWithRet && Scope.EndScopeBB->user_empty();
+}
 
 Scope &ScopeInfo::GetScope(unsigned i) { return scopes[i]; }
 
+void ScopeInfo::LegalizeWholeReturnedScope() {
+  // legalize scopes which whole scope returned.
+  // When whole scope is returned, the endScopeBB will be deleted in codeGen.
+  // Here update it to parent scope's endScope.
+  // Since the scopes are in order, so it will automatic update to the final
+  // target. A->B->C will just get A->C.
+  for (auto &S : scopes) {
+    if (S.bWholeScopeReturned && S.kind != Scope::ScopeKind::ReturnScope) {
+      S.EndScopeBB = scopes[S.parentScopeIndex].EndScopeBB;
+    }
+  }
+}
+
 } // namespace CGHLSLMSHelper
 
 namespace {
-BasicBlock *createBlockBefore(BasicBlock *BB) {
-  BasicBlock *InsertBefore =
-      std::next(Function::iterator(BB)).getNodePtrUnchecked();
-  BasicBlock *New =
-      BasicBlock::Create(BB->getContext(), "", BB->getParent(), InsertBefore);
-  return New;
+
+void updateEndScope(
+    ScopeInfo &ScopeInfo,
+    DenseMap<BasicBlock *, SmallVector<unsigned, 2>> &EndBBToScopeIndexMap,
+    BasicBlock *oldEndScope, BasicBlock *newEndScope) {
+  auto it = EndBBToScopeIndexMap.find(oldEndScope);
+  DXASSERT(it != EndBBToScopeIndexMap.end(),
+           "fail to find endScopeBB in EndBBToScopeIndexMap");
+  SmallVector<unsigned, 2> &scopeList = it->second;
+  // Don't need to update when not share endBB with other scope.
+  if (scopeList.size() < 2)
+    return;
+  for (unsigned i : scopeList) {
+    Scope &S = ScopeInfo.GetScope(i);
+    // Don't update return endBB, because that is the Block has return branch.
+    if (S.kind != Scope::ScopeKind::ReturnScope)
+      S.EndScopeBB = newEndScope;
+  }
+  EndBBToScopeIndexMap[newEndScope] = scopeList;
+}
+
+void InitRetValue(BasicBlock *exitBB) {
+  Value *RetValPtr = nullptr;
+  if (ReturnInst *RI = dyn_cast<ReturnInst>(exitBB->getTerminator())) {
+    if (Value *RetV = RI->getReturnValue()) {
+      if (LoadInst *LI = dyn_cast<LoadInst>(RetV)) {
+        RetValPtr = LI->getPointerOperand();
+      }
+    }
+  }
+  if (!RetValPtr)
+    return;
+  if (AllocaInst *RetVAlloc = dyn_cast<AllocaInst>(RetValPtr)) {
+    IRBuilder<> B(RetVAlloc->getNextNode());
+    Type *Ty = RetVAlloc->getAllocatedType();
+    Value *Init = Constant::getNullValue(Ty);
+    if (Ty->isAggregateType()) {
+      // TODO: support aggreagate type and out parameters.
+      // Skip it here will cause undef on phi which the incoming path should never hit.
+    } else {
+      B.CreateStore(Init, RetVAlloc);
+    }
+  }
 }
+
 // For functions has multiple returns like
 // float foo(float a, float b, float c) {
 //   float r = c;
@@ -2755,99 +2845,147 @@ BasicBlock *createBlockBefore(BasicBlock *BB) {
 //   }
 //   return vRet;
 // }
-void StructurizeMultiRetFunction(
-    Function *F, ScopeInfo &ScopeInfo, bool bWaveEnabledStage,
-    SmallVector<BranchInst *, 16> &DxBreaks) {
+void StructurizeMultiRetFunction(Function *F, ScopeInfo &ScopeInfo,
+                                 bool bWaveEnabledStage,
+                                 SmallVector<BranchInst *, 16> &DxBreaks) {
+  if (ScopeInfo.CanSkipStructurize())
+    return;
   // Get bbWithRets.
   auto &rets = ScopeInfo.GetRetScopes();
-  if (rets.size() < 2)
-    return;
 
   IRBuilder<> B(F->getEntryBlock().begin());
-  Type *boolTy = Type::getInt1Ty(F->getContext());
 
   Scope &FunctionScope = ScopeInfo.GetScope(0);
 
+  Type *boolTy = Type::getInt1Ty(F->getContext());
+  Constant *cTrue = ConstantInt::get(boolTy, 1);
+  Constant *cFalse = ConstantInt::get(boolTy, 0);
   // bool bIsReturned = false;
   AllocaInst *bIsReturned = B.CreateAlloca(boolTy, nullptr, "bReturned");
-  B.CreateStore(ConstantInt::get(boolTy, 0), bIsReturned);
-  Constant *cTrue = ConstantInt::get(boolTy, 1);
+  B.CreateStore(cFalse, bIsReturned);
 
-  for (unsigned scopeIndex : rets) {
-    Scope &retScope = ScopeInfo.GetScope(scopeIndex);
-    Scope &curScope = ScopeInfo.GetScope(retScope.parentScopeIndex);
+  Scope &RetScope = ScopeInfo.GetScope(rets[0]);
+  BasicBlock *exitBB = RetScope.EndScopeBB->getTerminator()->getSuccessor(0);
+  FunctionScope.EndScopeBB = exitBB;
+  // Find alloca for retunr val and init it to avoid undef after guard code with
+  // bIsReturned.
+  InitRetValue(exitBB);
+
+  ScopeInfo.LegalizeWholeReturnedScope();
+
+  // Map from endScopeBB to scope index.
+  // When 2 scopes share same endScopeBB, need to update endScopeBB after
+  // structurize.
+  DenseMap<BasicBlock *, SmallVector<unsigned, 2>> EndBBToScopeIndexMap;
+  auto &scopes = ScopeInfo.GetScopes();
+  for (unsigned i = 0; i < scopes.size(); i++) {
+    Scope &S = scopes[i];
+    EndBBToScopeIndexMap[S.EndScopeBB].emplace_back(i);
+  }
+
+  DenseSet<unsigned> guardedSet;
+
+  for (auto it = rets.begin(); it != rets.end(); it++) {
+    unsigned scopeIndex = *it;
+    Scope *pCurScope = &ScopeInfo.GetScope(scopeIndex);
+    Scope *pRetParentScope = &ScopeInfo.GetScope(pCurScope->parentScopeIndex);
     // skip ret not in nested control flow.
-    if (curScope.kind == Scope::ScopeKind::FunctionScope)
+    if (pRetParentScope->kind == Scope::ScopeKind::FunctionScope)
       continue;
 
-    BasicBlock *retBB = retScope.EndScopeBB;
-    Instruction *RetInst = retBB->getTerminator();
-    IRBuilder<> B(retBB->begin());
-    // bIsReturned = true;
-    B.CreateStore(cTrue, bIsReturned);
-
-    BranchInst *Br = cast<BranchInst>(RetInst);
-    if (!FunctionScope.EndScopeBB) {
-      FunctionScope.EndScopeBB = Br->getSuccessor(0);
-    }
-    // First level just branch to parentScope.
-    BasicBlock *endScopeBB = curScope.EndScopeBB;
-    Br->setSuccessor(0, endScopeBB);
-
-    // Guard parent scopes with bReturned until finish function scope.
-    Scope &parentScope = ScopeInfo.GetScope(curScope.parentScopeIndex);
-    while (true) {
-      BasicBlock *BB = curScope.EndScopeBB;
-      BasicBlock *EndBB = parentScope.EndScopeBB;
-      switch (parentScope.kind) {
-      case Scope::ScopeKind::FunctionScope:
-      case Scope::ScopeKind::ThenScope:
-      case Scope::ScopeKind::ElseScope: {
-        // inside if.
-        // if (!bReturned) {
-        //   rest of if or else.
-        // }
-        BasicBlock *CmpBB =
-            BasicBlock::Create(BB->getContext(), "bReturned.cmp.false", F, BB);
-        // Make BB preds go to cmpBB.
-        BB->replaceAllUsesWith(CmpBB);
-        IRBuilder<> B(CmpBB);
-        Value *isRetured = B.CreateLoad(bIsReturned, "bReturned.load");
-        Value *notRetunred = B.CreateNot(isRetured, "bReturned.not");
-        B.CreateCondBr(notRetunred, BB, EndBB);
-      } break;
-      default: {
-        // inside switch/loop
-        // if (bReturned) {
-        //   br endOfScope.
-        // }
-        BasicBlock *CmpBB =
-            BasicBlock::Create(BB->getContext(), "bReturned.cmp.true", F, BB);
-        BasicBlock *BreakBB =
-            BasicBlock::Create(BB->getContext(), "bReturned.break", F, BB);
-        BB->replaceAllUsesWith(CmpBB);
-        IRBuilder<> B(CmpBB);
-        Value *isRetured = B.CreateLoad(bIsReturned, "bReturned.load");
-        B.CreateCondBr(isRetured, BreakBB, BB);
-
-        B.SetInsertPoint(BreakBB);
-        if (bWaveEnabledStage &&
-            parentScope.kind == Scope::ScopeKind::LoopScope) {
-          BranchInst *BI =
-              B.CreateCondBr(cTrue, EndBB, parentScope.loopContinueBB);
-          DxBreaks.emplace_back(BI);
-        } else {
-          B.CreateBr(EndBB);
+    do {
+      BasicBlock *BB = pCurScope->EndScopeBB;
+      // exit when scope is processed.
+      if (guardedSet.count(scopeIndex))
+        break;
+      guardedSet.insert(scopeIndex);
+
+      Scope *pParentScope = &ScopeInfo.GetScope(pCurScope->parentScopeIndex);
+      BasicBlock *EndBB = pParentScope->EndScopeBB;
+      // When whole scope returned, just branch to endScope of parent.
+      if (pCurScope->bWholeScopeReturned) {
+        // For ret, just branch to endScope of parent.
+        if (pCurScope->kind == Scope::ScopeKind::ReturnScope) {
+          BasicBlock *retBB = pCurScope->EndScopeBB;
+          TerminatorInst *retBr = retBB->getTerminator();
+          IRBuilder<> B(retBr);
+          // Set bReturned to true.
+          B.CreateStore(cTrue, bIsReturned);
+          if (bWaveEnabledStage &&
+              pParentScope->kind == Scope::ScopeKind::LoopScope) {
+            BranchInst *BI =
+                B.CreateCondBr(cTrue, EndBB, pParentScope->loopContinueBB);
+            DxBreaks.emplace_back(BI);
+            retBr->eraseFromParent();
+          } else {
+            // Update branch target.
+            retBr->setSuccessor(0, EndBB);
+          }
+        }
+        // For other scope, do nothing. Since whole scope is returned.
+        // Just flow naturally to parent scope.
+      } else {
+        // When only part scope returned.
+        // Use bIsReturned to guard to part which not returned.
+        switch (pParentScope->kind) {
+        case Scope::ScopeKind::ReturnScope:
+          DXASSERT(0, "return scope must get whole scope returned.");
+          break;
+        case Scope::ScopeKind::FunctionScope:
+        case Scope::ScopeKind::IfScope: {
+          // inside if.
+          // if (!bReturned) {
+          //   rest of if or else.
+          // }
+          BasicBlock *CmpBB = BasicBlock::Create(BB->getContext(),
+                                                 "bReturned.cmp.false", F, BB);
+
+          // Make BB preds go to cmpBB.
+          BB->replaceAllUsesWith(CmpBB);
+          // Update endscopeBB to CmpBB for scopes which has BB as endscope.
+          updateEndScope(ScopeInfo, EndBBToScopeIndexMap, BB, CmpBB);
+
+          IRBuilder<> B(CmpBB);
+          Value *isRetured = B.CreateLoad(bIsReturned, "bReturned.load");
+          Value *notReturned =
+              B.CreateICmpNE(isRetured, cFalse, "bReturned.not");
+          B.CreateCondBr(notReturned, EndBB, BB);
+        } break;
+        default: {
+          // inside switch/loop
+          // if (bReturned) {
+          //   br endOfScope.
+          // }
+          BasicBlock *CmpBB =
+              BasicBlock::Create(BB->getContext(), "bReturned.cmp.true", F, BB);
+          BasicBlock *BreakBB =
+              BasicBlock::Create(BB->getContext(), "bReturned.break", F, BB);
+          BB->replaceAllUsesWith(CmpBB);
+          // Update endscopeBB to CmpBB for scopes which has BB as endscope.
+          updateEndScope(ScopeInfo, EndBBToScopeIndexMap, BB, CmpBB);
+
+          IRBuilder<> B(CmpBB);
+          Value *isReturned = B.CreateLoad(bIsReturned, "bReturned.load");
+          isReturned = B.CreateICmpEQ(isReturned, cTrue, "bReturned.true");
+          B.CreateCondBr(isReturned, BreakBB, BB);
+
+          B.SetInsertPoint(BreakBB);
+          if (bWaveEnabledStage &&
+              pParentScope->kind == Scope::ScopeKind::LoopScope) {
+            BranchInst *BI =
+                B.CreateCondBr(cTrue, EndBB, pParentScope->loopContinueBB);
+            DxBreaks.emplace_back(BI);
+          } else {
+            B.CreateBr(EndBB);
+          }
+        } break;
         }
-      } break;
       }
 
-      curScope = ScopeInfo.GetScope(curScope.parentScopeIndex);
-      // break after done with function scope.
-      if (curScope.kind == Scope::ScopeKind::FunctionScope)
-        break;
-      parentScope = ScopeInfo.GetScope(parentScope.parentScopeIndex);
-    }
+      scopeIndex = pCurScope->parentScopeIndex;
+      pCurScope = &ScopeInfo.GetScope(scopeIndex);
+      // done when reach function scope.
+    } while (pCurScope->kind != Scope::ScopeKind::FunctionScope);
   }
 }
 } // namespace
@@ -2860,7 +2998,8 @@ void StructurizeMultiRet(Module &M, DenseMap<Function *, ScopeInfo> &ScopeMap,
     if (F.isDeclaration())
       continue;
     auto it = ScopeMap.find(&F);
-    DXASSERT(it != ScopeMap.end(), "cannot find scope info");
+    if (it == ScopeMap.end())
+      continue;
     StructurizeMultiRetFunction(&F, it->second, bWaveEnabledStage, DxBreaks);
   }
 }

+ 22 - 5
tools/clang/lib/CodeGen/CGHLSLMSHelper.h

@@ -78,8 +78,7 @@ private:
 // Scope to help transform multiple returns.
 struct Scope {
  enum class ScopeKind {
-   ThenScope,
-   ElseScope,
+   IfScope,
    SwitchScope,
    LoopScope,
    ReturnScope,
@@ -89,6 +88,19 @@ struct Scope {
  llvm::BasicBlock *EndScopeBB;
  // Save loopContinueBB to create dxBreak.
  llvm::BasicBlock *loopContinueBB;
+ // For case like
+ // if () {
+ //   ...
+ //   return;
+ // } else {
+ //   ...
+ //   return;
+ // }
+ //
+ // both path is returned.
+ // When whole scope is returned, go to parent scope directly.
+ // Anything after it is unreachable.
+ bool bWholeScopeReturned;
  unsigned parentScopeIndex;
 };
 
@@ -96,17 +108,22 @@ class ScopeInfo {
 public:
   ScopeInfo(){}
   ScopeInfo(llvm::Function *F);
-  void AddThen(llvm::BasicBlock *endIfBB);
-  void AddElse(llvm::BasicBlock *endIfBB);
+  void AddIf(llvm::BasicBlock *endIfBB);
   void AddSwitch(llvm::BasicBlock *endSwitchBB);
   void AddLoop(llvm::BasicBlock *loopContinue, llvm::BasicBlock *endLoopBB);
   void AddRet(llvm::BasicBlock *bbWithRet);
-  void EndScope();
+  void EndScope(bool bScopeFinishedWithRet);
   Scope &GetScope(unsigned i);
   const llvm::SmallVector<unsigned, 2> &GetRetScopes() { return rets; }
+  void LegalizeWholeReturnedScope();
+  llvm::SmallVector<Scope, 16> &GetScopes() { return scopes; }
+  bool CanSkipStructurize();
+
 private:
   void AddScope(Scope::ScopeKind k, llvm::BasicBlock *endScopeBB);
   llvm::SmallVector<unsigned, 2> rets;
+  unsigned maxRetLevel;
+  bool bAllReturnsInIf;
   llvm::SmallVector<unsigned, 8> scopeStack;
   // save all scopes.
   llvm::SmallVector<Scope, 16> scopes;

+ 2 - 4
tools/clang/lib/CodeGen/CGHLSLRuntime.h

@@ -130,10 +130,7 @@ public:
 
   virtual void FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D,
                              llvm::Value *V) = 0;
-  virtual void MarkThenStmt(CodeGenFunction &CGF,
-                            llvm::BasicBlock *endIfBB) = 0;
-  virtual void MarkElseStmt(CodeGenFunction &CGF,
-                            llvm::BasicBlock *endIfBB) = 0;
+  virtual void MarkIfStmt(CodeGenFunction &CGF, llvm::BasicBlock *endIfBB) = 0;
   virtual void MarkSwitchStmt(CodeGenFunction &CGF,
                               llvm::SwitchInst *switchInst,
                               llvm::BasicBlock *endSwitch) = 0;
@@ -142,6 +139,7 @@ public:
   virtual void MarkLoopStmt(CodeGenFunction &CGF,
                              llvm::BasicBlock *loopContinue,
                              llvm::BasicBlock *loopExit) = 0;
+
   virtual void MarkScopeEnd(CodeGenFunction &CGF) = 0;
 };
 

+ 19 - 22
tools/clang/lib/CodeGen/CGStmt.cpp

@@ -600,7 +600,7 @@ void CodeGenFunction::EmitIfStmt(const IfStmt &S,
   llvm::TerminatorInst *TI =
       cast<llvm::TerminatorInst>(*ThenBlock->user_begin());
   CGM.getHLSLRuntime().AddControlFlowHint(*this, S, TI, Attrs);
-  CGM.getHLSLRuntime().MarkThenStmt(*this, ContBlock);
+  CGM.getHLSLRuntime().MarkIfStmt(*this, ContBlock);
   // HLSL Change Ends
 
   // Emit the 'then' code.
@@ -612,15 +612,8 @@ void CodeGenFunction::EmitIfStmt(const IfStmt &S,
   }
   EmitBranch(ContBlock);
 
-  // HLSL Change Begin.
-  CGM.getHLSLRuntime().MarkScopeEnd(*this);
-  // HLSL Change End.
-
   // Emit the 'else' code if present.
   if (const Stmt *Else = S.getElse()) {
-    // HLSL Change Begin.
-    CGM.getHLSLRuntime().MarkElseStmt(*this, ContBlock);
-    // HLSL Change End.
     {
       // There is no need to emit line number for an unconditional branch.
       auto NL = ApplyDebugLocation::CreateEmpty(*this);
@@ -636,10 +629,10 @@ void CodeGenFunction::EmitIfStmt(const IfStmt &S,
       EmitBranch(ContBlock);
     }
 
-    // HLSL Change Begin.
-    CGM.getHLSLRuntime().MarkScopeEnd(*this);
-    // HLSL Change End.
   }
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
   // Emit the continuation block for code after the if.
   EmitBlock(ContBlock, true);
 }
@@ -827,6 +820,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
 
   LoopStack.pop();
 
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
+
   // Emit the exit block.
   EmitBlock(LoopExit.getBlock(), true);
 
@@ -834,9 +831,6 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   // a branch, try to erase it.
   if (!EmitBoolCondBranch)
     SimplifyForwardingBlocks(LoopHeader.getBlock());
-  // HLSL Change Begin.
-  CGM.getHLSLRuntime().MarkScopeEnd(*this);
-  // HLSL Change End.
 }
 
 void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -896,6 +890,10 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
 
   LoopStack.pop();
 
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
+
   // Emit the exit block.
   EmitBlock(LoopExit.getBlock());
 
@@ -903,9 +901,6 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   // emitting a branch, try to erase it.
   if (!EmitBoolCondBranch)
     SimplifyForwardingBlocks(LoopCond.getBlock());
-  // HLSL Change Begin.
-  CGM.getHLSLRuntime().MarkScopeEnd(*this);
-  // HLSL Change End.
 }
 
 void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1004,11 +999,12 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
 
   LoopStack.pop();
 
-  // Emit the fall-through block.
-  EmitBlock(LoopExit.getBlock(), true);
   // HLSL Change Begin.
   CGM.getHLSLRuntime().MarkScopeEnd(*this);
   // HLSL Change End.
+
+  // Emit the fall-through block.
+  EmitBlock(LoopExit.getBlock(), true);
 }
 
 void
@@ -1728,6 +1724,10 @@ void CodeGenFunction::EmitSwitchStmt(const SwitchStmt &S,
 
   ConditionScope.ForceCleanup();
 
+  // HLSL Change Begin.
+  CGM.getHLSLRuntime().MarkScopeEnd(*this);
+  // HLSL Change End.
+
   // Emit continuation.
   EmitBlock(SwitchExit.getBlock(), true);
   incrementProfileCounter(&S);
@@ -1744,9 +1744,6 @@ void CodeGenFunction::EmitSwitchStmt(const SwitchStmt &S,
   SwitchInsn = SavedSwitchInsn;
   SwitchWeights = SavedSwitchWeights;
   CaseRangeBlock = SavedCRBlock;
-  // HLSL Change Begin.
-  CGM.getHLSLRuntime().MarkScopeEnd(*this);
-  // HLSL Change End.
 }
 
 static std::string

+ 11 - 22
tools/clang/test/HLSLFileCheck/hlsl/control_flow/return/multi_ret.hlsl

@@ -6,6 +6,8 @@ float main(float4 a:A) : SV_Target {
 // Init bReturned.
 // CHECK:%[[bReturned:.*]] = alloca i1
 // CHECK:store i1 false, i1* %[[bReturned]]
+// Init retVal to 0.
+// CHECK:store float 0.000000e+00
 
   float c = 0;
 
@@ -20,7 +22,7 @@ float main(float4 a:A) : SV_Target {
 // guard rest of scope with !bReturned
 // CHECK: [[label_bRet_cmp_false:.*]] ; preds =
 // CHECK:%[[RET:.*]] = load i1, i1* %[[bReturned]]
-// CHECK:%[[NRET:.*]] = xor i1 %[[RET]], true
+// CHECK:%[[NRET:.*]] = icmp ne i1 %[[RET]], false
 // CHECK:br i1 %[[NRET]],
 
 // CHECK: [[label3:.*]]  ; preds =
@@ -35,7 +37,7 @@ float main(float4 a:A) : SV_Target {
       return -5;
 // CHECK: [[label_bRet_cmp_false2:.*]] ; preds =
 // CHECK:%[[RET2:.*]] = load i1, i1* %[[bReturned]]
-// CHECK:%[[NRET2:.*]] = xor i1 %[[RET2]], true
+// CHECK:%[[NRET2:.*]] = icmp ne i1 %[[RET2]], false
 // CHECK:br i1 %[[NRET2]],
 
 // CHECK: [[label4:.*]] ; preds =
@@ -43,13 +45,9 @@ float main(float4 a:A) : SV_Target {
 // guard after endif.
 // CHECK: [[label_bRet_cmp_false3:.*]] ; preds =
 // CHECK:%[[RET3:.*]] = load i1, i1* %[[bReturned]]
-// CHECK:%[[NRET3:.*]] = xor i1 %[[RET3]], true
+// CHECK:%[[NRET3:.*]] = icmp ne i1 %[[RET3]], false
 // CHECK:br i1 %[[NRET3]],
-// guard after endif for else.
-// CHECK: [[label_bRet_cmp_false4:.*]] ; preds =
-// CHECK:%[[RET4:.*]] = load i1, i1* %[[bReturned]]
-// CHECK:%[[NRET4:.*]] = xor i1 %[[RET4]], true
-// CHECK:br i1 %[[NRET4]],
+
   }
 // CHECK: [[endif:.*]] ; preds =
 
@@ -63,14 +61,10 @@ float main(float4 a:A) : SV_Target {
     if (c > 10)
 // set bIsReturn to true
 // CHECK:store i1 true, i1* %[[bReturned]]
+// dxBreak.
+// CHECK:br i1 true,
       return -2;
 
-// CHECK: [[label_bRet_cmp_true:.*]] ; preds =
-// CHECK:%[[RET5:.*]] = load i1, i1* %[[bReturned]]
-// CHECK:br i1 %[[RET5]],
-// CHECK: [[label_bRet_break:.*]] ; preds =
-// dxBreak
-// CHECK:br i1 true,
 // CHECK: [[endif_in_loop:.*]] ; preds =
 
 // CHECK: [[for_inc:.*]] ; preds =
@@ -78,7 +72,7 @@ float main(float4 a:A) : SV_Target {
 // Guard after loop.
 // CHECK: [[label_bRet_cmp_false5:.*]] ; preds =
 // CHECK:%[[RET6:.*]] = load i1, i1* %[[bReturned]]
-// CHECK:%[[NRET6:.*]] = xor i1 %[[RET6]], true
+// CHECK:%[[NRET6:.*]] = icmp ne i1 %[[RET6]], false
 // CHECK:br i1 %[[NRET6]],
 
   }
@@ -103,13 +97,8 @@ float main(float4 a:A) : SV_Target {
          if (c < 10)
 // set bIsReturn to true
 // CHECK:store i1 true, i1* %[[bReturned]]
+// return just change to branch out of switch.
          return -3;
-// CHECK: [[label_bRet_cmp_true2:.*]] ; preds =
-// CHECK:%[[RET7:.*]] = load i1, i1* %[[bReturned]]
-// CHECK:br i1 %[[RET7]],
-// CHECK: [[label_bRet_break2:.*]] ; preds =
-// normal break
-// CHECK:br label
 
 // CHECK: [[endif_in_switch:.*]] ; preds =
        c += sin(a.x);
@@ -118,7 +107,7 @@ float main(float4 a:A) : SV_Target {
 // guard code after switch.
 // CHECK: [[label_bRet_cmp_false6:.*]] ; preds =
 // CHECK:%[[RET8:.*]] = load i1, i1* %[[bReturned]]
-// CHECK:%[[NRET8:.*]] = xor i1 %[[RET8]], true
+// CHECK:%[[NRET8:.*]] = icmp ne i1 %[[RET8]], false
 // CHECK:br i1 %[[NRET8]]
 
 // CHECK: [[end_switch:.*]]; preds =

+ 37 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/return/whole_scope_returned_if.hlsl

@@ -0,0 +1,37 @@
+// RUN: %dxc -E main -fcgl -structurize-returns -T ps_6_0 %s | FileCheck %s
+
+float main(float4 a:A) : SV_Target {
+// Init bReturned.
+// CHECK:%[[bReturned:.*]] = alloca i1
+// CHECK:store i1 false, i1* %[[bReturned]]
+
+// Init retVal to 0.
+// CHECK:store float 0.000000e+00
+
+// CHECK: [[label:.*]] ; preds =
+  if (a.w < 0) {
+// CHECK: [[label2:.*]] ; preds =
+   if (floor(a.x) > 1) {
+// set bReturned to true.
+// CHECK:store i1 true, i1* %[[bReturned]]
+     return sin(a.y);
+   } else {
+// CHECK: [[else:.*]] ; preds =
+// set bReturned to true.
+// CHECK:store i1 true, i1* %[[bReturned]]
+     float r = log(a.z);
+     return r;
+   }
+  }
+// guard rest of scope with !bReturned
+// CHECK: [[label_bRet_cmp_false:.*]] ; preds =
+// CHECK:%[[RET:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:%[[NRET:.*]] = icmp ne i1 %[[RET]], false
+// CHECK:br i1 %[[NRET]],
+
+// CHECK: [[label3:.*]]  ; preds =
+  return cos(a.x+a.y);
+// CHECK: [[exit:.*]]  ; preds =
+// CHECK-NOT:preds
+// CHECK:ret float
+}

+ 50 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/return/whole_scope_returned_loop.hlsl

@@ -0,0 +1,50 @@
+// RUN: %dxc -E main -fcgl -structurize-returns -T ps_6_0 %s | FileCheck %s
+
+int i;
+// CHECK:define float @main
+float main(float4 a:A) : SV_Target {
+// Init bReturned.
+// CHECK:%[[bReturned:.*]] = alloca i1
+// CHECK-NEXT:store i1 false, i1* %[[bReturned]]
+// Init retVal to 0.
+// CHECK:store float 0.000000e+00
+  float r = a.w;
+
+// CHECK: [[if_then:.*]] ; preds =
+  if (a.z > 0) {
+// CHECK: [[for_cond:.*]] ; preds =
+    for (int j=0;j<i;j++) {
+// CHECK: [[for_body:.*]] ; preds =
+// set bReturned to true.
+// CHECK:store i1 true, i1* %[[bReturned]]
+// dxBreak
+// CHECK:br i1 true
+       return log(i);
+
+// CHECK: [[for_inc:.*]] ; preds =
+    }
+// guard rest of scope with !bReturned
+// CHECK: [[bRet_cmp_false:.*]] ; preds =
+// CHECK:%[[RET:.*]] = load i1, i1* %[[bReturned]]
+// CHECK:%[[NRET:.*]] = icmp ne i1 %[[RET]], false
+// CHECK:br i1 %[[NRET]],
+
+// CHECK: [[for_end:.*]]  ; preds =
+    r += sin(a.y);
+// set bReturned to true.
+// CHECK:store i1 true, i1* %[[bReturned]]
+    return sin(a.x * a.z + r);
+  } else {
+// CHECK: [[else:.*]]  ; preds =
+// set bReturned to true.
+// CHECK:store i1 true, i1* %[[bReturned]]
+    return cos(r + a.z);
+  }
+
+// dead code which not has code generated.
+  return a.x + a.y;
+
+// CHECK: [[exit:.*]]  ; preds =
+// CHECK-NOT:preds
+// CHECK:ret float
+}

+ 26 - 0
tools/clang/test/HLSLFileCheck/hlsl/control_flow/return/whole_scope_returned_switch.hlsl

@@ -0,0 +1,26 @@
+// RUN: %dxc -E main -fcgl -structurize-returns -T ps_6_0 %s | FileCheck %s
+
+int i;
+
+float main(float4 a:A) : SV_Target {
+// Init bReturned.
+// CHECK:%[[bReturned:.*]] = alloca i1
+// CHECK:store i1 false, i1* %[[bReturned]]
+
+// CHECK:switch
+  switch (i) {
+  default:
+// CHECK: [[default:.*]] ; preds =
+// set bReturned to true.
+// CHECK:store i1 true, i1* %[[bReturned]]
+    return sin(a.y);
+  case 1:
+// CHECK: [[case1:.*]] ; preds =
+// set bReturned to true.
+// CHECK:store i1 true, i1* %[[bReturned]]
+    return cos(a.x);
+  }
+// CHECK: [[exit:.*]]  ; preds =
+// CHECK-NOT:preds
+// CHECK:ret float
+}