Jelajahi Sumber

Merge DivergenceAnalysis from llvm3.8. (#528)

Xiang Li 8 tahun lalu
induk
melakukan
e5c0e5ffaa

+ 48 - 0
include/llvm/Analysis/DivergenceAnalysis.h

@@ -0,0 +1,48 @@
+//===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// The divergence analysis is an LLVM pass which can be used to find out
+// if a branch instruction in a GPU program is divergent or not. It can help
+// branch optimizations such as jump threading and loop unswitching to make
+// better decisions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+class Value;
+class DivergenceAnalysis : public FunctionPass {
+public:
+  static char ID;
+
+  DivergenceAnalysis() : FunctionPass(ID) {
+    initializeDivergenceAnalysisPass(*PassRegistry::getPassRegistry());
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+  bool runOnFunction(Function &F) override;
+
+  // Print all divergent branches in the function.
+  void print(raw_ostream &OS, const Module *) const override;
+
+  // Returns true if V is divergent.
+  bool isDivergent(const Value *V) const { return DivergentValues.count(V); }
+
+  // Returns true if V is uniform/non-divergent.
+  bool isUniform(const Value *V) const { return !isDivergent(V); }
+
+private:
+  // Stores all divergent values.
+  DenseSet<const Value *> DivergentValues;
+};
+} // End llvm namespace

+ 47 - 63
lib/Analysis/DivergenceAnalysis.cpp

@@ -1,4 +1,4 @@
-//===- DivergenceAnalysis.cpp ------ Divergence Analysis ------------------===//
+//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
 //
 //
 //                     The LLVM Compiler Infrastructure
 //                     The LLVM Compiler Infrastructure
 //
 //
@@ -7,8 +7,8 @@
 //
 //
 //===----------------------------------------------------------------------===//
 //===----------------------------------------------------------------------===//
 //
 //
-// This file defines divergence analysis which determines whether a branch in a
-// GPU program is divergent. It can help branch optimizations such as jump
+// This file implements divergence analysis which determines whether a branch
+// in a GPU program is divergent.It can help branch optimizations such as jump
 // threading and loop unswitching to make better decisions.
 // threading and loop unswitching to make better decisions.
 //
 //
 // GPU programs typically use the SIMD execution model, where multiple threads
 // GPU programs typically use the SIMD execution model, where multiple threads
@@ -61,75 +61,31 @@
 // 2. memory as black box. It conservatively considers values loaded from
 // 2. memory as black box. It conservatively considers values loaded from
 //    generic or local address as divergent. This can be improved by leveraging
 //    generic or local address as divergent. This can be improved by leveraging
 //    pointer analysis.
 //    pointer analysis.
+//
 //===----------------------------------------------------------------------===//
 //===----------------------------------------------------------------------===//
 
 
-#include <vector>
-#include "llvm/IR/Dominators.h"
-#include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/DivergenceAnalysis.h"
 #include "llvm/Analysis/Passes.h"
 #include "llvm/Analysis/Passes.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/IR/Function.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Value.h"
 #include "llvm/IR/Value.h"
-#include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Scalar.h"
+#include <vector>
 using namespace llvm;
 using namespace llvm;
 
 
-#define DEBUG_TYPE "divergence"
-
-namespace {
-class DivergenceAnalysis : public FunctionPass {
-public:
-  static char ID;
-
-  DivergenceAnalysis() : FunctionPass(ID) {
-    initializeDivergenceAnalysisPass(*PassRegistry::getPassRegistry());
-  }
-
-  void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addRequired<DominatorTreeWrapperPass>();
-    AU.addRequired<PostDominatorTree>();
-    AU.setPreservesAll();
-  }
-
-  bool runOnFunction(Function &F) override;
-
-  // Print all divergent branches in the function.
-  void print(raw_ostream &OS, const Module *) const override;
-
-  // Returns true if V is divergent.
-  bool isDivergent(const Value *V) const { return DivergentValues.count(V); }
-  // Returns true if V is uniform/non-divergent.
-  bool isUniform(const Value *V) const { return !isDivergent(V); }
-
-private:
-  // Stores all divergent values.
-  DenseSet<const Value *> DivergentValues;
-};
-} // End of anonymous namespace
-
-// Register this pass.
-char DivergenceAnalysis::ID = 0;
-INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis",
-                      false, true)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTree)
-INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis",
-                    false, true)
-
 namespace {
 namespace {
 
 
 class DivergencePropagator {
 class DivergencePropagator {
 public:
 public:
-  DivergencePropagator(Function &F, TargetTransformInfo &TTI,
-                       DominatorTree &DT, PostDominatorTree &PDT,
-                       DenseSet<const Value *> &DV)
+  DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT,
+                       PostDominatorTree &PDT, DenseSet<const Value *> &DV)
       : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {}
       : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {}
   void populateWithSourcesOfDivergence();
   void populateWithSourcesOfDivergence();
   void propagate();
   void propagate();
@@ -140,7 +96,7 @@ private:
   // A helper function that explores sync dependents of TI.
   // A helper function that explores sync dependents of TI.
   void exploreSyncDependency(TerminatorInst *TI);
   void exploreSyncDependency(TerminatorInst *TI);
   // Computes the influence region from Start to End. This region includes all
   // Computes the influence region from Start to End. This region includes all
-  // basic blocks on any path from Start to End.
+  // basic blocks on any simple path from Start to End.
   void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End,
   void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End,
                               DenseSet<BasicBlock *> &InfluenceRegion);
                               DenseSet<BasicBlock *> &InfluenceRegion);
   // Finds all users of I that are outside the influence region, and add these
   // Finds all users of I that are outside the influence region, and add these
@@ -153,7 +109,7 @@ private:
   DominatorTree &DT;
   DominatorTree &DT;
   PostDominatorTree &PDT;
   PostDominatorTree &PDT;
   std::vector<Value *> Worklist; // Stack for DFS.
   std::vector<Value *> Worklist; // Stack for DFS.
-  DenseSet<const Value *> &DV; // Stores all divergent values.
+  DenseSet<const Value *> &DV;   // Stores all divergent values.
 };
 };
 
 
 void DivergencePropagator::populateWithSourcesOfDivergence() {
 void DivergencePropagator::populateWithSourcesOfDivergence() {
@@ -165,6 +121,7 @@ void DivergencePropagator::populateWithSourcesOfDivergence() {
       DV.insert(&I);
       DV.insert(&I);
     }
     }
   }
   }
