DxilSimpleGVNHoist.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilSimpleGVNHoist.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // A simple version of GVN hoist for DXIL. //
  9. // Based on GVNHoist in LLVM 6.0. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "llvm/IR/Dominators.h"
  14. #include "llvm/IR/Instructions.h"
  15. #include "llvm/ADT/PostOrderIterator.h"
  16. #include "llvm/ADT/DenseMapInfo.h"
  17. #include "llvm/ADT/DenseSet.h"
  18. #include "llvm/IR/IntrinsicInst.h"
  19. #include "llvm/IR/CFG.h"
  20. using namespace llvm;
  21. using namespace hlsl;
  22. ///////////////////////////////////////////////////////////////////////////////
  23. namespace {
  24. struct Expression {
  25. uint32_t opcode;
  26. Type *type;
  27. bool commutative = false;
  28. SmallVector<uint32_t, 4> varargs;
  29. Expression(uint32_t o = ~2U) : opcode(o) {}
  30. bool operator==(const Expression &other) const {
  31. if (opcode != other.opcode)
  32. return false;
  33. if (opcode == ~0U || opcode == ~1U)
  34. return true;
  35. if (type != other.type)
  36. return false;
  37. if (varargs != other.varargs)
  38. return false;
  39. return true;
  40. }
  41. friend hash_code hash_value(const Expression &Value) {
  42. return hash_combine(
  43. Value.opcode, Value.type,
  44. hash_combine_range(Value.varargs.begin(), Value.varargs.end()));
  45. }
  46. };
  47. }
  48. namespace llvm {
  49. template <> struct DenseMapInfo<Expression> {
  50. static inline Expression getEmptyKey() { return ~0U; }
  51. static inline Expression getTombstoneKey() { return ~1U; }
  52. static unsigned getHashValue(const Expression &e) {
  53. using llvm::hash_value;
  54. return static_cast<unsigned>(hash_value(e));
  55. }
  56. static bool isEqual(const Expression &LHS, const Expression &RHS) {
  57. return LHS == RHS;
  58. }
  59. };
  60. } // namespace llvm
  61. namespace {
  62. // Simple Value table which support DXIL operation.
  63. class ValueTable {
  64. DenseMap<Value *, uint32_t> valueNumbering;
  65. DenseMap<Expression, uint32_t> expressionNumbering;
  66. // Expressions is the vector of Expression. ExprIdx is the mapping from
  67. // value number to the index of Expression in Expressions. We use it
  68. // instead of a DenseMap because filling such mapping is faster than
  69. // filling a DenseMap and the compile time is a little better.
  70. uint32_t nextExprNumber;
  71. std::vector<Expression> Expressions;
  72. std::vector<uint32_t> ExprIdx;
  73. DominatorTree *DT;
  74. uint32_t nextValueNumber = 1;
  75. Expression createExpr(Instruction *I);
  76. Expression createCmpExpr(unsigned Opcode, CmpInst::Predicate Predicate,
  77. Value *LHS, Value *RHS);
  78. Expression createExtractvalueExpr(ExtractValueInst *EI);
  79. uint32_t lookupOrAddCall(CallInst *C);
  80. std::pair<uint32_t, bool> assignExpNewValueNum(Expression &exp);
  81. public:
  82. ValueTable();
  83. ValueTable(const ValueTable &Arg);
  84. ValueTable(ValueTable &&Arg);
  85. ~ValueTable();
  86. uint32_t lookupOrAdd(Value *V);
  87. uint32_t lookup(Value *V, bool Verify = true) const;
  88. uint32_t lookupOrAddCmp(unsigned Opcode, CmpInst::Predicate Pred, Value *LHS,
  89. Value *RHS);
  90. bool exists(Value *V) const;
  91. void add(Value *V, uint32_t num);
  92. void clear();
  93. void erase(Value *v);
  94. void setDomTree(DominatorTree *D) { DT = D; }
  95. uint32_t getNextUnusedValueNumber() { return nextValueNumber; }
  96. void verifyRemoved(const Value *) const;
  97. };
  98. //===----------------------------------------------------------------------===//
  99. // ValueTable Internal Functions
  100. //===----------------------------------------------------------------------===//
  101. Expression ValueTable::createExpr(Instruction *I) {
  102. Expression e;
  103. e.type = I->getType();
  104. e.opcode = I->getOpcode();
  105. for (Instruction::op_iterator OI = I->op_begin(), OE = I->op_end();
  106. OI != OE; ++OI)
  107. e.varargs.push_back(lookupOrAdd(*OI));
  108. if (I->isCommutative()) {
  109. // Ensure that commutative instructions that only differ by a permutation
  110. // of their operands get the same value number by sorting the operand value
  111. // numbers. Since all commutative instructions have two operands it is more
  112. // efficient to sort by hand rather than using, say, std::sort.
  113. assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!");
  114. if (e.varargs[0] > e.varargs[1])
  115. std::swap(e.varargs[0], e.varargs[1]);
  116. e.commutative = true;
  117. }
  118. if (CmpInst *C = dyn_cast<CmpInst>(I)) {
  119. // Sort the operand value numbers so x<y and y>x get the same value number.
  120. CmpInst::Predicate Predicate = C->getPredicate();
  121. if (e.varargs[0] > e.varargs[1]) {
  122. std::swap(e.varargs[0], e.varargs[1]);
  123. Predicate = CmpInst::getSwappedPredicate(Predicate);
  124. }
  125. e.opcode = (C->getOpcode() << 8) | Predicate;
  126. e.commutative = true;
  127. }
  128. else if (InsertValueInst *E = dyn_cast<InsertValueInst>(I)) {
  129. for (InsertValueInst::idx_iterator II = E->idx_begin(), IE = E->idx_end();
  130. II != IE; ++II)
  131. e.varargs.push_back(*II);
  132. }
  133. return e;
  134. }
  135. Expression ValueTable::createCmpExpr(unsigned Opcode,
  136. CmpInst::Predicate Predicate,
  137. Value *LHS, Value *RHS) {
  138. assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
  139. "Not a comparison!");
  140. Expression e;
  141. e.type = CmpInst::makeCmpResultType(LHS->getType());
  142. e.varargs.push_back(lookupOrAdd(LHS));
  143. e.varargs.push_back(lookupOrAdd(RHS));
  144. // Sort the operand value numbers so x<y and y>x get the same value number.
  145. if (e.varargs[0] > e.varargs[1]) {
  146. std::swap(e.varargs[0], e.varargs[1]);
  147. Predicate = CmpInst::getSwappedPredicate(Predicate);
  148. }
  149. e.opcode = (Opcode << 8) | Predicate;
  150. e.commutative = true;
  151. return e;
  152. }
  153. Expression ValueTable::createExtractvalueExpr(ExtractValueInst *EI) {
  154. assert(EI && "Not an ExtractValueInst?");
  155. Expression e;
  156. e.type = EI->getType();
  157. e.opcode = 0;
  158. IntrinsicInst *I = dyn_cast<IntrinsicInst>(EI->getAggregateOperand());
  159. if (I != nullptr && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) {
  160. // EI might be an extract from one of our recognised intrinsics. If it
  161. // is we'll synthesize a semantically equivalent expression instead on
  162. // an extract value expression.
  163. switch (I->getIntrinsicID()) {
  164. case Intrinsic::sadd_with_overflow:
  165. case Intrinsic::uadd_with_overflow:
  166. e.opcode = Instruction::Add;
  167. break;
  168. case Intrinsic::ssub_with_overflow:
  169. case Intrinsic::usub_with_overflow:
  170. e.opcode = Instruction::Sub;
  171. break;
  172. case Intrinsic::smul_with_overflow:
  173. case Intrinsic::umul_with_overflow:
  174. e.opcode = Instruction::Mul;
  175. break;
  176. default:
  177. break;
  178. }
  179. if (e.opcode != 0) {
  180. // Intrinsic recognized. Grab its args to finish building the expression.
  181. assert(I->getNumArgOperands() == 2 &&
  182. "Expect two args for recognised intrinsics.");
  183. e.varargs.push_back(lookupOrAdd(I->getArgOperand(0)));
  184. e.varargs.push_back(lookupOrAdd(I->getArgOperand(1)));
  185. return e;
  186. }
  187. }
  188. // Not a recognised intrinsic. Fall back to producing an extract value
  189. // expression.
  190. e.opcode = EI->getOpcode();
  191. for (Instruction::op_iterator OI = EI->op_begin(), OE = EI->op_end();
  192. OI != OE; ++OI)
  193. e.varargs.push_back(lookupOrAdd(*OI));
  194. for (ExtractValueInst::idx_iterator II = EI->idx_begin(), IE = EI->idx_end();
  195. II != IE; ++II)
  196. e.varargs.push_back(*II);
  197. return e;
  198. }
  199. //===----------------------------------------------------------------------===//
  200. // ValueTable External Functions
  201. //===----------------------------------------------------------------------===//
  202. ValueTable::ValueTable() = default;
  203. ValueTable::ValueTable(const ValueTable &) = default;
  204. ValueTable::ValueTable(ValueTable &&) = default;
  205. ValueTable::~ValueTable() = default;
  206. /// add - Insert a value into the table with a specified value number.
  207. void ValueTable::add(Value *V, uint32_t num) {
  208. valueNumbering.insert(std::make_pair(V, num));
  209. }
  210. uint32_t ValueTable::lookupOrAddCall(CallInst *C) {
  211. Function *F = C->getCalledFunction();
  212. bool bSafe = false;
  213. if (F->hasFnAttribute(Attribute::ReadNone)) {
  214. bSafe = true;
  215. } else if (F->hasFnAttribute(Attribute::ReadOnly)) {
  216. if (hlsl::OP::IsDxilOpFunc(F)) {
  217. DXIL::OpCode Opcode = hlsl::OP::GetDxilOpFuncCallInst(C);
  218. switch (Opcode) {
  219. default:
  220. break;
  221. // TODO: make buffer/texture load on srv safe.
  222. case DXIL::OpCode::CreateHandleForLib:
  223. case DXIL::OpCode::CBufferLoad:
  224. case DXIL::OpCode::CBufferLoadLegacy:
  225. case DXIL::OpCode::Sample:
  226. case DXIL::OpCode::SampleBias:
  227. case DXIL::OpCode::SampleCmp:
  228. case DXIL::OpCode::SampleCmpLevelZero:
  229. case DXIL::OpCode::SampleGrad:
  230. case DXIL::OpCode::CheckAccessFullyMapped:
  231. case DXIL::OpCode::GetDimensions:
  232. case DXIL::OpCode::TextureGather:
  233. case DXIL::OpCode::TextureGatherCmp:
  234. case DXIL::OpCode::Texture2DMSGetSamplePosition:
  235. case DXIL::OpCode::RenderTargetGetSampleCount:
  236. case DXIL::OpCode::RenderTargetGetSamplePosition:
  237. case DXIL::OpCode::CalculateLOD:
  238. bSafe = true;
  239. break;
  240. }
  241. }
  242. }
  243. if (bSafe) {
  244. Expression exp = createExpr(C);
  245. uint32_t e = assignExpNewValueNum(exp).first;
  246. valueNumbering[C] = e;
  247. return e;
  248. } else {
  249. // Not sure safe or not, always use new value number.
  250. valueNumbering[C] = nextValueNumber;
  251. return nextValueNumber++;
  252. }
  253. }
  254. /// Returns true if a value number exists for the specified value.
  255. bool ValueTable::exists(Value *V) const { return valueNumbering.count(V) != 0; }
  256. /// lookup_or_add - Returns the value number for the specified value, assigning
  257. /// it a new number if it did not have one before.
  258. uint32_t ValueTable::lookupOrAdd(Value *V) {
  259. DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V);
  260. if (VI != valueNumbering.end())
  261. return VI->second;
  262. if (!isa<Instruction>(V)) {
  263. valueNumbering[V] = nextValueNumber;
  264. return nextValueNumber++;
  265. }
  266. Instruction* I = cast<Instruction>(V);
  267. Expression exp;
  268. switch (I->getOpcode()) {
  269. case Instruction::Call:
  270. return lookupOrAddCall(cast<CallInst>(I));
  271. case Instruction::Add:
  272. case Instruction::FAdd:
  273. case Instruction::Sub:
  274. case Instruction::FSub:
  275. case Instruction::Mul:
  276. case Instruction::FMul:
  277. case Instruction::UDiv:
  278. case Instruction::SDiv:
  279. case Instruction::FDiv:
  280. case Instruction::URem:
  281. case Instruction::SRem:
  282. case Instruction::FRem:
  283. case Instruction::Shl:
  284. case Instruction::LShr:
  285. case Instruction::AShr:
  286. case Instruction::And:
  287. case Instruction::Or:
  288. case Instruction::Xor:
  289. case Instruction::ICmp:
  290. case Instruction::FCmp:
  291. case Instruction::Trunc:
  292. case Instruction::ZExt:
  293. case Instruction::SExt:
  294. case Instruction::FPToUI:
  295. case Instruction::FPToSI:
  296. case Instruction::UIToFP:
  297. case Instruction::SIToFP:
  298. case Instruction::FPTrunc:
  299. case Instruction::FPExt:
  300. case Instruction::PtrToInt:
  301. case Instruction::IntToPtr:
  302. case Instruction::BitCast:
  303. case Instruction::Select:
  304. case Instruction::ExtractElement:
  305. case Instruction::InsertElement:
  306. case Instruction::ShuffleVector:
  307. case Instruction::InsertValue:
  308. case Instruction::GetElementPtr:
  309. exp = createExpr(I);
  310. break;
  311. case Instruction::ExtractValue:
  312. exp = createExtractvalueExpr(cast<ExtractValueInst>(I));
  313. break;
  314. case Instruction::PHI:
  315. valueNumbering[V] = nextValueNumber;
  316. return nextValueNumber++;
  317. default:
  318. valueNumbering[V] = nextValueNumber;
  319. return nextValueNumber++;
  320. }
  321. uint32_t e = assignExpNewValueNum(exp).first;
  322. valueNumbering[V] = e;
  323. return e;
  324. }
  325. /// Returns the value number of the specified value. Fails if
  326. /// the value has not yet been numbered.
  327. uint32_t ValueTable::lookup(Value *V, bool Verify) const {
  328. DenseMap<Value*, uint32_t>::const_iterator VI = valueNumbering.find(V);
  329. if (Verify) {
  330. assert(VI != valueNumbering.end() && "Value not numbered?");
  331. return VI->second;
  332. }
  333. return (VI != valueNumbering.end()) ? VI->second : 0;
  334. }
  335. /// Returns the value number of the given comparison,
  336. /// assigning it a new number if it did not have one before. Useful when
  337. /// we deduced the result of a comparison, but don't immediately have an
  338. /// instruction realizing that comparison to hand.
  339. uint32_t ValueTable::lookupOrAddCmp(unsigned Opcode,
  340. CmpInst::Predicate Predicate,
  341. Value *LHS, Value *RHS) {
  342. Expression exp = createCmpExpr(Opcode, Predicate, LHS, RHS);
  343. return assignExpNewValueNum(exp).first;
  344. }
  345. /// Remove all entries from the ValueTable.
  346. void ValueTable::clear() {
  347. valueNumbering.clear();
  348. expressionNumbering.clear();
  349. nextValueNumber = 1;
  350. Expressions.clear();
  351. ExprIdx.clear();
  352. nextExprNumber = 0;
  353. }
  354. /// Remove a value from the value numbering.
  355. void ValueTable::erase(Value *V) {
  356. valueNumbering.erase(V);
  357. }
  358. /// verifyRemoved - Verify that the value is removed from all internal data
  359. /// structures.
  360. void ValueTable::verifyRemoved(const Value *V) const {
  361. for (DenseMap<Value*, uint32_t>::const_iterator
  362. I = valueNumbering.begin(), E = valueNumbering.end(); I != E; ++I) {
  363. assert(I->first != V && "Inst still occurs in value numbering map!");
  364. }
  365. }
  366. /// Return a pair the first field showing the value number of \p Exp and the
  367. /// second field showing whether it is a value number newly created.
  368. std::pair<uint32_t, bool>
  369. ValueTable::assignExpNewValueNum(Expression &Exp) {
  370. uint32_t &e = expressionNumbering[Exp];
  371. bool CreateNewValNum = !e;
  372. if (CreateNewValNum) {
  373. Expressions.push_back(Exp);
  374. if (ExprIdx.size() < nextValueNumber + 1)
  375. ExprIdx.resize(nextValueNumber * 2);
  376. e = nextValueNumber;
  377. ExprIdx[nextValueNumber++] = nextExprNumber++;
  378. }
  379. return {e, CreateNewValNum};
  380. }
  381. } // namespace
  382. namespace {
  383. // Reduce code size for pattern like this:
  384. // if (a.x > 0) {
  385. // r = tex.Sample(ss, uv)-1;
  386. // } else {
  387. // if (a.y > 0)
  388. // r = tex.Sample(ss, uv);
  389. // else
  390. // r = tex.Sample(ss, uv) + 3;
  391. // }
  392. class DxilSimpleGVNHoist : public FunctionPass {
  393. public:
  394. static char ID; // Pass identification, replacement for typeid
  395. explicit DxilSimpleGVNHoist() : FunctionPass(ID) {}
  396. const char *getPassName() const override {
  397. return "DXIL simple GVN hoist";
  398. }
  399. bool runOnFunction(Function &F) override;
  400. private:
  401. bool tryToHoist(BasicBlock *BB, BasicBlock *Succ0, BasicBlock *Succ1);
  402. };
  403. char DxilSimpleGVNHoist::ID = 0;
  404. bool HasOnePred(BasicBlock *BB) {
  405. if (pred_empty(BB))
  406. return false;
  407. auto pred = pred_begin(BB);
  408. pred++;
  409. if (pred != pred_end(BB))
  410. return false;
  411. return true;
  412. }
  413. bool DxilSimpleGVNHoist::tryToHoist(BasicBlock *BB, BasicBlock *Succ0,
  414. BasicBlock *Succ1) {
  415. // ValueNumber Succ0 and Succ1.
  416. ValueTable VT;
  417. DenseMap<uint32_t, SmallVector<Instruction *, 2>> VNtoInsts;
  418. for (Instruction &I : *Succ0) {
  419. uint32_t V = VT.lookupOrAdd(&I);
  420. VNtoInsts[V].emplace_back(&I);
  421. }
  422. std::vector<uint32_t> HoistCandidateVN;
  423. for (Instruction &I : *Succ1) {
  424. uint32_t V = VT.lookupOrAdd(&I);
  425. if (!VNtoInsts.count(V))
  426. continue;
  427. VNtoInsts[V].emplace_back(&I);
  428. HoistCandidateVN.emplace_back(V);
  429. }
  430. if (HoistCandidateVN.empty()) {
  431. return false;
  432. }
  433. DenseSet<uint32_t> ProcessedVN;
  434. Instruction *TI = BB->getTerminator();
  435. // Hoist need to be in order, so operand could hoist before its users.
  436. for (uint32_t VN : HoistCandidateVN) {
  437. // Skip processed VN
  438. if (ProcessedVN.count(VN))
  439. continue;
  440. ProcessedVN.insert(VN);
  441. auto &Insts = VNtoInsts[VN];
  442. if (Insts.size() == 1)
  443. continue;
  444. bool bHoist = false;
  445. for (Instruction *I : Insts) {
  446. if (I->getParent() == Succ1) {
  447. bHoist = true;
  448. break;
  449. }
  450. }
  451. Instruction *FirstI = Insts.front();
  452. if (bHoist) {
  453. // When operand is different, need to hoist operand.
  454. auto it = Insts.begin();
  455. it++;
  456. bool bHasDifferentOperand = false;
  457. unsigned NumOps = FirstI->getNumOperands();
  458. for (; it != Insts.end(); it++) {
  459. Instruction *I = *it;
  460. assert(NumOps == I->getNumOperands());
  461. for (unsigned i = 0; i < NumOps; i++) {
  462. if (FirstI->getOperand(i) != I->getOperand(i)) {
  463. bHasDifferentOperand = true;
  464. break;
  465. }
  466. }
  467. if (bHasDifferentOperand)
  468. break;
  469. }
  470. // TODO: hoist operands.
  471. if (bHasDifferentOperand)
  472. continue;
  473. // Move FirstI to BB.
  474. FirstI->removeFromParent();
  475. FirstI->insertBefore(TI);
  476. }
  477. // Replace all insts with same value number with firstI.
  478. auto it = Insts.begin();
  479. it++;
  480. for (; it != Insts.end(); it++) {
  481. Instruction *I = *it;
  482. I->replaceAllUsesWith(FirstI);
  483. I->eraseFromParent();
  484. }
  485. Insts.clear();
  486. }
  487. return true;
  488. }
  489. bool DxilSimpleGVNHoist::runOnFunction(Function &F) {
  490. BasicBlock &Entry = F.getEntryBlock();
  491. bool bUpdated = false;
  492. for (auto it = po_begin(&Entry); it != po_end(&Entry); it++) {
  493. BasicBlock *BB = *it;
  494. TerminatorInst *TI = BB->getTerminator();
  495. if (TI->getNumSuccessors() != 2)
  496. continue;
  497. BasicBlock *Succ0 = TI->getSuccessor(0);
  498. BasicBlock *Succ1 = TI->getSuccessor(1);
  499. if (BB == Succ0)
  500. continue;
  501. if (BB == Succ1)
  502. continue;
  503. if (!HasOnePred(Succ0))
  504. continue;
  505. if (!HasOnePred(Succ1))
  506. continue;
  507. bUpdated |= tryToHoist(BB, Succ0, Succ1);
  508. }
  509. return bUpdated;
  510. }
  511. }
  512. FunctionPass *llvm::createDxilSimpleGVNHoistPass() {
  513. return new DxilSimpleGVNHoist();
  514. }
  515. INITIALIZE_PASS(DxilSimpleGVNHoist, "dxil-gvn-hoist",
  516. "DXIL simple gvn hoist", false, false)