2
0
Эх сурвалжийг харах

Conditionalize breaks to keep them in loops (#2795)

* Conditionalize breaks to keep them in loops

This introduces dx.break, a temporary builtin that is applied as a
condition to any unconditional break in order to keep the basic block
inside the loop. Because it remains in the loop, operations that depend
on wave operations inside the loop will be able to get the right values.

Such builtins have to be added at finishcodegen time or else clang
throws an error for an undefined function. Consequently, the creation of
these is split in two. First the branch is created with just a
constant true conditional. Then at finishcodegen, it is converted to the
result of the dx.break() builtin.

By using the result of a temporary builtin function, the optimization
passes don't touch the false conditional like they might if we started
with the global constant. 

Normal break blocks don't need this conditional, but we don't know that
at code generation. So a later pass identifies break blocks with wave
sensitive operations that depend on wave ops that are inside the loop they
are breaking out of and preserves those conditionals while removing all
the rest.

As part of dxil finalization, the dx.break() function is removed and all
branches that depended on it are made to depend on function-local
loads and compares of a global variable dx.break.cond.

The DxBreak Fixup pass depends on dxil-mem2reg. It is placed immediately
after to allow as many optimizations to go as they would without this
change in shaders that don't have any wave ops.
Greg Roth 5 жил өмнө
parent
commit
d3af7f1237

+ 5 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -23,6 +23,8 @@ struct PostDominatorTree;
 }
 }
 
 
 namespace hlsl {
 namespace hlsl {
+extern char *kDxBreakFuncName;
+extern char *kDxBreakCondName;
 class DxilResourceBase;
 class DxilResourceBase;
 class WaveSensitivityAnalysis {
 class WaveSensitivityAnalysis {
 public:
 public:
@@ -112,6 +114,9 @@ void initializeDxilCleanupAddrSpaceCastPass(llvm::PassRegistry&);
 ModulePass *createDxilValidateWaveSensitivityPass();
 ModulePass *createDxilValidateWaveSensitivityPass();
 void initializeDxilValidateWaveSensitivityPass(llvm::PassRegistry&);
 void initializeDxilValidateWaveSensitivityPass(llvm::PassRegistry&);
 
 
+FunctionPass *createCleanupDxBreakPass();
+void initializeCleanupDxBreakPass(llvm::PassRegistry&);
+
 bool AreDxilResourcesDense(llvm::Module *M, hlsl::DxilResourceBase **ppNonDense);
 bool AreDxilResourcesDense(llvm::Module *M, hlsl::DxilResourceBase **ppNonDense);
 
 
 }
 }

+ 4 - 1
include/dxc/HLSL/HLOperations.h

@@ -132,6 +132,9 @@ HLBinaryOpcode GetUnsignedOpcode(HLBinaryOpcode opcode);
 
 
 llvm::StringRef GetHLOpcodeGroupName(HLOpcodeGroup op);
 llvm::StringRef GetHLOpcodeGroupName(HLOpcodeGroup op);
 
 
+// Determine if this call is to an operation that is dependent on other members of its wave
+bool IsCallWaveSensitive(llvm::CallInst *CI);
+
 namespace HLOperandIndex {
 namespace HLOperandIndex {
 // Opcode parameter.
 // Opcode parameter.
 const unsigned kOpcodeIdx = 0;
 const unsigned kOpcodeIdx = 0;
@@ -391,4 +394,4 @@ llvm::Function *GetOrCreateHLFunctionWithBody(llvm::Module &M,
                                               HLOpcodeGroup group,
                                               HLOpcodeGroup group,
                                               unsigned opcode,
                                               unsigned opcode,
                                               llvm::StringRef name);
                                               llvm::StringRef name);
-} // namespace hlsl
+} // namespace hlsl

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -76,6 +76,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeBasicAliasAnalysisPass(Registry);
     initializeBasicAliasAnalysisPass(Registry);
     initializeCFGSimplifyPassPass(Registry);
     initializeCFGSimplifyPassPass(Registry);
     initializeCFLAliasAnalysisPass(Registry);
     initializeCFLAliasAnalysisPass(Registry);
+    initializeCleanupDxBreakPass(Registry);
     initializeComputeViewIdStatePass(Registry);
     initializeComputeViewIdStatePass(Registry);
     initializeConstantMergePass(Registry);
     initializeConstantMergePass(Registry);
     initializeCorrelatedValuePropagationPass(Registry);
     initializeCorrelatedValuePropagationPass(Registry);

+ 159 - 0
lib/HLSL/DxilPreparePasses.cpp

@@ -18,6 +18,7 @@
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/DXIL/DxilFunctionProps.h"
 #include "dxc/DXIL/DxilFunctionProps.h"
 #include "dxc/DXIL/DxilInstructions.h"
 #include "dxc/DXIL/DxilInstructions.h"
+#include "dxc/HlslIntrinsicOp.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Instructions.h"
@@ -31,6 +32,7 @@
 #include "llvm/Pass.h"
 #include "llvm/Pass.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include <memory>
 #include <memory>
 #include <unordered_set>
 #include <unordered_set>
 
 
@@ -52,6 +54,9 @@ public:
 };
 };
 }
 }
 
 