+
   for (auto &Arg : F.args()) {
   for (auto &Arg : F.args()) {
     if (TTI.isSourceOfDivergence(&Arg)) {
     if (TTI.isSourceOfDivergence(&Arg)) {
       Worklist.push_back(&Arg);
       Worklist.push_back(&Arg);
@@ -191,8 +148,8 @@ void DivergencePropagator::exploreSyncDependency(TerminatorInst *TI) {
   for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) {
   for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) {
     // A PHINode is uniform if it returns the same value no matter which path is
     // A PHINode is uniform if it returns the same value no matter which path is
     // taken.
     // taken.
-    if (!cast<PHINode>(I)->hasConstantValue() && DV.insert(I).second)
-      Worklist.push_back(I);
+    if (!cast<PHINode>(I)->hasConstantValue() && DV.insert(&*I).second)
+      Worklist.push_back(&*I);
   }
   }
 
 
   // Propagation rule 2: if a value defined in a loop is used outside, the user
   // Propagation rule 2: if a value defined in a loop is used outside, the user
@@ -242,21 +199,33 @@ void DivergencePropagator::findUsersOutsideInfluenceRegion(
   }
   }
 }
 }
 
 
+// A helper function for computeInfluenceRegion that adds successors of "ThisBB"
+// to the influence region.
+static void
+addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End,
+                               DenseSet<BasicBlock *> &InfluenceRegion,
+                               std::vector<BasicBlock *> &InfluenceStack) {
+  for (BasicBlock *Succ : successors(ThisBB)) {
+    if (Succ != End && InfluenceRegion.insert(Succ).second)
+      InfluenceStack.push_back(Succ);
+  }
+}
+
 void DivergencePropagator::computeInfluenceRegion(
 void DivergencePropagator::computeInfluenceRegion(
     BasicBlock *Start, BasicBlock *End,
     BasicBlock *Start, BasicBlock *End,
     DenseSet<BasicBlock *> &InfluenceRegion) {
     DenseSet<BasicBlock *> &InfluenceRegion) {
   assert(PDT.properlyDominates(End, Start) &&
   assert(PDT.properlyDominates(End, Start) &&
          "End does not properly dominate Start");
          "End does not properly dominate Start");
+
+  // The influence region starts from the end of "Start" to the beginning of
+  // "End". Therefore, "Start" should not be in the region unless "Start" is in
+  // a loop that doesn't contain "End".
   std::vector<BasicBlock *> InfluenceStack;
   std::vector<BasicBlock *> InfluenceStack;
-  InfluenceStack.push_back(Start);
-  InfluenceRegion.insert(Start);
+  addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack);
   while (!InfluenceStack.empty()) {
   while (!InfluenceStack.empty()) {
     BasicBlock *BB = InfluenceStack.back();
     BasicBlock *BB = InfluenceStack.back();
     InfluenceStack.pop_back();
     InfluenceStack.pop_back();
-    for (BasicBlock *Succ : successors(BB)) {
-      if (End != Succ && InfluenceRegion.insert(Succ).second)
-        InfluenceStack.push_back(Succ);
-    }
+    addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack);
   }
   }
 }
 }
 
 
