LowerSwitch.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // The LowerSwitch transformation rewrites switch instructions with a sequence
  11. // of branches, which allows targets to get away with not implementing the
  12. // switch instruction until it is convenient.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/Transforms/Scalar.h"
  16. #include "llvm/ADT/STLExtras.h"
  17. #include "llvm/IR/CFG.h"
  18. #include "llvm/IR/Constants.h"
  19. #include "llvm/IR/Function.h"
  20. #include "llvm/IR/Instructions.h"
  21. #include "llvm/IR/LLVMContext.h"
  22. #include "llvm/Pass.h"
  23. #include "llvm/Support/Compiler.h"
  24. #include "llvm/Support/Debug.h"
  25. #include "llvm/Support/raw_ostream.h"
  26. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  27. #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
  28. #include <algorithm>
  29. using namespace llvm;
  30. #define DEBUG_TYPE "lower-switch"
  31. namespace {
  32. struct IntRange {
  33. int64_t Low, High;
  34. };
  35. // Return true iff R is covered by Ranges.
  36. static bool IsInRanges(const IntRange &R,
  37. const std::vector<IntRange> &Ranges) {
  38. // Note: Ranges must be sorted, non-overlapping and non-adjacent.
  39. // Find the first range whose High field is >= R.High,
  40. // then check if the Low field is <= R.Low. If so, we
  41. // have a Range that covers R.
  42. auto I = std::lower_bound(
  43. Ranges.begin(), Ranges.end(), R,
  44. [](const IntRange &A, const IntRange &B) { return A.High < B.High; });
  45. return I != Ranges.end() && I->Low <= R.Low;
  46. }
  47. /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch
  48. /// instructions.
  49. class LowerSwitch : public FunctionPass {
  50. public:
  51. static char ID; // Pass identification, replacement for typeid
  52. LowerSwitch() : FunctionPass(ID) {
  53. initializeLowerSwitchPass(*PassRegistry::getPassRegistry());
  54. }
  55. bool runOnFunction(Function &F) override;
  56. void getAnalysisUsage(AnalysisUsage &AU) const override {
  57. // This is a cluster of orthogonal Transforms
  58. AU.addPreserved<UnifyFunctionExitNodes>();
  59. AU.addPreservedID(LowerInvokePassID);
  60. }
  61. struct CaseRange {
  62. ConstantInt* Low;
  63. ConstantInt* High;
  64. BasicBlock* BB;
  65. CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb)
  66. : Low(low), High(high), BB(bb) {}
  67. };
  68. typedef std::vector<CaseRange> CaseVector;
  69. typedef std::vector<CaseRange>::iterator CaseItr;
  70. private:
  71. void processSwitchInst(SwitchInst *SI);
  72. BasicBlock *switchConvert(CaseItr Begin, CaseItr End,
  73. ConstantInt *LowerBound, ConstantInt *UpperBound,
  74. Value *Val, BasicBlock *Predecessor,
  75. BasicBlock *OrigBlock, BasicBlock *Default,
  76. const std::vector<IntRange> &UnreachableRanges);
  77. BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock,
  78. BasicBlock *Default);
  79. unsigned Clusterify(CaseVector &Cases, SwitchInst *SI);
  80. };
  81. /// The comparison function for sorting the switch case values in the vector.
  82. /// WARNING: Case ranges should be disjoint!
  83. struct CaseCmp {
  84. bool operator () (const LowerSwitch::CaseRange& C1,
  85. const LowerSwitch::CaseRange& C2) {
  86. const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low);
  87. const ConstantInt* CI2 = cast<const ConstantInt>(C2.High);
  88. return CI1->getValue().slt(CI2->getValue());
  89. }
  90. };
  91. }
  92. char LowerSwitch::ID = 0;
  93. INITIALIZE_PASS(LowerSwitch, "lowerswitch",
  94. "Lower SwitchInst's to branches", false, false)
  95. // Publicly exposed interface to pass...
  96. char &llvm::LowerSwitchID = LowerSwitch::ID;
  97. // createLowerSwitchPass - Interface to this file...
  98. FunctionPass *llvm::createLowerSwitchPass() {
  99. return new LowerSwitch();
  100. }
  101. bool LowerSwitch::runOnFunction(Function &F) {
  102. bool Changed = false;
  103. for (Function::iterator I = F.begin(), E = F.end(); I != E; ) {
  104. BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks
  105. if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) {
  106. Changed = true;
  107. processSwitchInst(SI);
  108. }
  109. }
  110. return Changed;
  111. }
  112. // operator<< - Used for debugging purposes.
  113. //
  114. static raw_ostream& operator<<(raw_ostream &O,
  115. const LowerSwitch::CaseVector &C)
  116. LLVM_ATTRIBUTE_USED;
  117. static raw_ostream& operator<<(raw_ostream &O,
  118. const LowerSwitch::CaseVector &C) {
  119. O << "[";
  120. for (LowerSwitch::CaseVector::const_iterator B = C.begin(),
  121. E = C.end(); B != E; ) {
  122. O << *B->Low << " -" << *B->High;
  123. if (++B != E) O << ", ";
  124. }
  125. return O << "]";
  126. }
  127. // \brief Update the first occurrence of the "switch statement" BB in the PHI
  128. // node with the "new" BB. The other occurrences will:
  129. //
  130. // 1) Be updated by subsequent calls to this function. Switch statements may
  131. // have more than one outcoming edge into the same BB if they all have the same
  132. // value. When the switch statement is converted these incoming edges are now
  133. // coming from multiple BBs.
  134. // 2) Removed if subsequent incoming values now share the same case, i.e.,
  135. // multiple outcome edges are condensed into one. This is necessary to keep the
  136. // number of phi values equal to the number of branches to SuccBB.
  137. static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
  138. unsigned NumMergedCases) {
  139. for (BasicBlock::iterator I = SuccBB->begin(), IE = SuccBB->getFirstNonPHI();
  140. I != IE; ++I) {
  141. PHINode *PN = cast<PHINode>(I);
  142. // Only update the first occurence.
  143. unsigned Idx = 0, E = PN->getNumIncomingValues();
  144. unsigned LocalNumMergedCases = NumMergedCases;
  145. for (; Idx != E; ++Idx) {
  146. if (PN->getIncomingBlock(Idx) == OrigBB) {
  147. PN->setIncomingBlock(Idx, NewBB);
  148. break;
  149. }
  150. }
  151. // Remove additional occurences coming from condensed cases and keep the
  152. // number of incoming values equal to the number of branches to SuccBB.
  153. SmallVector<unsigned, 8> Indices;
  154. for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx)
  155. if (PN->getIncomingBlock(Idx) == OrigBB) {
  156. Indices.push_back(Idx);
  157. LocalNumMergedCases--;
  158. }
  159. // Remove incoming values in the reverse order to prevent invalidating
  160. // *successive* index.
  161. for (auto III = Indices.rbegin(), IIE = Indices.rend(); III != IIE; ++III)
  162. PN->removeIncomingValue(*III);
  163. }
  164. }
  165. // switchConvert - Convert the switch statement into a binary lookup of
  166. // the case values. The function recursively builds this tree.
  167. // LowerBound and UpperBound are used to keep track of the bounds for Val
  168. // that have already been checked by a block emitted by one of the previous
  169. // calls to switchConvert in the call stack.
  170. BasicBlock *
  171. LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
  172. ConstantInt *UpperBound, Value *Val,
  173. BasicBlock *Predecessor, BasicBlock *OrigBlock,
  174. BasicBlock *Default,
  175. const std::vector<IntRange> &UnreachableRanges) {
  176. unsigned Size = End - Begin;
  177. if (Size == 1) {
  178. // Check if the Case Range is perfectly squeezed in between
  179. // already checked Upper and Lower bounds. If it is then we can avoid
  180. // emitting the code that checks if the value actually falls in the range
  181. // because the bounds already tell us so.
  182. if (Begin->Low == LowerBound && Begin->High == UpperBound) {
  183. unsigned NumMergedCases = 0;
  184. if (LowerBound && UpperBound)
  185. NumMergedCases =
  186. UpperBound->getSExtValue() - LowerBound->getSExtValue();
  187. fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
  188. return Begin->BB;
  189. }
  190. return newLeafBlock(*Begin, Val, OrigBlock, Default);
  191. }
  192. unsigned Mid = Size / 2;
  193. std::vector<CaseRange> LHS(Begin, Begin + Mid);
  194. DEBUG(dbgs() << "LHS: " << LHS << "\n");
  195. std::vector<CaseRange> RHS(Begin + Mid, End);
  196. DEBUG(dbgs() << "RHS: " << RHS << "\n");
  197. CaseRange &Pivot = *(Begin + Mid);
  198. DEBUG(dbgs() << "Pivot ==> "
  199. << Pivot.Low->getValue()
  200. << " -" << Pivot.High->getValue() << "\n");
  201. // NewLowerBound here should never be the integer minimal value.
  202. // This is because it is computed from a case range that is never
  203. // the smallest, so there is always a case range that has at least
  204. // a smaller value.
  205. ConstantInt *NewLowerBound = Pivot.Low;
  206. // Because NewLowerBound is never the smallest representable integer
  207. // it is safe here to subtract one.
  208. ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(),
  209. NewLowerBound->getValue() - 1);
  210. if (!UnreachableRanges.empty()) {
  211. // Check if the gap between LHS's highest and NewLowerBound is unreachable.
  212. int64_t GapLow = LHS.back().High->getSExtValue() + 1;
  213. int64_t GapHigh = NewLowerBound->getSExtValue() - 1;
  214. IntRange Gap = { GapLow, GapHigh };
  215. if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges))
  216. NewUpperBound = LHS.back().High;
  217. }
  218. DEBUG(dbgs() << "LHS Bounds ==> ";
  219. if (LowerBound) {
  220. dbgs() << LowerBound->getSExtValue();
  221. } else {
  222. dbgs() << "NONE";
  223. }
  224. dbgs() << " - " << NewUpperBound->getSExtValue() << "\n";
  225. dbgs() << "RHS Bounds ==> ";
  226. dbgs() << NewLowerBound->getSExtValue() << " - ";
  227. if (UpperBound) {
  228. dbgs() << UpperBound->getSExtValue() << "\n";
  229. } else {
  230. dbgs() << "NONE\n";
  231. });
  232. // Create a new node that checks if the value is < pivot. Go to the
  233. // left branch if it is and right branch if not.
  234. Function* F = OrigBlock->getParent();
  235. BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock");
  236. ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT,
  237. Val, Pivot.Low, "Pivot");
  238. BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound,
  239. NewUpperBound, Val, NewNode, OrigBlock,
  240. Default, UnreachableRanges);
  241. BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound,
  242. UpperBound, Val, NewNode, OrigBlock,
  243. Default, UnreachableRanges);
  244. Function::iterator FI = OrigBlock;
  245. F->getBasicBlockList().insert(++FI, NewNode);
  246. NewNode->getInstList().push_back(Comp);
  247. BranchInst::Create(LBranch, RBranch, Comp, NewNode);
  248. return NewNode;
  249. }
  250. // newLeafBlock - Create a new leaf block for the binary lookup tree. It
  251. // checks if the switch's value == the case's value. If not, then it
  252. // jumps to the default branch. At this point in the tree, the value
  253. // can't be another valid case value, so the jump to the "default" branch
  254. // is warranted.
  255. //
  256. BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val,
  257. BasicBlock* OrigBlock,
  258. BasicBlock* Default)
  259. {
  260. Function* F = OrigBlock->getParent();
  261. BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock");
  262. Function::iterator FI = OrigBlock;
  263. F->getBasicBlockList().insert(++FI, NewLeaf);
  264. // Emit comparison
  265. ICmpInst* Comp = nullptr;
  266. if (Leaf.Low == Leaf.High) {
  267. // Make the seteq instruction...
  268. Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val,
  269. Leaf.Low, "SwitchLeaf");
  270. } else {
  271. // Make range comparison
  272. if (Leaf.Low->isMinValue(true /*isSigned*/)) {
  273. // Val >= Min && Val <= Hi --> Val <= Hi
  274. Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High,
  275. "SwitchLeaf");
  276. } else if (Leaf.Low->isZero()) {
  277. // Val >= 0 && Val <= Hi --> Val <=u Hi
  278. Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High,
  279. "SwitchLeaf");
  280. } else {
  281. // Emit V-Lo <=u Hi-Lo
  282. Constant* NegLo = ConstantExpr::getNeg(Leaf.Low);
  283. Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo,
  284. Val->getName()+".off",
  285. NewLeaf);
  286. Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High);
  287. Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound,
  288. "SwitchLeaf");
  289. }
  290. }
  291. // Make the conditional branch...
  292. BasicBlock* Succ = Leaf.BB;
  293. BranchInst::Create(Succ, Default, Comp, NewLeaf);
  294. // If there were any PHI nodes in this successor, rewrite one entry
  295. // from OrigBlock to come from NewLeaf.
  296. for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
  297. PHINode* PN = cast<PHINode>(I);
  298. // Remove all but one incoming entries from the cluster
  299. uint64_t Range = Leaf.High->getSExtValue() -
  300. Leaf.Low->getSExtValue();
  301. for (uint64_t j = 0; j < Range; ++j) {
  302. PN->removeIncomingValue(OrigBlock);
  303. }
  304. int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
  305. assert(BlockIdx != -1 && "Switch didn't go to this successor??");
  306. PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf);
  307. }
  308. return NewLeaf;
  309. }
  310. // Clusterify - Transform simple list of Cases into list of CaseRange's
  311. unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) {
  312. unsigned numCmps = 0;
  313. // Start with "simple" cases
  314. for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i)
  315. Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(),
  316. i.getCaseSuccessor()));
  317. std::sort(Cases.begin(), Cases.end(), CaseCmp());
  318. // Merge case into clusters
  319. if (Cases.size() >= 2) {
  320. CaseItr I = Cases.begin();
  321. for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
  322. int64_t nextValue = J->Low->getSExtValue();
  323. int64_t currentValue = I->High->getSExtValue();
  324. BasicBlock* nextBB = J->BB;
  325. BasicBlock* currentBB = I->BB;
  326. // If the two neighboring cases go to the same destination, merge them
  327. // into a single case.
  328. assert(nextValue > currentValue && "Cases should be strictly ascending");
  329. if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
  330. I->High = J->High;
  331. // FIXME: Combine branch weights.
  332. } else if (++I != J) {
  333. *I = *J;
  334. }
  335. }
  336. Cases.erase(std::next(I), Cases.end());
  337. }
  338. for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) {
  339. if (I->Low != I->High)
  340. // A range counts double, since it requires two compares.
  341. ++numCmps;
  342. }
  343. return numCmps;
  344. }
  345. // processSwitchInst - Replace the specified switch instruction with a sequence
  346. // of chained if-then insts in a balanced binary search.
  347. //
  348. void LowerSwitch::processSwitchInst(SwitchInst *SI) {
  349. BasicBlock *CurBlock = SI->getParent();
  350. BasicBlock *OrigBlock = CurBlock;
  351. Function *F = CurBlock->getParent();
  352. Value *Val = SI->getCondition(); // The value we are switching on...
  353. BasicBlock* Default = SI->getDefaultDest();
  354. // If there is only the default destination, just branch.
  355. if (!SI->getNumCases()) {
  356. BranchInst::Create(Default, CurBlock);
  357. SI->eraseFromParent();
  358. return;
  359. }
  360. // Prepare cases vector.
  361. CaseVector Cases;
  362. unsigned numCmps = Clusterify(Cases, SI);
  363. DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
  364. << ". Total compares: " << numCmps << "\n");
  365. DEBUG(dbgs() << "Cases: " << Cases << "\n");
  366. (void)numCmps;
  367. ConstantInt *LowerBound = nullptr;
  368. ConstantInt *UpperBound = nullptr;
  369. std::vector<IntRange> UnreachableRanges;
  370. if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) {
  371. // Make the bounds tightly fitted around the case value range, becase we
  372. // know that the value passed to the switch must be exactly one of the case
  373. // values.
  374. assert(!Cases.empty());
  375. LowerBound = Cases.front().Low;
  376. UpperBound = Cases.back().High;
  377. DenseMap<BasicBlock *, unsigned> Popularity;
  378. unsigned MaxPop = 0;
  379. BasicBlock *PopSucc = nullptr;
  380. IntRange R = { INT64_MIN, INT64_MAX };
  381. UnreachableRanges.push_back(R);
  382. for (const auto &I : Cases) {
  383. int64_t Low = I.Low->getSExtValue();
  384. int64_t High = I.High->getSExtValue();
  385. IntRange &LastRange = UnreachableRanges.back();
  386. if (LastRange.Low == Low) {
  387. // There is nothing left of the previous range.
  388. UnreachableRanges.pop_back();
  389. } else {
  390. // Terminate the previous range.
  391. assert(Low > LastRange.Low);
  392. LastRange.High = Low - 1;
  393. }
  394. if (High != INT64_MAX) {
  395. IntRange R = { High + 1, INT64_MAX };
  396. UnreachableRanges.push_back(R);
  397. }
  398. // Count popularity.
  399. int64_t N = High - Low + 1;
  400. unsigned &Pop = Popularity[I.BB];
  401. if ((Pop += N) > MaxPop) {
  402. MaxPop = Pop;
  403. PopSucc = I.BB;
  404. }
  405. }
  406. #ifndef NDEBUG
  407. /* UnreachableRanges should be sorted and the ranges non-adjacent. */
  408. for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
  409. I != E; ++I) {
  410. assert(I->Low <= I->High);
  411. auto Next = I + 1;
  412. if (Next != E) {
  413. assert(Next->Low > I->High);
  414. }
  415. }
  416. #endif
  417. // Use the most popular block as the new default, reducing the number of
  418. // cases.
  419. assert(MaxPop > 0 && PopSucc);
  420. Default = PopSucc;
  421. Cases.erase(std::remove_if(
  422. Cases.begin(), Cases.end(),
  423. [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }),
  424. Cases.end());
  425. // If there are no cases left, just branch.
  426. if (Cases.empty()) {
  427. BranchInst::Create(Default, CurBlock);
  428. SI->eraseFromParent();
  429. return;
  430. }
  431. }
  432. // Create a new, empty default block so that the new hierarchy of
  433. // if-then statements go to this and the PHI nodes are happy.
  434. BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault");
  435. F->getBasicBlockList().insert(Default, NewDefault);
  436. BranchInst::Create(Default, NewDefault);
  437. // If there is an entry in any PHI nodes for the default edge, make sure
  438. // to update them as well.
  439. for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) {
  440. PHINode *PN = cast<PHINode>(I);
  441. int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
  442. assert(BlockIdx != -1 && "Switch didn't go to this successor??");
  443. PN->setIncomingBlock((unsigned)BlockIdx, NewDefault);
  444. }
  445. BasicBlock *SwitchBlock =
  446. switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val,
  447. OrigBlock, OrigBlock, NewDefault, UnreachableRanges);
  448. // Branch to our shiny new if-then stuff...
  449. BranchInst::Create(SwitchBlock, OrigBlock);
  450. // We are now done with the switch instruction, delete it.
  451. BasicBlock *OldDefault = SI->getDefaultDest();
  452. CurBlock->getInstList().erase(SI);
  453. // If the Default block has no more predecessors just remove it.
  454. if (pred_begin(OldDefault) == pred_end(OldDefault))
  455. DeleteDeadBlock(OldDefault);
  456. }