Kaynağa Gözat

Add a simple GVNHoist to reduce dxil code size. (#1897)

* Add a simple GVNHoist to reduce dxil code size.

* Add res_may_alias option.
Xiang Li 6 yıl önce
ebeveyn
işleme
27a7bf4b0b

+ 2 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -66,6 +66,7 @@ ModulePass *createDxilPromoteStaticResources();
 ModulePass *createDxilLegalizeResources();
 ModulePass *createDxilLegalizeEvalOperationsPass();
 FunctionPass *createDxilLegalizeSampleOffsetPass();
+FunctionPass *createDxilSimpleGVNHoistPass();
 ModulePass *createFailUndefResourcePass();
 FunctionPass *createSimplifyInstPass();
 ModulePass *createDxilTranslateRawBuffer();
@@ -96,6 +97,7 @@ void initializeDxilPromoteStaticResourcesPass(llvm::PassRegistry&);
 void initializeDxilLegalizeResourcesPass(llvm::PassRegistry&);
 void initializeDxilLegalizeEvalOperationsPass(llvm::PassRegistry&);
 void initializeDxilLegalizeSampleOffsetPassPass(llvm::PassRegistry&);
+void initializeDxilSimpleGVNHoistPass(llvm::PassRegistry&);
 void initializeFailUndefResourcePass(llvm::PassRegistry&);
 void initializeSimplifyInstPass(llvm::PassRegistry&);
 void initializeDxilTranslateRawBufferPass(llvm::PassRegistry&);

+ 1 - 0
include/dxc/Support/HLSLOptions.h

@@ -162,6 +162,7 @@ public:
   bool LegacyResourceReservation = false; // OPT_flegacy_resource_reservation
   unsigned long AutoBindingSpace = UINT_MAX; // OPT_auto_binding_space
   bool ExportShadersOnly = false; // OPT_export_shaders_only
+  bool ResMayAlias = false; // OPT_res_may_alias
 
   bool IsRootSignatureProfile();
   bool IsLibraryProfile();

+ 2 - 2
include/dxc/Support/HLSLOptions.td

@@ -365,11 +365,11 @@ def mergeUAVs : JoinedOrSeparate<["-", "/"], "mergeUAVs">, MetaVarName<"<file>">
   HelpText<"Merge UAV slots of template shader and current shader">;
 def matchUAVs : JoinedOrSeparate<["-", "/"], "matchUAVs">, MetaVarName<"<file>">, Group<hlslcomp_Group>,
   HelpText<"Match template shader UAV slots in current shader">;
-def res_may_alias : Flag<["-", "/"], "res_may_alias">, Flags<[CoreOption]>, Group<hlslcomp_Group>,
-  HelpText<"Assume that UAVs/SRVs may alias">;
 def enable_unbounded_descriptor_tables : Flag<["-", "/"], "enable_unbounded_descriptor_tables">, Flags<[CoreOption]>, Group<hlslcomp_Group>,
   HelpText<"Enables unbounded descriptor tables">;
 */
+def res_may_alias : Flag<["-", "/"], "res_may_alias">, Flags<[CoreOption]>, Group<hlslcomp_Group>,
+  HelpText<"Assume that UAVs/SRVs may alias">;
 def all_resources_bound : Flag<["-", "/"], "all_resources_bound">, Flags<[CoreOption]>, Group<hlslcomp_Group>,
   HelpText<"Enables agressive flattening">;
 

+ 1 - 0
include/llvm/Transforms/IPO/PassManagerBuilder.h

@@ -128,6 +128,7 @@ public:
   bool PrepareForLTO;
   bool HLSLHighLevel = false; // HLSL Change
   hlsl::HLSLExtensionsCodegenHelper *HLSLExtensionsCodeGen = nullptr; // HLSL Change
+  bool HLSLResMayAlias = false; // HLSL Change
 
 private:
   /// ExtensionList - This is list of all of the extensions that are registered.

+ 1 - 0
lib/DxcSupport/HLSLOptions.cpp

@@ -544,6 +544,7 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
   opts.LegacyMacroExpansion = Args.hasFlag(OPT_flegacy_macro_expansion, OPT_INVALID, false);
   opts.LegacyResourceReservation = Args.hasFlag(OPT_flegacy_resource_reservation, OPT_INVALID, false);
   opts.ExportShadersOnly = Args.hasFlag(OPT_export_shaders_only, OPT_INVALID, false);
+  opts.ResMayAlias = Args.hasFlag(OPT_res_may_alias, OPT_INVALID, false);
 
   if (opts.DefaultColMajor && opts.DefaultRowMajor) {
     errors << "Cannot specify /Zpr and /Zpc together, use /? to get usage information";

+ 1 - 0
lib/HLSL/CMakeLists.txt

@@ -16,6 +16,7 @@ add_llvm_library(LLVMHLSL
   DxilPackSignatureElement.cpp
   DxilPatchShaderRecordBindings.cpp
   DxilPreserveAllOutputs.cpp
+  DxilSimpleGVNHoist.cpp
   DxilSignatureValidation.cpp
   DxilTargetLowering.cpp
   DxilTargetTransformInfo.cpp

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -104,6 +104,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeDxilPreserveAllOutputsPass(Registry);
     initializeDxilPromoteLocalResourcesPass(Registry);
     initializeDxilPromoteStaticResourcesPass(Registry);
+    initializeDxilSimpleGVNHoistPass(Registry);
     initializeDxilTranslateRawBufferPass(Registry);
     initializeDynamicIndexingVectorToArrayPass(Registry);
     initializeEarlyCSELegacyPassPass(Registry);

+ 566 - 0
lib/HLSL/DxilSimpleGVNHoist.cpp

@@ -0,0 +1,566 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// DxilSimpleGVNHoist.cpp                                                    //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// A simple version of GVN hoist for DXIL.                                   //
+// Based on GVNHoist in LLVM 6.0.                                            //                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/HLSL/DxilGenerationPass.h"
+#include "dxc/DXIL/DxilOperations.h"
+
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/CFG.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+///////////////////////////////////////////////////////////////////////////////
+namespace {
+struct Expression {
+  uint32_t opcode;
+  Type *type;
+  bool commutative = false;
+  SmallVector<uint32_t, 4> varargs;
+
+  Expression(uint32_t o = ~2U) : opcode(o) {}
+
+  bool operator==(const Expression &other) const {
+    if (opcode != other.opcode)
+      return false;
+    if (opcode == ~0U || opcode == ~1U)
+      return true;
+    if (type != other.type)
+      return false;
+    if (varargs != other.varargs)
+      return false;
+    return true;
+  }
+
+  friend hash_code hash_value(const Expression &Value) {
+    return hash_combine(
+        Value.opcode, Value.type,
+        hash_combine_range(Value.varargs.begin(), Value.varargs.end()));
+  }
+};
+
+}
+
+namespace llvm {
+template <> struct DenseMapInfo<Expression> {
+  static inline Expression getEmptyKey() { return ~0U; }
+  static inline Expression getTombstoneKey() { return ~1U; }
+
+  static unsigned getHashValue(const Expression &e) {
+    using llvm::hash_value;
+
+    return static_cast<unsigned>(hash_value(e));
+  }
+
+  static bool isEqual(const Expression &LHS, const Expression &RHS) {
+    return LHS == RHS;
+  }
+};
+} // namespace llvm
+
+namespace {
+// Simple Value table which support DXIL operation.
+class ValueTable {
+  DenseMap<Value *, uint32_t> valueNumbering;
+  DenseMap<Expression, uint32_t> expressionNumbering;
+
+  // Expressions is the vector of Expression. ExprIdx is the mapping from
+  // value number to the index of Expression in Expressions. We use it
+  // instead of a DenseMap because filling such mapping is faster than
+  // filling a DenseMap and the compile time is a little better.
+  uint32_t nextExprNumber;
+
+  std::vector<Expression> Expressions;
+  std::vector<uint32_t> ExprIdx;
+
+  DominatorTree *DT;
+
+  uint32_t nextValueNumber = 1;
+
+  Expression createExpr(Instruction *I);
+  Expression createCmpExpr(unsigned Opcode, CmpInst::Predicate Predicate,
+                           Value *LHS, Value *RHS);
+  Expression createExtractvalueExpr(ExtractValueInst *EI);
+  uint32_t lookupOrAddCall(CallInst *C);
+
+  std::pair<uint32_t, bool> assignExpNewValueNum(Expression &exp);
+
+public:
+  ValueTable();
+  ValueTable(const ValueTable &Arg);
+  ValueTable(ValueTable &&Arg);
+  ~ValueTable();
+
+  uint32_t lookupOrAdd(Value *V);
+  uint32_t lookup(Value *V, bool Verify = true) const;
+  uint32_t lookupOrAddCmp(unsigned Opcode, CmpInst::Predicate Pred, Value *LHS,
+                          Value *RHS);
+  bool exists(Value *V) const;
+  void add(Value *V, uint32_t num);
+  void clear();
+  void erase(Value *v);
+  void setDomTree(DominatorTree *D) { DT = D; }
+  uint32_t getNextUnusedValueNumber() { return nextValueNumber; }
+  void verifyRemoved(const Value *) const;
+};
+
+//===----------------------------------------------------------------------===//
+//                     ValueTable Internal Functions
+//===----------------------------------------------------------------------===//
+
+Expression ValueTable::createExpr(Instruction *I) {
+  Expression e;
+  e.type = I->getType();
+  e.opcode = I->getOpcode();
+  for (Instruction::op_iterator OI = I->op_begin(), OE = I->op_end();
+       OI != OE; ++OI)
+    e.varargs.push_back(lookupOrAdd(*OI));
+  if (I->isCommutative()) {
+    // Ensure that commutative instructions that only differ by a permutation
+    // of their operands get the same value number by sorting the operand value
+    // numbers.  Since all commutative instructions have two operands it is more
+    // efficient to sort by hand rather than using, say, std::sort.
+    assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!");
+    if (e.varargs[0] > e.varargs[1])
+      std::swap(e.varargs[0], e.varargs[1]);
+    e.commutative = true;
+  }
+
+  if (CmpInst *C = dyn_cast<CmpInst>(I)) {
+    // Sort the operand value numbers so x<y and y>x get the same value number.
+    CmpInst::Predicate Predicate = C->getPredicate();
+    if (e.varargs[0] > e.varargs[1]) {
+      std::swap(e.varargs[0], e.varargs[1]);
+Predicate = CmpInst::getSwappedPredicate(Predicate);
+    }
+    e.opcode = (C->getOpcode() << 8) | Predicate;
+    e.commutative = true;
+  }
+ else if (InsertValueInst *E = dyn_cast<InsertValueInst>(I)) {
+ for (InsertValueInst::idx_iterator II = E->idx_begin(), IE = E->idx_end();
+     II != IE; ++II)
+     e.varargs.push_back(*II);
+  }
+
+  return e;
+}
+
+Expression ValueTable::createCmpExpr(unsigned Opcode,
+    CmpInst::Predicate Predicate,
+    Value *LHS, Value *RHS) {
+    assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
+        "Not a comparison!");
+    Expression e;
+    e.type = CmpInst::makeCmpResultType(LHS->getType());
+    e.varargs.push_back(lookupOrAdd(LHS));
+    e.varargs.push_back(lookupOrAdd(RHS));
+
+    // Sort the operand value numbers so x<y and y>x get the same value number.
+    if (e.varargs[0] > e.varargs[1]) {
+        std::swap(e.varargs[0], e.varargs[1]);
+        Predicate = CmpInst::getSwappedPredicate(Predicate);
+    }
+    e.opcode = (Opcode << 8) | Predicate;
+    e.commutative = true;
+    return e;
+}
+
+Expression ValueTable::createExtractvalueExpr(ExtractValueInst *EI) {
+    assert(EI && "Not an ExtractValueInst?");
+    Expression e;
+    e.type = EI->getType();
+    e.opcode = 0;
+
+    IntrinsicInst *I = dyn_cast<IntrinsicInst>(EI->getAggregateOperand());
+    if (I != nullptr && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) {
+        // EI might be an extract from one of our recognised intrinsics. If it
+        // is we'll synthesize a semantically equivalent expression instead on
+        // an extract value expression.
+        switch (I->getIntrinsicID()) {
+        case Intrinsic::sadd_with_overflow:
+        case Intrinsic::uadd_with_overflow:
+            e.opcode = Instruction::Add;
+            break;
+        case Intrinsic::ssub_with_overflow:
+        case Intrinsic::usub_with_overflow:
+            e.opcode = Instruction::Sub;
+            break;
+        case Intrinsic::smul_with_overflow:
+        case Intrinsic::umul_with_overflow:
+            e.opcode = Instruction::Mul;
+            break;
+        default:
+            break;
+        }
+
+        if (e.opcode != 0) {
+            // Intrinsic recognized. Grab its args to finish building the expression.
+            assert(I->getNumArgOperands() == 2 &&
+                "Expect two args for recognised intrinsics.");
+            e.varargs.push_back(lookupOrAdd(I->getArgOperand(0)));
+            e.varargs.push_back(lookupOrAdd(I->getArgOperand(1)));
+            return e;
+        }
+    }
+
+    // Not a recognised intrinsic. Fall back to producing an extract value
+    // expression.
+    e.opcode = EI->getOpcode();
+    for (Instruction::op_iterator OI = EI->op_begin(), OE = EI->op_end();
+        OI != OE; ++OI)
+        e.varargs.push_back(lookupOrAdd(*OI));
+
+    for (ExtractValueInst::idx_iterator II = EI->idx_begin(), IE = EI->idx_end();
+        II != IE; ++II)
+        e.varargs.push_back(*II);
+
+    return e;
+}
+
+//===----------------------------------------------------------------------===//
+//                     ValueTable External Functions
+//===----------------------------------------------------------------------===//
+
+ValueTable::ValueTable() = default;
+ValueTable::ValueTable(const ValueTable &) = default;
+ValueTable::ValueTable(ValueTable &&) = default;
+ValueTable::~ValueTable() = default;
+
+/// add - Insert a value into the table with a specified value number.
+void ValueTable::add(Value *V, uint32_t num) {
+    valueNumbering.insert(std::make_pair(V, num));
+}
+
+uint32_t ValueTable::lookupOrAddCall(CallInst *C) {
+  Function *F = C->getCalledFunction();
+  bool bSafe = false;
+  if (F->hasFnAttribute(Attribute::ReadNone)) {
+    bSafe = true;
+  } else if (F->hasFnAttribute(Attribute::ReadOnly)) {
+    if (hlsl::OP::IsDxilOpFunc(F)) {
+      DXIL::OpCode Opcode = hlsl::OP::GetDxilOpFuncCallInst(C);
+      switch (Opcode) {
+      default:
+        break;
+        // TODO: make buffer/texture load on srv safe.
+      case DXIL::OpCode::CreateHandleForLib:
+      case DXIL::OpCode::CBufferLoad:
+      case DXIL::OpCode::CBufferLoadLegacy:
+      case DXIL::OpCode::Sample:
+      case DXIL::OpCode::SampleBias:
+      case DXIL::OpCode::SampleCmp:
+      case DXIL::OpCode::SampleCmpLevelZero:
+      case DXIL::OpCode::SampleGrad:
+      case DXIL::OpCode::CheckAccessFullyMapped:
+      case DXIL::OpCode::GetDimensions:
+      case DXIL::OpCode::TextureGather:
+      case DXIL::OpCode::TextureGatherCmp:
+      case DXIL::OpCode::Texture2DMSGetSamplePosition:
+      case DXIL::OpCode::RenderTargetGetSampleCount:
+      case DXIL::OpCode::RenderTargetGetSamplePosition:
+      case DXIL::OpCode::CalculateLOD:
+        bSafe = true;
+        break;
+      }
+    }
+  }
+  if (bSafe) {
+    Expression exp = createExpr(C);
+    uint32_t e = assignExpNewValueNum(exp).first;
+    valueNumbering[C] = e;
+    return e;
+  } else {
+    // Not sure safe or not, always use new value number.
+    valueNumbering[C] = nextValueNumber;
+    return nextValueNumber++;
+  }
+}
+
+/// Returns true if a value number exists for the specified value.
+bool ValueTable::exists(Value *V) const { return valueNumbering.count(V) != 0; }
+
+/// lookup_or_add - Returns the value number for the specified value, assigning
+/// it a new number if it did not have one before.
+uint32_t ValueTable::lookupOrAdd(Value *V) {
+  DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V);
+  if (VI != valueNumbering.end())
+    return VI->second;
+
+  if (!isa<Instruction>(V)) {
+    valueNumbering[V] = nextValueNumber;
+    return nextValueNumber++;
+  }
+
+  Instruction* I = cast<Instruction>(V);
+  Expression exp;
+  switch (I->getOpcode()) {
+    case Instruction::Call:
+      return lookupOrAddCall(cast<CallInst>(I));
+    case Instruction::Add:
+    case Instruction::FAdd:
+    case Instruction::Sub:
+    case Instruction::FSub:
+    case Instruction::Mul:
+    case Instruction::FMul:
+    case Instruction::UDiv:
+    case Instruction::SDiv:
+    case Instruction::FDiv:
+    case Instruction::URem:
+    case Instruction::SRem:
+    case Instruction::FRem:
+    case Instruction::Shl:
+    case Instruction::LShr:
+    case Instruction::AShr:
+    case Instruction::And:
+    case Instruction::Or:
+    case Instruction::Xor:
+    case Instruction::ICmp:
+    case Instruction::FCmp:
+    case Instruction::Trunc:
+    case Instruction::ZExt:
+    case Instruction::SExt:
+    case Instruction::FPToUI:
+    case Instruction::FPToSI:
+    case Instruction::UIToFP:
+    case Instruction::SIToFP:
+    case Instruction::FPTrunc:
+    case Instruction::FPExt:
+    case Instruction::PtrToInt:
+    case Instruction::IntToPtr:
+    case Instruction::BitCast:
+    case Instruction::Select:
+    case Instruction::ExtractElement:
+    case Instruction::InsertElement:
+    case Instruction::ShuffleVector:
+    case Instruction::InsertValue:
+    case Instruction::GetElementPtr:
+      exp = createExpr(I);
+      break;
+    case Instruction::ExtractValue:
+      exp = createExtractvalueExpr(cast<ExtractValueInst>(I));
+      break;
+    case Instruction::PHI:
+      valueNumbering[V] = nextValueNumber;
+      return nextValueNumber++;
+    default:
+      valueNumbering[V] = nextValueNumber;
+      return nextValueNumber++;
+  }
+
+  uint32_t e = assignExpNewValueNum(exp).first;
+  valueNumbering[V] = e;
+  return e;
+}
+
+/// Returns the value number of the specified value. Fails if
+/// the value has not yet been numbered.
+uint32_t ValueTable::lookup(Value *V, bool Verify) const {
+  DenseMap<Value*, uint32_t>::const_iterator VI = valueNumbering.find(V);
+  if (Verify) {
+    assert(VI != valueNumbering.end() && "Value not numbered?");
+    return VI->second;
+  }
+  return (VI != valueNumbering.end()) ? VI->second : 0;
+}
+
+/// Returns the value number of the given comparison,
+/// assigning it a new number if it did not have one before.  Useful when
+/// we deduced the result of a comparison, but don't immediately have an
+/// instruction realizing that comparison to hand.
+uint32_t ValueTable::lookupOrAddCmp(unsigned Opcode,
+                                         CmpInst::Predicate Predicate,
+                                         Value *LHS, Value *RHS) {
+  Expression exp = createCmpExpr(Opcode, Predicate, LHS, RHS);
+  return assignExpNewValueNum(exp).first;
+}
+
+/// Remove all entries from the ValueTable.
+void ValueTable::clear() {
+  valueNumbering.clear();
+  expressionNumbering.clear();
+  nextValueNumber = 1;
+  Expressions.clear();
+  ExprIdx.clear();
+  nextExprNumber = 0;
+}
+
+/// Remove a value from the value numbering.
+void ValueTable::erase(Value *V) {
+  valueNumbering.erase(V);
+}
+
+/// verifyRemoved - Verify that the value is removed from all internal data
+/// structures.
+void ValueTable::verifyRemoved(const Value *V) const {
+  for (DenseMap<Value*, uint32_t>::const_iterator
+         I = valueNumbering.begin(), E = valueNumbering.end(); I != E; ++I) {
+    assert(I->first != V && "Inst still occurs in value numbering map!");
+  }
+}
+
+/// Return a pair the first field showing the value number of \p Exp and the
+/// second field showing whether it is a value number newly created.
+std::pair<uint32_t, bool>
+ValueTable::assignExpNewValueNum(Expression &Exp) {
+  uint32_t &e = expressionNumbering[Exp];
+  bool CreateNewValNum = !e;
+  if (CreateNewValNum) {
+    Expressions.push_back(Exp);
+    if (ExprIdx.size() < nextValueNumber + 1)
+      ExprIdx.resize(nextValueNumber * 2);
+    e = nextValueNumber;
+    ExprIdx[nextValueNumber++] = nextExprNumber++;
+  }
+  return {e, CreateNewValNum};
+}
+
+} // namespace
+
+namespace {
+// Reduce code size for pattern like this:
+// if (a.x > 0) {
+//  r = tex.Sample(ss, uv)-1;
+// } else {
+//  if (a.y > 0)
+//    r = tex.Sample(ss, uv);
+//  else
+//    r = tex.Sample(ss, uv) + 3;
+// }
+class DxilSimpleGVNHoist : public FunctionPass {
+
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit DxilSimpleGVNHoist() : FunctionPass(ID) {}
+
+  const char *getPassName() const override {
+    return "DXIL simple GVN hoist";
+  }
+
+  bool runOnFunction(Function &F) override;
+
+private:
+  bool tryToHoist(BasicBlock *BB, BasicBlock *Succ0, BasicBlock *Succ1);
+};
+
+char DxilSimpleGVNHoist::ID = 0;
+
+bool HasOnePred(BasicBlock *BB) {
+  if (pred_empty(BB))
+    return false;
+
+  auto pred = pred_begin(BB);
+  pred++;
+  if (pred != pred_end(BB))
+    return false;
+  return true;
+}
+
+bool DxilSimpleGVNHoist::tryToHoist(BasicBlock *BB, BasicBlock *Succ0,
+                                    BasicBlock *Succ1) {
+  // ValueNumber Succ0 and Succ1.
+  ValueTable VT;
+  DenseMap<uint32_t, SmallVector<Instruction *, 2>> VNtoInsts;
+  for (Instruction &I : *Succ0) {
+    uint32_t V = VT.lookupOrAdd(&I);
+    VNtoInsts[V].emplace_back(&I);
+  }
+
+  std::vector<uint32_t> HoistCandidateVN;
+
+  for (Instruction &I : *Succ1) {
+    uint32_t V = VT.lookupOrAdd(&I);
+    if (!VNtoInsts.count(V))
+      continue;
+    VNtoInsts[V].emplace_back(&I);
+    HoistCandidateVN.emplace_back(V);
+  }
+
+  if (HoistCandidateVN.empty()) {
+    return false;
+  }
+
+  DenseSet<uint32_t> ProcessedVN;
+  Instruction *TI = BB->getTerminator();
+  // Hoist need to be in order, so operand could hoist before its users.
+  for (uint32_t VN : HoistCandidateVN) {
+    // Skip processed VN
+    if (ProcessedVN.count(VN))
+      continue;
+    ProcessedVN.insert(VN);
+
+    auto &Insts = VNtoInsts[VN];
+    if (Insts.size() == 1)
+      continue;
+    bool bHoist = false;
+    for (Instruction *I : Insts) {
+      if (I->getParent() == Succ1) {
+        bHoist = true;
+        break;
+      }
+    }
+
+    Instruction *FirstI = Insts.front();
+    if (bHoist) {
+      // Move FirstI to BB.
+      FirstI->removeFromParent();
+      FirstI->insertBefore(TI);
+    }
+    // Replace all insts with same value number with firstI.
+    auto it = Insts.begin();
+    it++;
+    for (; it != Insts.end(); it++) {
+      Instruction *I = *it;
+      I->replaceAllUsesWith(FirstI);
+      I->eraseFromParent();
+    }
+    Insts.clear();
+  }
+  return true;
+}
+
+bool DxilSimpleGVNHoist::runOnFunction(Function &F) {
+  BasicBlock &Entry = F.getEntryBlock();
+  bool bUpdated = false;
+  for (auto it = po_begin(&Entry); it != po_end(&Entry); it++) {
+    BasicBlock *BB = *it;
+    TerminatorInst *TI = BB->getTerminator();
+    if (TI->getNumSuccessors() != 2)
+      continue;
+    BasicBlock *Succ0 = TI->getSuccessor(0);
+    BasicBlock *Succ1 = TI->getSuccessor(1);
+    if (BB == Succ0)
+      continue;
+    if (BB == Succ1)
+      continue;
+
+    if (!HasOnePred(Succ0))
+      continue;
+    if (!HasOnePred(Succ1))
+      continue;
+    bUpdated |= tryToHoist(BB, Succ0, Succ1);
+  }
+  return bUpdated;
+}
+
+}
+
+FunctionPass *llvm::createDxilSimpleGVNHoistPass() {
+  return new DxilSimpleGVNHoist();
+}
+
+INITIALIZE_PASS(DxilSimpleGVNHoist, "dxil-gvn-hoist",
+                "DXIL simple gvn hoist", false, false)

+ 2 - 0
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -419,6 +419,8 @@ void PassManagerBuilder::populateModulePassManager(
     if (EnableMLSM)
       MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds
     MPM.add(createGVNPass(DisableGVNLoadPRE));  // Remove redundancies
+    if (!HLSLResMayAlias)
+      MPM.add(createDxilSimpleGVNHoistPass()); // HLSL Change - GVN hoist for code size.
   }
   // HLSL Change Begins.
   // HLSL don't allow memcpy and memset.

+ 2 - 0
tools/clang/include/clang/Frontend/CodeGenOptions.h

@@ -212,6 +212,8 @@ public:
   /// DefaultLinkage Internal, External, or Default.  If Default, default
   /// function linkage is determined by library target.
   hlsl::DXIL::DefaultLinkage DefaultLinkage = hlsl::DXIL::DefaultLinkage::Default;
+  /// Assume UAVs/SRVs may alias.
+  bool HLSLResMayAlias = false;
   // HLSL Change Ends
 
   // SPIRV Change Starts

+ 1 - 0
tools/clang/lib/CodeGen/BackendUtil.cpp

@@ -324,6 +324,7 @@ void EmitAssemblyHelper::CreatePasses() {
   PMBuilder.LoopVectorize = CodeGenOpts.VectorizeLoop;
   PMBuilder.HLSLHighLevel = CodeGenOpts.HLSLHighLevel; // HLSL Change
   PMBuilder.HLSLExtensionsCodeGen = CodeGenOpts.HLSLExtensionsCodegen.get(); // HLSL Change
+  PMBuilder.HLSLResMayAlias = CodeGenOpts.HLSLResMayAlias; // HLSL Change
 
   PMBuilder.DisableUnitAtATime = !CodeGenOpts.UnitAtATime;
   PMBuilder.DisableUnrollLoops = !CodeGenOpts.UnrollLoops;

+ 23 - 0
tools/clang/test/CodeGenHLSL/quick-test/simple_gvn_hoist.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -T ps_6_0 -E main %s | FileCheck %s
+
+// CHECK: call %dx.types.ResRet.f32 @dx.op.sample.f32
+// Make sure only 1 sample exist.
+// CHECK-NOT:call %dx.types.ResRet.f32 @dx.op.sample.f32
+
+Texture2D<float4> tex;
+SamplerState ss;
+
+float4 main(float2 uv:UV, float2 a:A) : SV_Target {
+  float4 r = 0;
+  if (a.x > 0) {
+    r = tex.Sample(ss, uv)-1;
+  } else {
+    if (a.y > 0)
+      r = tex.Sample(ss, uv);
+    else
+      r = tex.Sample(ss, uv) + 3;
+  }
+
+  return r;
+
+}

+ 1 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -923,6 +923,7 @@ public:
       compiler.getCodeGenOpts().UnrollLoops = true;
 
     compiler.getCodeGenOpts().HLSLHighLevel = Opts.CodeGenHighLevel;
+    compiler.getCodeGenOpts().HLSLResMayAlias = Opts.ResMayAlias;
     compiler.getCodeGenOpts().HLSLAllResourcesBound = Opts.AllResourcesBound;
     compiler.getCodeGenOpts().HLSLDefaultRowMajor = Opts.DefaultRowMajor;
     compiler.getCodeGenOpts().HLSLPreferControlFlow = Opts.PreferFlowControl;

+ 1 - 0
utils/hct/hctdb.py

@@ -1564,6 +1564,7 @@ class db_dxil(object):
         add_pass('simplify-inst', 'SimplifyInst', 'Simplify Instructions', [])
         add_pass('hlsl-dxil-precise', 'DxilPrecisePropagatePass', 'DXIL precise attribute propagate', [])
         add_pass('dxil-legalize-sample-offset', 'DxilLegalizeSampleOffsetPass', 'DXIL legalize sample offset', [])
+        add_pass('dxil-gvn-hoist', 'DxilSimpleGVNHoist', 'DXIL simple gvn hoist', [])
         add_pass('hlsl-hlensure', 'HLEnsureMetadata', 'HLSL High-Level Metadata Ensure', [])
         add_pass('multi-dim-one-dim', 'MultiDimArrayToOneDimArray', 'Flatten multi-dim array into one-dim array', [])
         add_pass('resource-handle', 'ResourceToHandle', 'Lower resource into handle', [])