@@ -286,10 +255,25 @@ void DivergencePropagator::propagate() {
 
 
 } /// end namespace anonymous
 } /// end namespace anonymous
 
 
+// Register this pass.
+char DivergenceAnalysis::ID = 0;
+INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis",
+                      false, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTree)
+INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis",
+                    false, true)
+
 FunctionPass *llvm::createDivergenceAnalysisPass() {
 FunctionPass *llvm::createDivergenceAnalysisPass() {
   return new DivergenceAnalysis();
   return new DivergenceAnalysis();
 }
 }
 
 
+void DivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addRequired<PostDominatorTree>();
+  AU.setPreservesAll();
+}
+
 bool DivergenceAnalysis::runOnFunction(Function &F) {
 bool DivergenceAnalysis::runOnFunction(Function &F) {
   auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
   auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
   if (TTIWP == nullptr)
   if (TTIWP == nullptr)

+ 2 - 0
lib/HLSL/CMakeLists.txt

@@ -33,6 +33,8 @@ add_llvm_library(LLVMHLSL
   DxilShaderModel.cpp
   DxilShaderModel.cpp
   DxilSignature.cpp
   DxilSignature.cpp
   DxilSignatureElement.cpp
   DxilSignatureElement.cpp
+  DxilTargetLowering.cpp
+  DxilTargetTransformInfo.cpp
   DxilTypeSystem.cpp
   DxilTypeSystem.cpp
   DxilUtil.cpp
   DxilUtil.cpp
   DxilValidation.cpp
   DxilValidation.cpp

+ 34 - 0
lib/HLSL/DxilTargetLowering.cpp

@@ -0,0 +1,34 @@
+//===-- DxilTargetLowering.cpp - Implement the DxilTargetLowering class ---===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// Empty implementation of TargetLoweringBase::InstructionOpcodeToISD and
+// TargetLoweringBase::getTypeLegalizationCost to make TargetTransformInfo
+// compile.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Target/TargetLowering.h"
+
+using namespace llvm;
+
+//===----------------------------------------------------------------------===//
+//  TargetTransformInfo Helpers
+//===----------------------------------------------------------------------===//
+
+int TargetLoweringBase::InstructionOpcodeToISD(unsigned Opcode) const {
+  return 0;
+}
+
+std::pair<unsigned, MVT>
+TargetLoweringBase::getTypeLegalizationCost(const DataLayout &DL,
+                                            Type *Ty) const {
+  EVT MTy = getValueType(DL, Ty);
+  unsigned Cost = 1;
+  return std::make_pair(Cost, MTy.getSimpleVT());
+}

+ 96 - 0
lib/HLSL/DxilTargetTransformInfo.cpp

@@ -0,0 +1,96 @@
+//===-- DxilTargetTransformInfo.cpp - DXIL specific TTI pass     ----------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// This file implements a TargetTransformInfo analysis pass specific to the
+// DXIL. Only implemented isSourceOfDivergence for DivergenceAnalysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "DxilTargetTransformInfo.h"
+#include "dxc/HLSL/DxilModule.h"
+#include "dxc/HLSL/DxilOperations.h"
+#include "llvm/CodeGen/BasicTTIImpl.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+#define DEBUG_TYPE "DXILtti"
+
+// For BasicTTImpl
+cl::opt<unsigned>
+    llvm::PartialUnrollingThreshold("partial-unrolling-threshold", cl::init(0),
+                                    cl::desc("Threshold for partial unrolling"),
+                                    cl::Hidden);
+
+DxilTTIImpl::DxilTTIImpl(const TargetMachine *TM, const Function &F,
+                         hlsl::DxilModule &DM, bool ThreadGroup)
+    : BaseT(TM, F.getParent()->getDataLayout()), m_pHlslOP(DM.GetOP()),
+      m_isThreadGroup(ThreadGroup) {}
+
+namespace {
+bool IsDxilOpSourceOfDivergence(const CallInst *CI, OP *hlslOP,
+                                bool ThreadGroup) {
+
+  DXIL::OpCode opcode = hlslOP->GetDxilOpFuncCallInst(CI);
+  switch (opcode) {
+  case DXIL::OpCode::AtomicBinOp:
+  case DXIL::OpCode::AtomicCompareExchange:
+  case DXIL::OpCode::LoadInput:
+  case DXIL::OpCode::BufferUpdateCounter:
+  case DXIL::OpCode::CycleCounterLegacy:
+  case DXIL::OpCode::DomainLocation:
+  case DXIL::OpCode::Coverage:
+  case DXIL::OpCode::EvalCentroid:
+  case DXIL::OpCode::EvalSampleIndex:
+  case DXIL::OpCode::EvalSnapped:
+  case DXIL::OpCode::FlattenedThreadIdInGroup:
+  case DXIL::OpCode::GSInstanceID:
+  case DXIL::OpCode::InnerCoverage:
+  case DXIL::OpCode::LoadOutputControlPoint:
+  case DXIL::OpCode::LoadPatchConstant:
+  case DXIL::OpCode::OutputControlPointID:
+  case DXIL::OpCode::PrimitiveID:
+  case DXIL::OpCode::RenderTargetGetSampleCount:
+  case DXIL::OpCode::RenderTargetGetSamplePosition:
+  case DXIL::OpCode::ThreadId:
+  case DXIL::OpCode::ThreadIdInGroup:
+    return true;
+  case DXIL::OpCode::GroupId:
+    return !ThreadGroup;
+  default:
+    return false;
+  }
+}
+}
+
+///
+/// \returns true if the result of the value could potentially be
+/// different across dispatch or thread group.
+bool DxilTTIImpl::isSourceOfDivergence(const Value *V) const {
+
+  if (const Argument *A = dyn_cast<Argument>(V))
+    return true;
+
+  // Atomics are divergent because they are executed sequentially: when an
+  // atomic operation refers to the same address in each thread, then each
+  // thread after the first sees the value written by the previous thread as
+  // original value.
+  if (isa<AtomicRMWInst>(V) || isa<AtomicCmpXchgInst>(V))
+    return true;
+
+  if (const CallInst *CI = dyn_cast<CallInst>(V)) {
+    // Assume none dxil instrincis function calls are a source of divergence.
+    if (!m_pHlslOP->IsDxilOpFuncCallInst(CI))
+      return true;
+    return IsDxilOpSourceOfDivergence(CI, m_pHlslOP, m_isThreadGroup);
+  }
+
+  return false;
+}

+ 43 - 0
lib/HLSL/DxilTargetTransformInfo.h

@@ -0,0 +1,43 @@
+//===-- DxilTargetTransformInfo.h - DXIL specific TTI -------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file declares a TargetTransformInfo analysis pass specific to the DXIL.
+/// Only implemented isSourceOfDivergence for DivergenceAnalysis.
+///
+//===----------------------------------------------------------------------===//
+
+#pragma once
+
+#include "llvm/CodeGen/BasicTTIImpl.h"
+
+namespace hlsl {
+class DxilModule;
+class OP;
+}
+
+namespace llvm {
+
+class DxilTTIImpl final : public BasicTTIImplBase<DxilTTIImpl> {
+  typedef BasicTTIImplBase<DxilTTIImpl> BaseT;
+  typedef TargetTransformInfo TTI;
+  friend BaseT;
+  hlsl::OP *m_pHlslOP;
+  bool m_isThreadGroup;
+  const TargetSubtargetInfo *getST() const { return nullptr; }
+  const TargetLowering *getTLI() const { return nullptr; }
+
+public:
+  explicit DxilTTIImpl(const TargetMachine *TM, const Function &F,
+                       hlsl::DxilModule &DM, bool ThreadGroup);
+
+  bool hasBranchDivergence() { return true; }
+  bool isSourceOfDivergence(const Value *V) const;
+};
+
+} // end namespace llvm