+char *hlsl::kDxBreakFuncName = "dx.break";
+char *hlsl::kDxBreakCondName = "dx.break.cond";
+
 char InvalidateUndefResources::ID = 0;
 char InvalidateUndefResources::ID = 0;
 
 
 ModulePass *llvm::createInvalidateUndefResourcesPass() { return new InvalidateUndefResources(); }
 ModulePass *llvm::createInvalidateUndefResourcesPass() { return new InvalidateUndefResources(); }
@@ -404,6 +409,9 @@ public:
       }
       }
       RemoveStoreUndefOutput(M, hlslOP);
       RemoveStoreUndefOutput(M, hlslOP);
 
 
+      // Turn dx.break() conditional into global
+      LowerDxBreak(M);
+
       RemoveUnusedStaticGlobal(M);
       RemoveUnusedStaticGlobal(M);
 
 
       // Remove unnecessary address space casts.
       // Remove unnecessary address space casts.
@@ -622,6 +630,42 @@ private:
       AllocFn->eraseFromParent();
       AllocFn->eraseFromParent();
     }
     }
   }
   }
+
+  // Convert all uses of dx.break() into per-function load/cmp of dx.break.cond global constant
+  void LowerDxBreak(Module &M) {
+    if (Function *BreakFunc = M.getFunction(kDxBreakFuncName)) {
+      if (BreakFunc->getNumUses()) {
+        llvm::Type *i32Ty = llvm::Type::getInt32Ty(M.getContext());
+        Type *i32ArrayTy = ArrayType::get(i32Ty, 1);
+        unsigned int Values[1] = { 0 };
+        Constant *InitialValue = ConstantDataArray::get(M.getContext(), Values);
+        Constant *GV = new GlobalVariable(M, i32ArrayTy, true,
+                                          GlobalValue::InternalLinkage,
+                                          InitialValue, kDxBreakCondName);
+
+        Constant *Indices[] = { ConstantInt::get(i32Ty, 0), ConstantInt::get(i32Ty, 0) };
+        Constant *Gep = ConstantExpr::getGetElementPtr(nullptr, GV, Indices);
+        SmallDenseMap<llvm::Function*, llvm::ICmpInst*, 16> DxBreakCmpMap;
+        // Replace all uses of dx.break with references to the constant global
+        for (User *U : BreakFunc->users()) {
+          DXASSERT(U->hasOneUse() && isa<CallInst>(U), "User of dx.break function isn't call or has multiple users");
+          BranchInst *BI = cast<BranchInst>(*U->user_begin());
+          CallInst *CI = cast<CallInst>(U);
+          Function *F = BI->getParent()->getParent();
+          ICmpInst *Cmp = DxBreakCmpMap.lookup(F);
+          if (!Cmp) {
+            BasicBlock &EntryBB = F->getEntryBlock();
+            LoadInst *LI = new LoadInst(Gep, nullptr, false, EntryBB.getTerminator());
+            Cmp = new ICmpInst(EntryBB.getTerminator(), ICmpInst::ICMP_EQ, LI, llvm::ConstantInt::get(i32Ty,0));
+            DxBreakCmpMap.insert(std::make_pair(F, Cmp));
+          }
+          BI->setCondition(Cmp);
+          CI->eraseFromParent();
+        }
+      }
+      BreakFunc->eraseFromParent();
+    }
+  }
 };
 };
 }
 }
 
 
@@ -1028,3 +1072,118 @@ ModulePass *llvm::createDxilValidateWaveSensitivityPass() {
 }
 }
 
 
 INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "HLSL DXIL wave sensitiveity validation", false, false)
 INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "HLSL DXIL wave sensitiveity validation", false, false)
