Browse Source

Add simplify to DxilValueCache for dx.break() (#2803)

* Add simplify to DxilValueCache for dx.break()

To allow loop unrolling, call instructions that reference dx.break()
are given a constant true boolean simplification in DxilValueCache.

This required making the constant string available by moving it into
Analysis.

As an incidental change, I corrected the spelling of simplify in a
couple cases.
Greg Roth 5 years ago
parent
commit
7d665725a8

+ 3 - 0
include/dxc/DXIL/DxilConstants.h

@@ -1438,6 +1438,9 @@ namespace DXIL {
   extern const char* kFP32DenormValuePreserveString;
   extern const char* kFP32DenormValuePreserveString;
   extern const char* kFP32DenormValueFtzString;
   extern const char* kFP32DenormValueFtzString;
 
 
+  static const char *kDxBreakFuncName = "dx.break";
+  static const char *kDxBreakCondName = "dx.break.cond";
+
 } // namespace DXIL
 } // namespace DXIL
 
 
 } // namespace hlsl
 } // namespace hlsl

+ 0 - 2
include/dxc/HLSL/DxilGenerationPass.h

@@ -23,8 +23,6 @@ struct PostDominatorTree;
 }
 }
 
 
 namespace hlsl {
 namespace hlsl {
-extern char *kDxBreakFuncName;
-extern char *kDxBreakCondName;
 class DxilResourceBase;
 class DxilResourceBase;
 class WaveSensitivityAnalysis {
 class WaveSensitivityAnalysis {
 public:
 public:

+ 2 - 2
include/llvm/Analysis/DxilValueCache.h

@@ -60,8 +60,8 @@ private:
   Value *ProcessValue(Value *V, DominatorTree *DT);
   Value *ProcessValue(Value *V, DominatorTree *DT);
 
 
   Value *ProcessAndSimplify_PHI(Instruction *I, DominatorTree *DT);
   Value *ProcessAndSimplify_PHI(Instruction *I, DominatorTree *DT);
-  Value *ProcessAndSimpilfy_Br(Instruction *I, DominatorTree *DT);
-  Value *ProcessAndSimpilfy_Load(Instruction *LI, DominatorTree *DT);
+  Value *ProcessAndSimplify_Br(Instruction *I, DominatorTree *DT);
+  Value *ProcessAndSimplify_Load(Instruction *LI, DominatorTree *DT);
   Value *SimplifyAndCacheResult(Instruction *I, DominatorTree *DT);
   Value *SimplifyAndCacheResult(Instruction *I, DominatorTree *DT);
 
 
 public:
 public:

+ 13 - 4
lib/Analysis/DxilValueCache.cpp

@@ -12,6 +12,7 @@
 
 
 
 
 #include "llvm/Pass.h"
 #include "llvm/Pass.h"
+#include "dxc/DXIL/DxilConstants.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Constants.h"
@@ -146,7 +147,7 @@ Value *DxilValueCache::ProcessAndSimplify_PHI(Instruction *I, DominatorTree *DT)
   return Simplified;
   return Simplified;
 }
 }
 
 
-Value *DxilValueCache::ProcessAndSimpilfy_Br(Instruction *I, DominatorTree *DT) {
+Value *DxilValueCache::ProcessAndSimplify_Br(Instruction *I, DominatorTree *DT) {
 
 
   // The *only* reason we're paying special attention to the
   // The *only* reason we're paying special attention to the
   // branch inst, is to mark certain Basic Blocks as always
   // branch inst, is to mark certain Basic Blocks as always
@@ -192,7 +193,7 @@ Value *DxilValueCache::ProcessAndSimpilfy_Br(Instruction *I, DominatorTree *DT)
   return nullptr;
   return nullptr;
 }
 }
 
 
-Value *DxilValueCache::ProcessAndSimpilfy_Load(Instruction *I, DominatorTree *DT) {
+Value *DxilValueCache::ProcessAndSimplify_Load(Instruction *I, DominatorTree *DT) {
   LoadInst *LI = cast<LoadInst>(I);
   LoadInst *LI = cast<LoadInst>(I);
   Value *V = TryGetCachedValue(LI->getPointerOperand());
   Value *V = TryGetCachedValue(LI->getPointerOperand());
   if (Constant *ConstPtr = dyn_cast<Constant>(V)) {
   if (Constant *ConstPtr = dyn_cast<Constant>(V)) {
@@ -208,13 +209,13 @@ Value *DxilValueCache::SimplifyAndCacheResult(Instruction *I, DominatorTree *DT)
 
 
   Value *Simplified = nullptr;
   Value *Simplified = nullptr;
   if (Instruction::Br == I->getOpcode()) {
   if (Instruction::Br == I->getOpcode()) {
-    Simplified = ProcessAndSimpilfy_Br(I, DT);
+    Simplified = ProcessAndSimplify_Br(I, DT);
   }
   }
   else if (Instruction::PHI == I->getOpcode()) {
   else if (Instruction::PHI == I->getOpcode()) {
     Simplified = ProcessAndSimplify_PHI(I, DT);
     Simplified = ProcessAndSimplify_PHI(I, DT);
   }
   }
   else if (Instruction::Load == I->getOpcode()) {
   else if (Instruction::Load == I->getOpcode()) {
-    Simplified = ProcessAndSimpilfy_Load(I, DT);
+    Simplified = ProcessAndSimplify_Load(I, DT);
   }
   }
   else if (Instruction::GetElementPtr == I->getOpcode()) {
   else if (Instruction::GetElementPtr == I->getOpcode()) {
     SmallVector<Value *, 4> Ops;
     SmallVector<Value *, 4> Ops;
@@ -222,6 +223,14 @@ Value *DxilValueCache::SimplifyAndCacheResult(Instruction *I, DominatorTree *DT)
       Ops.push_back(TryGetCachedValue(I->getOperand(i)));
       Ops.push_back(TryGetCachedValue(I->getOperand(i)));
     Simplified = llvm::SimplifyGEPInst(Ops, DL, nullptr, DT);
     Simplified = llvm::SimplifyGEPInst(Ops, DL, nullptr, DT);
   }
   }
+  else if (Instruction::Call == I->getOpcode()) {
+    Module *M = I->getModule();
+    CallInst *CI = cast<CallInst>(I);
+    if (CI->getCalledFunction()->getName() == hlsl::DXIL::kDxBreakFuncName) {
+      llvm::Type *i1Ty = llvm::Type::getInt1Ty(M->getContext());
+      Simplified = llvm::ConstantInt::get(i1Ty, 1);
+    }
+  }
   // The rest of the checks use LLVM stock simplifications
   // The rest of the checks use LLVM stock simplifications
   else if (I->isBinaryOp()) {
   else if (I->isBinaryOp()) {
     Simplified =
     Simplified =

+ 5 - 6
lib/HLSL/DxilPreparePasses.cpp

@@ -18,6 +18,7 @@
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/DXIL/DxilFunctionProps.h"
 #include "dxc/DXIL/DxilFunctionProps.h"
 #include "dxc/DXIL/DxilInstructions.h"
 #include "dxc/DXIL/DxilInstructions.h"
+#include "dxc/DXIL/DxilConstants.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IRBuilder.h"
@@ -32,6 +33,7 @@
 #include "llvm/Pass.h"
 #include "llvm/Pass.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/DxilValueCache.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include <memory>
 #include <memory>
 #include <unordered_set>
 #include <unordered_set>
@@ -54,9 +56,6 @@ public:
 };
 };
 }
 }
 
 
-char *hlsl::kDxBreakFuncName = "dx.break";
-char *hlsl::kDxBreakCondName = "dx.break.cond";
-
 char InvalidateUndefResources::ID = 0;
 char InvalidateUndefResources::ID = 0;
 
 
 ModulePass *llvm::createInvalidateUndefResourcesPass() { return new InvalidateUndefResources(); }
 ModulePass *llvm::createInvalidateUndefResourcesPass() { return new InvalidateUndefResources(); }
@@ -633,7 +632,7 @@ private:
 
 
   // Convert all uses of dx.break() into per-function load/cmp of dx.break.cond global constant
   // Convert all uses of dx.break() into per-function load/cmp of dx.break.cond global constant
   void LowerDxBreak(Module &M) {
   void LowerDxBreak(Module &M) {
-    if (Function *BreakFunc = M.getFunction(kDxBreakFuncName)) {
+    if (Function *BreakFunc = M.getFunction(DXIL::kDxBreakFuncName)) {
       if (BreakFunc->getNumUses()) {
       if (BreakFunc->getNumUses()) {
         llvm::Type *i32Ty = llvm::Type::getInt32Ty(M.getContext());
         llvm::Type *i32Ty = llvm::Type::getInt32Ty(M.getContext());
         Type *i32ArrayTy = ArrayType::get(i32Ty, 1);
         Type *i32ArrayTy = ArrayType::get(i32Ty, 1);
@@ -641,7 +640,7 @@ private:
         Constant *InitialValue = ConstantDataArray::get(M.getContext(), Values);
         Constant *InitialValue = ConstantDataArray::get(M.getContext(), Values);
         Constant *GV = new GlobalVariable(M, i32ArrayTy, true,
         Constant *GV = new GlobalVariable(M, i32ArrayTy, true,
                                           GlobalValue::InternalLinkage,
                                           GlobalValue::InternalLinkage,
-                                          InitialValue, kDxBreakCondName);
+                                          InitialValue, DXIL::kDxBreakCondName);
 
 
         Constant *Indices[] = { ConstantInt::get(i32Ty, 0), ConstantInt::get(i32Ty, 0) };
         Constant *Indices[] = { ConstantInt::get(i32Ty, 0), ConstantInt::get(i32Ty, 0) };
         Constant *Gep = ConstantExpr::getGetElementPtr(nullptr, GV, Indices);
         Constant *Gep = ConstantExpr::getGetElementPtr(nullptr, GV, Indices);
@@ -1135,7 +1134,7 @@ public:
     // Only check ps and lib profile.
     // Only check ps and lib profile.
     Module *M = F.getEntryBlock().getModule();
     Module *M = F.getEntryBlock().getModule();
 
 
-    Function *BreakFunc = M->getFunction(kDxBreakFuncName);
+    Function *BreakFunc = M->getFunction(DXIL::kDxBreakFuncName);
     if (!BreakFunc)
     if (!BreakFunc)
       return false;
       return false;
 
 

+ 2 - 1
tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

@@ -18,6 +18,7 @@
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DxilValueCache.h"
 #include "llvm/Transforms/Utils/ValueMapper.h"
 #include "llvm/Transforms/Utils/ValueMapper.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 
 
@@ -2581,7 +2582,7 @@ void AddDxBreak(Module &M, SmallVector<llvm::BranchInst*, 16> DxBreaks) {
 
 
   // Create the dx.break function
   // Create the dx.break function
   FunctionType *FT = llvm::FunctionType::get(llvm::Type::getInt1Ty(M.getContext()), false);
   FunctionType *FT = llvm::FunctionType::get(llvm::Type::getInt1Ty(M.getContext()), false);
-  Function *func = cast<llvm::Function>(M.getOrInsertFunction(kDxBreakFuncName, FT));
+  Function *func = cast<llvm::Function>(M.getOrInsertFunction(DXIL::kDxBreakFuncName, FT));
   func->addFnAttr(Attribute::AttrKind::NoUnwind);
   func->addFnAttr(Attribute::AttrKind::NoUnwind);
 
 
   for(llvm::BranchInst *BI : DxBreaks) {
   for(llvm::BranchInst *BI : DxBreaks) {

+ 48 - 0
tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/reduction/WaveBreakAndUnrollCS.hlsl

@@ -0,0 +1,48 @@
+// RUN: %dxc -T cs_6_0 %s | FileCheck %s
+// A test of explicit loop unrolling on a loop that uses a wave op in a break block
+
+// CHECK: void @main
+// CHECK: @dx.op.cbufferLoadLegacy.f32
+// CHECK: @dx.op.waveActiveOp.f32
+// CHECK: @dx.op.cbufferLoadLegacy.i32
+// CHECK: br i1
+// CHECK: @dx.op.cbufferLoadLegacy.f32
+// CHECK: @dx.op.waveActiveOp.f32
+// CHECK: @dx.op.cbufferLoadLegacy.i32
+// CHECK: br i1
+// CHECK: @dx.op.cbufferLoadLegacy.f32
+// CHECK: @dx.op.waveActiveOp.f32
+// CHECK: @dx.op.cbufferLoadLegacy.i32
+// CHECK: br i1
+// CHECK: @dx.op.cbufferLoadLegacy.f32
+// CHECK: @dx.op.waveActiveOp.f32
+// CHECK: @dx.op.cbufferLoadLegacy.i32
+// CHECK: br i1
+// CHECK: @dx.op.cbufferLoadLegacy.f32
+// CHECK: @dx.op.waveActiveOp.f32
+// CHECK: @dx.op.cbufferLoadLegacy.i32
+// CHECK: br i1
+// CHECK: @dx.op.cbufferLoadLegacy.f32
+// CHECK: @dx.op.waveActiveOp.f32
+// CHECK-NOT: @dx.op.waveActiveOp.f32
+// CHECK: br
+// CHECK: ret void
+
+RWStructuredBuffer<float> u0;
+uint C;
+float f;
+[numthreads(64,1,1)]
+void main(uint GI : SV_GroupIndex)
+{
+    float r = 0;
+    [unroll]
+    for (int i = 0; i < C && i < 64; ++i) {
+        r += WaveActiveSum(f);
+        if (i > 4) {
+          r *= 2;
+          break;
+        }
+    }
+
+    u0[GI] = r;
+}