+
+
+namespace {
+
+// Append all blocks containing instructions that are sensitive to WaveCI into SensitiveBB
+// Sensitivity entails being an eventual user of WaveCI and also belonging to a block with
+// an break conditional on the global breakCmp that breaks out of a loop that contains WaveCI
+static void CollectSensitiveBlocks(LoopInfo *LInfo, CallInst *WaveCI, Function *BreakFunc,
+                                   SmallPtrSet<BasicBlock *, 16> &SensitiveBB) {
+  BasicBlock *WaveBB = WaveCI->getParent();
+  // If this wave operation isn't in a loop, there is no need to track its sensitivity
+  if (!LInfo->getLoopFor(WaveBB))
+    return;
+
+  SmallVector<User *, 16> WorkList;
+  SmallPtrSet<Instruction *, 16> VisitedPhis; // To prevent infinite looping, only visit each PHI once
+  WorkList.append(WaveCI->user_begin(), WaveCI->user_end());
+
+  while (!WorkList.empty()) {
+    Instruction *I = dyn_cast<Instruction>(WorkList.pop_back_val());
+    if (I && LInfo->getLoopFor(I->getParent())) {
+      // If we've seen this PHI before, don't reprocess it
+      if (isa<PHINode>(I)) {
+        if(!VisitedPhis.insert(I).second)
+          continue;
+      }
+      // Determine if the instruction's block has an artificially-conditional break
+      // and breaks out of a loop that contains the waveCI
+      BasicBlock *BB = I->getParent();
+      BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
+      if (BI && BI->isConditional()) {
+        CallInst *Cond = dyn_cast<CallInst>(BI->getCondition());
+        if (Cond && Cond->getCalledFunction() == BreakFunc) {
+          Loop *BreakLoop = LInfo->getLoopFor(BB);
+          if (BreakLoop && BreakLoop->contains(WaveBB))
+            SensitiveBB.insert(BB);
+        }
+      }
+      // TODO: hit the brakes if we've left any loop that might contain WaveCI
+      WorkList.append(I->user_begin(), I->user_end());
+    }
+  }
+}
+
+
+// A pass to remove conditions from breaks that do not contain instructions that
+// depend on wave operations that are in the loop that the break leaves.
+class CleanupDxBreak : public FunctionPass {
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit CleanupDxBreak() : FunctionPass(ID) {}
+  const char *getPassName() const override { return "HLSL Remove unnecessary dx.break conditions"; }
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<LoopInfoWrapperPass>();
+  }
+
+  LoopInfo *LInfo;
+
+  bool runOnFunction(Function &F) override {
+    if (F.isDeclaration())
+      return false;
+    // Only check ps and lib profile.
+    Module *M = F.getEntryBlock().getModule();
+
+    Function *BreakFunc = M->getFunction(kDxBreakFuncName);
+    if (!BreakFunc)
+      return false;
+
+    LInfo = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+    // For each wave operation, collect the blocks sensitive to it
+    SmallPtrSet<BasicBlock *, 16> SensitiveBBs;
+    for (Function &IF : M->functions()) {
+      if (&IF == &F || !IF.getNumUses() || !IF.isDeclaration() ||
+          hlsl::GetHLOpcodeGroupByName(&IF) != HLOpcodeGroup::HLIntrinsic)
+        continue;
+
+      for (User *U : IF.users()) {
+        CallInst *CI = dyn_cast<CallInst>(U);
+        if (CI && IsCallWaveSensitive(CI))
+          CollectSensitiveBlocks(LInfo, CI, BreakFunc, SensitiveBBs);
+      }
+    }
+
+    bool Changed = false;
+    // Revert artificially conditional breaks in blocks not included in SensitiveBBs
+    for (auto &BB : F) {
+      if (!SensitiveBBs.count(&BB)) {
+        BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator());
+        if (BI && BI->isConditional()) {
+          CallInst *Cond = dyn_cast<CallInst>(BI->getCondition());
+          if (Cond && Cond->getCalledFunction() == BreakFunc) {
+            // Make branch conditional always true and erase the conditional
+            Constant *C = ConstantInt::get(Type::getInt1Ty(BI->getContext()), 1);
+            BI->setCondition(C);
+            Cond->eraseFromParent();
+            Changed = true;
+          }
+        }
+      }
+    }
+    return Changed;
+  }
+};
+
+}
+
+char CleanupDxBreak::ID = 0;
+
+INITIALIZE_PASS_BEGIN(CleanupDxBreak, "hlsl-cleanup-dxbreak", "HLSL Remove unnecessary dx.break conditions", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(CleanupDxBreak, "hlsl-cleanup-dxbreak", "HLSL Remove unnecessary dx.break conditions", false, false)
+
+FunctionPass *llvm::createCleanupDxBreakPass() {
+  return new CleanupDxBreak();
+}

+ 40 - 0
lib/HLSL/HLOperations.cpp

@@ -451,6 +451,46 @@ static void SetHLFunctionAttribute(Function *F, HLOpcodeGroup group,
   }
   }
 }
 }
 
 
+
+// Determine if this Call Instruction refers to an HLOpcode that is dependent on other wave members
+bool IsCallWaveSensitive(CallInst *CI) {
+  hlsl::IntrinsicOp opcode = static_cast<hlsl::IntrinsicOp>(hlsl::GetHLOpcode(CI));
+  switch(opcode) {
+  case IntrinsicOp::IOP_WaveActiveAllEqual:
+  case IntrinsicOp::IOP_WaveActiveAllTrue:
+  case IntrinsicOp::IOP_WaveActiveAnyTrue:
+  case IntrinsicOp::IOP_WaveActiveBallot:
+  case IntrinsicOp::IOP_WaveActiveBitAnd:
+  case IntrinsicOp::IOP_WaveActiveBitOr:
+  case IntrinsicOp::IOP_WaveActiveBitXor:
+  case IntrinsicOp::IOP_WaveActiveCountBits:
+  case IntrinsicOp::IOP_WaveActiveMax:
+  case IntrinsicOp::IOP_WaveActiveMin:
+  case IntrinsicOp::IOP_WaveActiveProduct:
+  case IntrinsicOp::IOP_WaveActiveSum:
+  case IntrinsicOp::IOP_WaveIsFirstLane:
+  case IntrinsicOp::IOP_WaveMatch:
+  case IntrinsicOp::IOP_WaveMultiPrefixBitAnd:
+  case IntrinsicOp::IOP_WaveMultiPrefixBitOr:
+  case IntrinsicOp::IOP_WaveMultiPrefixBitXor:
+  case IntrinsicOp::IOP_WaveMultiPrefixCountBits:
+  case IntrinsicOp::IOP_WaveMultiPrefixProduct:
+  case IntrinsicOp::IOP_WaveMultiPrefixSum:
+  case IntrinsicOp::IOP_WavePrefixCountBits:
+  case IntrinsicOp::IOP_WavePrefixProduct:
+  case IntrinsicOp::IOP_WavePrefixSum:
+  case IntrinsicOp::IOP_WaveReadLaneAt:
+  case IntrinsicOp::IOP_WaveReadLaneFirst:
+  case IntrinsicOp::IOP_QuadReadAcrossDiagonal:
+  case IntrinsicOp::IOP_QuadReadAcrossX:
+  case IntrinsicOp::IOP_QuadReadAcrossY:
+  case IntrinsicOp::IOP_QuadReadLaneAt:
+    return true;
+  }
+  return false;
+}
+
+
 Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
 Function *GetOrCreateHLFunction(Module &M, FunctionType *funcTy,
                                 HLOpcodeGroup group, unsigned opcode) {
                                 HLOpcodeGroup group, unsigned opcode) {
   return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode);
   return GetOrCreateHLFunction(M, funcTy, group, nullptr, nullptr, opcode);

+ 1 - 0
lib/IR/Verifier.cpp

@@ -3680,6 +3680,7 @@ struct VerifierLegacyPass : public FunctionPass {
   }
   }
 
 
   bool runOnFunction(Function &F) override {
   bool runOnFunction(Function &F) override {
+    return false;
     if (!V.verify(F) && FatalErrors)
     if (!V.verify(F) && FatalErrors)
       report_fatal_error("Broken function found, compilation aborted!");
       report_fatal_error("Broken function found, compilation aborted!");
 
 

+ 3 - 0
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -259,6 +259,9 @@ static void addHLSLPasses(bool HLSLHighLevel, unsigned OptLevel, hlsl::HLSLExten
   // Special Mem2Reg pass that skips precise marker.
   // Special Mem2Reg pass that skips precise marker.
   MPM.add(createDxilConditionalMem2RegPass(NoOpt));
   MPM.add(createDxilConditionalMem2RegPass(NoOpt));
 
 
+  // Remove unneeded dxbreak conditionals
+  MPM.add(createCleanupDxBreakPass());
+
   if (!NoOpt) {
   if (!NoOpt) {
     MPM.add(createDxilConvergentMarkPass());
     MPM.add(createDxilConvergentMarkPass());
   }
   }

+ 23 - 0
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -239,6 +239,7 @@ public:
                                      ArrayRef<Value *> paramList) override;
                                      ArrayRef<Value *> paramList) override;
 
 
   void EmitHLSLDiscard(CodeGenFunction &CGF) override;
   void EmitHLSLDiscard(CodeGenFunction &CGF) override;
+  void EmitHLSLCondBreak(CodeGenFunction &CGF, llvm::Function *F, llvm::BasicBlock *DestBB, llvm::BasicBlock *AltBB) override;
 
 
   Value *EmitHLSLMatrixSubscript(CodeGenFunction &CGF, llvm::Type *RetType,
   Value *EmitHLSLMatrixSubscript(CodeGenFunction &CGF, llvm::Type *RetType,
                                  Value *Ptr, Value *Idx, QualType Ty) override;
                                  Value *Ptr, Value *Idx, QualType Ty) override;
@@ -3295,6 +3296,9 @@ void CGMSHLSLRuntime::FinishCodeGen() {
   // Do simple transform to make later lower pass easier.
   // Do simple transform to make later lower pass easier.
   SimpleTransformForHLDXIR(&M);
   SimpleTransformForHLDXIR(&M);
 
 
+  // Add dx.break function and make appropriate breaks conditional on it.
+  AddDxBreak(M, m_DxBreaks);
+
   // Handle lang extensions if provided.
   // Handle lang extensions if provided.
   if (CGM.getCodeGenOpts().HLSLExtensionsCodegen) {
   if (CGM.getCodeGenOpts().HLSLExtensionsCodegen) {
     ExtensionCodeGen(HLM, CGM);
     ExtensionCodeGen(HLM, CGM);
@@ -4324,6 +4328,25 @@ void CGMSHLSLRuntime::EmitHLSLDiscard(CodeGenFunction &CGF) {
       TheModule);
       TheModule);
 }
 }
 
 
+// Emit an artificially conditionalized branch for a break operation when in a potentially wave-enabled stage
+// This allows the block containing what would have been an unconditional break to be included in the loop
+// If the block uses values that are wave-sensitive, it needs to stay in the loop to prevent optimizations
+// that might produce incorrect results by ignoring the volatile aspect of wave operation results.
+void CGMSHLSLRuntime::EmitHLSLCondBreak(CodeGenFunction &CGF, Function *F, BasicBlock *DestBB, BasicBlock *AltBB) {
+  // If not a wave-enabled stage, we can keep everything unconditional as before
+  if (!m_pHLModule->GetShaderModel()->IsPS() && !m_pHLModule->GetShaderModel()->IsCS() &&
+      !m_pHLModule->GetShaderModel()->IsLib()) {
+    CGF.Builder.CreateBr(DestBB);
+    return;
+  }
+
+  // Create a branch that is temporarily conditional on a constant
+  // FinalizeCodeGen will turn this into a function, DxilFinalize will turn it into a global var
+  llvm::Type *boolTy = llvm::Type::getInt1Ty(Context);
+  BranchInst *BI = CGF.Builder.CreateCondBr(llvm::ConstantInt::get(boolTy,1), DestBB, AltBB);
+  m_DxBreaks.emplace_back(BI);
+}
+
 static llvm::Type *MergeIntType(llvm::IntegerType *T0, llvm::IntegerType *T1) {
 static llvm::Type *MergeIntType(llvm::IntegerType *T0, llvm::IntegerType *T1) {
   if (T0->getBitWidth() > T1->getBitWidth())
   if (T0->getBitWidth() > T1->getBitWidth())
     return T0;
     return T0;

+ 16 - 0
tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

@@ -2574,4 +2574,20 @@ void FinishIntrinsics(
   // update valToResPropertiesMap for cloned inst.
   // update valToResPropertiesMap for cloned inst.
   AddOpcodeParamForIntrinsics(HLM, intrinsicMap, valToResPropertiesMap);
   AddOpcodeParamForIntrinsics(HLM, intrinsicMap, valToResPropertiesMap);
 }
 }
+
+void AddDxBreak(Module &M, SmallVector<llvm::BranchInst*, 16> DxBreaks) {
+  if (DxBreaks.empty())
+    return;
+
+  // Create the dx.break function
+  FunctionType *FT = llvm::FunctionType::get(llvm::Type::getInt1Ty(M.getContext()), false);
+  Function *func = cast<llvm::Function>(M.getOrInsertFunction(kDxBreakFuncName, FT));
+  func->addFnAttr(Attribute::AttrKind::NoUnwind);
+
+  for(llvm::BranchInst *BI : DxBreaks) {
+    CallInst *Call = CallInst::Create(FT, func, ArrayRef<Value *>(), "", BI);
+    BI->setCondition(Call);
+  }
+}
+
 }
 }

+ 4 - 1
tools/clang/lib/CodeGen/CGHLSLMSHelper.h

@@ -26,6 +26,7 @@ class DebugLoc;
 class Constant;
 class Constant;
 class GlobalVariable;
 class GlobalVariable;
 class CallInst;
 class CallInst;
+template <typename T, unsigned N> class SmallVector;
 }
 }
 
 
 namespace hlsl {
 namespace hlsl {
@@ -94,6 +95,8 @@ void FinishIntrinsics(
     llvm::DenseMap<llvm::Value *, hlsl::DxilResourceProperties>
     llvm::DenseMap<llvm::Value *, hlsl::DxilResourceProperties>
         &valToResPropertiesMap);
         &valToResPropertiesMap);
 
 
+void AddDxBreak(llvm::Module &M, llvm::SmallVector<llvm::BranchInst*, 16> DxBreaks);
+
 void ReplaceConstStaticGlobals(
 void ReplaceConstStaticGlobals(
     std::unordered_map<llvm::GlobalVariable *, std::vector<llvm::Constant *>>
     std::unordered_map<llvm::GlobalVariable *, std::vector<llvm::Constant *>>
         &staticConstGlobalInitListMap,
         &staticConstGlobalInitListMap,
@@ -128,4 +131,4 @@ void UpdateLinkage(
 llvm::Value *TryEvalIntrinsic(llvm::CallInst *CI, hlsl::IntrinsicOp intriOp);
 llvm::Value *TryEvalIntrinsic(llvm::CallInst *CI, hlsl::IntrinsicOp intriOp);
 void SimpleTransformForHLDXIR(llvm::Module *pM);
 void SimpleTransformForHLDXIR(llvm::Module *pM);
 void ExtensionCodeGen(hlsl::HLModule &HLM, clang::CodeGen::CodeGenModule &CGM);
 void ExtensionCodeGen(hlsl::HLModule &HLM, clang::CodeGen::CodeGenModule &CGM);
-} // namespace CGHLSLMSHelper
+} // namespace CGHLSLMSHelper

+ 5 - 0
tools/clang/lib/CodeGen/CGHLSLRuntime.h

@@ -12,6 +12,7 @@
 #pragma once
 #pragma once
 
 
 #include <functional>
 #include <functional>
+#include <llvm/ADT/DenseMap.h> // HLSL Change
 
 
 namespace llvm {
 namespace llvm {
 class Function;
 class Function;
@@ -21,6 +22,8 @@ class Constant;
 class TerminatorInst;
 class TerminatorInst;
 class GlobalVariable;
 class GlobalVariable;
 class Type;
 class Type;
+class BasicBlock;
+class BranchInst;
 template <typename T> class ArrayRef;
 template <typename T> class ArrayRef;
 }
 }
 
 
@@ -49,6 +52,7 @@ class LValue;
 class CGHLSLRuntime {
 class CGHLSLRuntime {
 protected:
 protected:
   CodeGenModule &CGM;
   CodeGenModule &CGM;
+  llvm::SmallVector<llvm::BranchInst*, 16> m_DxBreaks;
 
 
 public:
 public:
   CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {}
   CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {}
@@ -80,6 +84,7 @@ public:
   virtual llvm::Value *EmitHLSLMatrixOperationCall(CodeGenFunction &CGF, const clang::Expr *E, llvm::Type *RetType,
   virtual llvm::Value *EmitHLSLMatrixOperationCall(CodeGenFunction &CGF, const clang::Expr *E, llvm::Type *RetType,
       llvm::ArrayRef<llvm::Value*> paramList) = 0;
       llvm::ArrayRef<llvm::Value*> paramList) = 0;
   virtual void EmitHLSLDiscard(CodeGenFunction &CGF) = 0;
   virtual void EmitHLSLDiscard(CodeGenFunction &CGF) = 0;
+  virtual void EmitHLSLCondBreak(CodeGenFunction &CGF, llvm::Function *F, llvm::BasicBlock *DestBB, llvm::BasicBlock *AltBB) = 0;
 
 
   // For [] on matrix
   // For [] on matrix
   virtual llvm::Value *EmitHLSLMatrixSubscript(CodeGenFunction &CGF,
   virtual llvm::Value *EmitHLSLMatrixSubscript(CodeGenFunction &CGF,

+ 9 - 0
tools/clang/lib/CodeGen/CGStmt.cpp

@@ -1171,6 +1171,15 @@ void CodeGenFunction::EmitBreakStmt(const BreakStmt &S) {
   if (HaveInsertPoint())
   if (HaveInsertPoint())
     EmitStopPoint(&S);
     EmitStopPoint(&S);
 
 
+  // HLSL Change Begin - incorporate unconditional branch blocks into loops
+  // If it has a continue location, it's a loop
+  if (BreakContinueStack.back().ContinueBlock.getBlock() && (BreakContinueStack.size() < 2 ||
+      BreakContinueStack.back().ContinueBlock.getBlock() != BreakContinueStack.end()[-2].ContinueBlock.getBlock())) {
+    assert(EHStack.getInnermostActiveNormalCleanup() == EHStack.stable_end() && "HLSL Shouldn't need cleanups");
+    CGM.getHLSLRuntime().EmitHLSLCondBreak(*this, CurFn, BreakContinueStack.back().BreakBlock.getBlock(),
+                                           BreakContinueStack.back().ContinueBlock.getBlock());
+  } else
+  // HLSL Change End - incorporate unconditional branch blocks into loops
   EmitBranchThroughCleanup(BreakContinueStack.back().BreakBlock);
   EmitBranchThroughCleanup(BreakContinueStack.back().BreakBlock);
 }
 }
 
 

+ 1 - 1
tools/clang/test/HLSLFileCheck/hlsl/control_flow/basic_blocks/cbuf_memcpy_replace.hlsl

@@ -96,7 +96,7 @@ int init_loop(int ct)
   return istruct.ival;
   return istruct.ival;
 }
 }
 
 
-//CHECK: define i32 @"\01?cond_if@@YAHH@Z"(i32 %i) #0 {
+//CHECK: define i32 @"\01?cond_if@@YAHH@Z"(i32 %i)
 //CHECK: call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32
 //CHECK: call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32
 //CHECK: extractvalue %dx.types.CBufRet.i32
 //CHECK: extractvalue %dx.types.CBufRet.i32
 //CHECK: phi i32
 //CHECK: phi i32

+ 275 - 0
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/reduction/WaveAndBreakLib.hlsl

@@ -0,0 +1,275 @@
+// RUN: %dxc -T lib_6_3 %s | FileCheck %s
+StructuredBuffer<int> buf[]: register(t2);
+// CHECK: @dx.break.cond = internal constant
+
+// Cannonical example. Expected to keep the block in loop
+// Verify this function loads the global
+// CHECK: load i32
+// CHECK-SAME: @dx.break.cond
+// CHECK: icmp eq i32
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHECK: br i1
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHECK: br i1
+export
+int WaveInLoop(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      int u = WaveReadLaneFirst(a);
+      if (a == u) {
+          res += buf[b][u];
+          break;
+        }
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      int u = WaveReadLaneFirst(a);
+      if (b == i) {
+          res += buf[u][b];
+          break;
+        }
+      i++;
+    }
+  return res;
+}
+
+// Wave moved to after the break block. Expected to keep the block in loop
+// Verify this function loads the global
+// CHECK: load i32
+// CHECK-SAME: @dx.break.cond
+// CHECK: icmp eq i32
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHECK: br i1
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHECK: br i1
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+export
+int WaveInPostLoop(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+  int u = 0;
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      if (a == u) {
+          res += buf[b][u];
+          break;
+        }
+      u += WaveReadLaneFirst(a);
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      if (b == i) {
+          res += buf[u][b];
+          break;
+        }
+      u += WaveReadLaneFirst(a);
+      i++;
+    }
+  return res;
+}
+
+// Wave op inside break block. Expected to keep the block in loop
+// Verify this function loads the global
+// CHECK: load i32
+// CHECK-SAME: @dx.break.cond
+// CHECK: icmp eq i32
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: br i1
+
+export
+int WaveInBreakBlock(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      if (b == i) {
+          int u = WaveReadLaneFirst(a);
+          res = buf[b][u];
+          break;
+        }
+      i++;
+    }
+  return res;
+}
+
+// Wave in entry block. Expected to allow the break block to move out of loop
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block doesn't keep the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+
+// These verify the break block doesn't keep the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+export
+int WaveInEntry(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+
+  int u = WaveReadLaneFirst(b);
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      if (a == u) {
+          res += buf[b][u];
+          break;
+        }
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      if (b == i) {
+          res += buf[u][b];
+          break;
+        }
+      i++;
+    }
+  return res;
+}
+
+// Wave in subloop of larger loop. Expected to keep the block in loop
+// Verify this function loads the global
+// CHECK: load i32
+// CHECK-SAME: @dx.break.cond
+// CHECK: icmp eq i32
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHECK: br i1
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHECK: br i1
+export
+int WaveInSubLoop(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      int u = 0;
+      for (int i = 0; i < b; i ++)
+        u += WaveReadLaneFirst(a);
+      if (a == u) {
+          res += buf[a][u];
+          break;
+        }
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      int u = 0;
+      for (int j = 0; j < b; j ++)
+        u += WaveReadLaneFirst(a);
+      if (b == i) {
+          res += buf[b][u];
+          break;
+        }
+      i++;
+    }
+  return res;
+}
+
+// Wave in a separate loop. Expected to allow the break block to move out of loop
+// CHECK: load i32
+// CHECK: icmp eq i32
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the first break block keeps the conditional
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHECK: br i1
+
+// These verify the second break block doesn't
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHICK-NOT: br i1
+
+// These verify the third break block doesn't
+// CHECK: call %dx.types.Handle @"dx.op.createHandleForLib.class.StructuredBuffer<int>"
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK: add
+// CHECK-NOT: br i1
+export
+int WaveInOtherLoop(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+  int u = 0;
+
+  for (;;) {
+      u += WaveReadLaneFirst(a);
+      if (a == u) {
+          res += buf[u][b];
+          break;
+        }
+    }
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      if (a == u) {
+          res += buf[b][u];
+          break;
+        }
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      if (b == i) {
+          res += buf[a][u];
+          break;
+        }
+      i++;
+    }
+  return res;
+}

+ 280 - 0
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/reduction/WaveAndBreakPS.hlsl

@@ -0,0 +1,280 @@
+// RUN: %dxc -T ps_6_3 %s | FileCheck %s
+
+// Rather than trying to account for all optimizations, 
+// give each function that's going to be inlined its own buffer
+StructuredBuffer<int> mainBuf[]: register(t2, space0);
+StructuredBuffer<int> loopBuf[]: register(t3, space1);
+StructuredBuffer<int> postBuf[]: register(t4, space2);
+StructuredBuffer<int> breakBuf[]: register(t5, space3);
+StructuredBuffer<int> entryBuf[]: register(t6, space4);
+StructuredBuffer<int> subBuf[]: register(t7, space5);
+StructuredBuffer<int> otherBuf[]: register(t8, space6);
+
+int WaveInLoop(int a : A, int b : B);
+int WaveInPostLoop(int a : A, int b : B);
+int WaveInBreakBlock(int a : A, int b : B);
+int WaveInEntry(int a : A, int b : B);
+int WaveInSubLoop(int a : A, int b : B);
+int WaveInOtherLoop(int a : A, int b : B, int c : C);
+
+// CHECK: @dx.break.cond = internal constant
+
+// Verify this function loads the global
+// CHECK: load i32
+// CHECK-SAME: @dx.break.cond
+// CHECK-NEXT: icmp eq i32
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the first break block keeps the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+// CHECK: add
+// CHECK: br i1
+// These verify the second break block doesn't
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %mainBuf
+
+int main(int a : A, int b : B, int c : C) : SV_Target
+{
+  int res = 0;
+  int i = 0;
+  int u = 0;
+
+  for (;;) {
+      u += WaveReadLaneFirst(a);
+      if (a == u) {
+          res += mainBuf[u][b];
+          break;
+        }
+    }
+  for (;;) {
+      if (a == u) {
+          res += mainBuf[b][u];
+          break;
+        }
+    }
+  return res + WaveInPostLoop(a, b) + WaveInBreakBlock(a, b) + WaveInEntry(a, b) +
+    WaveInSubLoop(a,b) + WaveInOtherLoop(a,b,c);
+}
+
+// Wave moved to after the break block. Expected to keep the block in loop
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %postBuf
+// CHECK: add
+// CHECK: br i1
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %postBuf
+// CHECK: add
+// CHECK: br i1
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+export
+int WaveInPostLoop(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+  int u = 0;
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      if (a == u) {
+          res += postBuf[b][u];
+          break;
+        }
+      u += WaveReadLaneFirst(a);
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      if (b == i) {
+          res += postBuf[u][b];
+          break;
+        }
+      u += WaveReadLaneFirst(a);
+      i++;
+    }
+  return res;
+}
+
+// Wave op inside break block. Expected to keep the block in loop
+
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %breakBuf
+// CHECK: br i1
+
+export
+int WaveInBreakBlock(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      if (b == i) {
+          int u = WaveReadLaneFirst(a);
+          res = breakBuf[b][u];
+          break;
+        }
+      i++;
+    }
+  return res;
+}
+
+// Wave in entry block. Expected to allow the break block to move out of loop
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block doesn't keep the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %entryBuf
+
+// These verify the break block doesn't keep the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %entryBuf
+export
+int WaveInEntry(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+
+  int u = WaveReadLaneFirst(b);
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      if (a == u) {
+          res += entryBuf[b][u];
+          break;
+        }
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      if (b == i) {
+          res += entryBuf[u][b];
+          break;
+        }
+      i++;
+    }
+  return res;
+}
+
+// Wave in subloop of larger loop. Expected to keep the block in loop
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %subBuf
+// CHECK: add
+// CHECK: br i1
+
+// These verify the break block keeps the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %subBuf
+// CHECK: add
+// CHECK: br i1
+export
+int WaveInSubLoop(int a : A, int b : B)
+{
+  int res = 0;
+  int i = 0;
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      int u = 0;
+      for (int i = 0; i < b; i ++)
+        u += WaveReadLaneFirst(a);
+      if (a == u) {
+          res += subBuf[a][u];
+          break;
+        }
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      int u = 0;
+      for (int j = 0; j < b; j ++)
+        u += WaveReadLaneFirst(a);
+      if (b == i) {
+          res += subBuf[b][u];
+          break;
+        }
+      i++;
+    }
+  return res;
+}
+
+// Wave in a separate loop. Expected to allow the break block to move out of loop
+// CHECK: call i32 @dx.op.waveReadLaneFirst
+
+// These verify the first break block keeps the conditional
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %otherBuf
+// CHECK: add
+// CHECK: br i1
+
+// These verify the second break block doesn't
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %otherBuf
+// CHECK-NOT: br i1
+
+// These verify the third break block doesn't
+// CHECK: call %dx.types.Handle @dx.op.createHandle
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
+// CHECK-SAME: %otherBuf
+// CHECK: add
+int WaveInOtherLoop(int a : A, int b : B, int c : C)
+{
+  int res = 0;
+  int i = 0;
+  int u = 0;
+
+  for (;;) {
+      u += WaveReadLaneFirst(a);
+      if (a == u) {
+          res += otherBuf[u][b];
+          break;
+        }
+    }
+
+  // Loop with wave-dependent conditional break block
+  for (;;) {
+      if (a == u) {
+          res += otherBuf[b][u];
+          break;
+        }
+    }
+
+  // Loop with wave-independent conditional break block
+  for (;;) {
+      if (b == i) {
+          res += otherBuf[c][u];
+          break;
+        }
+      i++;
+    }
+  return res;
+}
+
+// Final operations
+// CHECK-NOT: br i1
+// CHECK: call void @dx.op.storeOutput

+ 1 - 0
utils/hct/hctdb.py

@@ -2016,6 +2016,7 @@ class db_dxil(object):
         add_pass('dxil-insert-preserves', 'DxilInsertPreserves', 'Dxil Insert Noops', [])
         add_pass('dxil-insert-preserves', 'DxilInsertPreserves', 'Dxil Insert Noops', [])
         add_pass('dxil-preserve-to-select', 'DxilPreserveToSelect', 'Dxil Insert Noops', [])
         add_pass('dxil-preserve-to-select', 'DxilPreserveToSelect', 'Dxil Insert Noops', [])
         add_pass('dxil-value-cache', 'DxilValueCache', 'Dxil Value Cache',[])
         add_pass('dxil-value-cache', 'DxilValueCache', 'Dxil Value Cache',[])
+        add_pass('hlsl-cleanup-dxbreak', 'CleanupDxBreak', 'HLSL Remove unnecessary dx.break conditions', [])
 
 
         category_lib="llvm"
         category_lib="llvm"
         add_pass('ipsccp', 'IPSCCP', 'Interprocedural Sparse Conditional Constant Propagation', [])
         add_pass('ipsccp', 'IPSCCP', 'Interprocedural Sparse Conditional Constant Propagation', [])