Float2Int.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. //===- Float2Int.cpp - Demote floating point ops to work on integers ------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This file implements the Float2Int pass, which aims to demote floating
  11. // point operations to work on integers, where that is losslessly possible.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #define DEBUG_TYPE "float2int"
  15. #include "llvm/ADT/APInt.h"
  16. #include "llvm/ADT/APSInt.h"
  17. #include "llvm/ADT/DenseMap.h"
  18. #include "llvm/ADT/EquivalenceClasses.h"
  19. #include "llvm/ADT/MapVector.h"
  20. #include "llvm/ADT/SmallVector.h"
  21. #include "llvm/IR/ConstantRange.h"
  22. #include "llvm/IR/Constants.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/InstIterator.h"
  25. #include "llvm/IR/Instructions.h"
  26. #include "llvm/IR/Module.h"
  27. #include "llvm/Pass.h"
  28. #include "llvm/Support/Debug.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. #include "llvm/Transforms/Scalar.h"
  31. #include <deque>
  32. #include <functional> // For std::function
  33. using namespace llvm;
  34. // The algorithm is simple. Start at instructions that convert from the
  35. // float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use
  36. // graph, using an equivalence datastructure to unify graphs that interfere.
  37. //
  38. // Mappable instructions are those with an integer corrollary that, given
  39. // integer domain inputs, produce an integer output; fadd, for example.
  40. //
  41. // If a non-mappable instruction is seen, this entire def-use graph is marked
  42. // as non-transformable. If we see an instruction that converts from the
  43. // integer domain to FP domain (uitofp,sitofp), we terminate our walk.
  44. /// The largest integer type worth dealing with.
  45. static cl::opt<unsigned>
  46. MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
  47. cl::desc("Max integer bitwidth to consider in float2int"
  48. "(default=64)"));
  49. namespace {
  50. struct Float2Int : public FunctionPass {
  51. static char ID; // Pass identification, replacement for typeid
  52. Float2Int() : FunctionPass(ID) {
  53. initializeFloat2IntPass(*PassRegistry::getPassRegistry());
  54. }
  55. bool runOnFunction(Function &F) override;
  56. void getAnalysisUsage(AnalysisUsage &AU) const override {
  57. AU.setPreservesCFG();
  58. }
  59. void findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots);
  60. ConstantRange seen(Instruction *I, ConstantRange R);
  61. ConstantRange badRange();
  62. ConstantRange unknownRange();
  63. ConstantRange validateRange(ConstantRange R);
  64. void walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots);
  65. void walkForwards();
  66. bool validateAndTransform();
  67. Value *convert(Instruction *I, Type *ToTy);
  68. void cleanup();
  69. MapVector<Instruction*, ConstantRange > SeenInsts;
  70. SmallPtrSet<Instruction*,8> Roots;
  71. EquivalenceClasses<Instruction*> ECs;
  72. MapVector<Instruction*, Value*> ConvertedInsts;
  73. LLVMContext *Ctx;
  74. };
  75. }
  76. char Float2Int::ID = 0;
  77. INITIALIZE_PASS(Float2Int, "float2int", "Float to int", false, false)
  78. // Given a FCmp predicate, return a matching ICmp predicate if one
  79. // exists, otherwise return BAD_ICMP_PREDICATE.
  80. static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
  81. switch (P) {
  82. case CmpInst::FCMP_OEQ:
  83. case CmpInst::FCMP_UEQ:
  84. return CmpInst::ICMP_EQ;
  85. case CmpInst::FCMP_OGT:
  86. case CmpInst::FCMP_UGT:
  87. return CmpInst::ICMP_SGT;
  88. case CmpInst::FCMP_OGE:
  89. case CmpInst::FCMP_UGE:
  90. return CmpInst::ICMP_SGE;
  91. case CmpInst::FCMP_OLT:
  92. case CmpInst::FCMP_ULT:
  93. return CmpInst::ICMP_SLT;
  94. case CmpInst::FCMP_OLE:
  95. case CmpInst::FCMP_ULE:
  96. return CmpInst::ICMP_SLE;
  97. case CmpInst::FCMP_ONE:
  98. case CmpInst::FCMP_UNE:
  99. return CmpInst::ICMP_NE;
  100. default:
  101. return CmpInst::BAD_ICMP_PREDICATE;
  102. }
  103. }
  104. // Given a floating point binary operator, return the matching
  105. // integer version.
  106. static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
  107. switch (Opcode) {
  108. default: llvm_unreachable("Unhandled opcode!");
  109. case Instruction::FAdd: return Instruction::Add;
  110. case Instruction::FSub: return Instruction::Sub;
  111. case Instruction::FMul: return Instruction::Mul;
  112. }
  113. }
  114. // Find the roots - instructions that convert from the FP domain to
  115. // integer domain.
  116. void Float2Int::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) {
  117. for (auto &I : inst_range(F)) {
  118. switch (I.getOpcode()) {
  119. default: break;
  120. case Instruction::FPToUI:
  121. case Instruction::FPToSI:
  122. Roots.insert(&I);
  123. break;
  124. case Instruction::FCmp:
  125. if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
  126. CmpInst::BAD_ICMP_PREDICATE)
  127. Roots.insert(&I);
  128. break;
  129. }
  130. }
  131. }
  132. // Helper - mark I as having been traversed, having range R.
  133. ConstantRange Float2Int::seen(Instruction *I, ConstantRange R) {
  134. DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");
  135. if (SeenInsts.find(I) != SeenInsts.end())
  136. SeenInsts.find(I)->second = R;
  137. else
  138. SeenInsts.insert(std::make_pair(I, R));
  139. return R;
  140. }
  141. // Helper - get a range representing a poison value.
  142. ConstantRange Float2Int::badRange() {
  143. return ConstantRange(MaxIntegerBW + 1, true);
  144. }
  145. ConstantRange Float2Int::unknownRange() {
  146. return ConstantRange(MaxIntegerBW + 1, false);
  147. }
  148. ConstantRange Float2Int::validateRange(ConstantRange R) {
  149. if (R.getBitWidth() > MaxIntegerBW + 1)
  150. return badRange();
  151. return R;
  152. }
  153. // The most obvious way to structure the search is a depth-first, eager
  154. // search from each root. However, that require direct recursion and so
  155. // can only handle small instruction sequences. Instead, we split the search
  156. // up into two phases:
  157. // - walkBackwards: A breadth-first walk of the use-def graph starting from
  158. // the roots. Populate "SeenInsts" with interesting
  159. // instructions and poison values if they're obvious and
  160. // cheap to compute. Calculate the equivalance set structure
  161. // while we're here too.
  162. // - walkForwards: Iterate over SeenInsts in reverse order, so we visit
  163. // defs before their uses. Calculate the real range info.
  164. // Breadth-first walk of the use-def graph; determine the set of nodes
  165. // we care about and eagerly determine if some of them are poisonous.
  166. void Float2Int::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) {
  167. std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
  168. while (!Worklist.empty()) {
  169. Instruction *I = Worklist.back();
  170. Worklist.pop_back();
  171. if (SeenInsts.find(I) != SeenInsts.end())
  172. // Seen already.
  173. continue;
  174. switch (I->getOpcode()) {
  175. // FIXME: Handle select and phi nodes.
  176. default:
  177. // Path terminated uncleanly.
  178. seen(I, badRange());
  179. break;
  180. case Instruction::UIToFP: {
  181. // Path terminated cleanly.
  182. unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
  183. APInt Min = APInt::getMinValue(BW).zextOrSelf(MaxIntegerBW+1);
  184. APInt Max = APInt::getMaxValue(BW).zextOrSelf(MaxIntegerBW+1);
  185. seen(I, validateRange(ConstantRange(Min, Max)));
  186. continue;
  187. }
  188. case Instruction::SIToFP: {
  189. // Path terminated cleanly.
  190. unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
  191. APInt SMin = APInt::getSignedMinValue(BW).sextOrSelf(MaxIntegerBW+1);
  192. APInt SMax = APInt::getSignedMaxValue(BW).sextOrSelf(MaxIntegerBW+1);
  193. seen(I, validateRange(ConstantRange(SMin, SMax)));
  194. continue;
  195. }
  196. case Instruction::FAdd:
  197. case Instruction::FSub:
  198. case Instruction::FMul:
  199. case Instruction::FPToUI:
  200. case Instruction::FPToSI:
  201. case Instruction::FCmp:
  202. seen(I, unknownRange());
  203. break;
  204. }
  205. for (Value *O : I->operands()) {
  206. if (Instruction *OI = dyn_cast<Instruction>(O)) {
  207. // Unify def-use chains if they interfere.
  208. ECs.unionSets(I, OI);
  209. if (SeenInsts.find(I)->second != badRange())
  210. Worklist.push_back(OI);
  211. } else if (!isa<ConstantFP>(O)) {
  212. // Not an instruction or ConstantFP? we can't do anything.
  213. seen(I, badRange());
  214. }
  215. }
  216. }
  217. }
  218. // Walk forwards down the list of seen instructions, so we visit defs before
  219. // uses.
  220. void Float2Int::walkForwards() {
  221. for (auto It = SeenInsts.rbegin(), E = SeenInsts.rend(); It != E; ++It) {
  222. if (It->second != unknownRange())
  223. continue;
  224. Instruction *I = It->first;
  225. std::function<ConstantRange(ArrayRef<ConstantRange>)> Op;
  226. switch (I->getOpcode()) {
  227. // FIXME: Handle select and phi nodes.
  228. default:
  229. case Instruction::UIToFP:
  230. case Instruction::SIToFP:
  231. llvm_unreachable("Should have been handled in walkForwards!");
  232. case Instruction::FAdd:
  233. Op = [](ArrayRef<ConstantRange> Ops) {
  234. assert(Ops.size() == 2 && "FAdd is a binary operator!");
  235. return Ops[0].add(Ops[1]);
  236. };
  237. break;
  238. case Instruction::FSub:
  239. Op = [](ArrayRef<ConstantRange> Ops) {
  240. assert(Ops.size() == 2 && "FSub is a binary operator!");
  241. return Ops[0].sub(Ops[1]);
  242. };
  243. break;
  244. case Instruction::FMul:
  245. Op = [](ArrayRef<ConstantRange> Ops) {
  246. assert(Ops.size() == 2 && "FMul is a binary operator!");
  247. return Ops[0].multiply(Ops[1]);
  248. };
  249. break;
  250. //
  251. // Root-only instructions - we'll only see these if they're the
  252. // first node in a walk.
  253. //
  254. case Instruction::FPToUI:
  255. case Instruction::FPToSI:
  256. Op = [](ArrayRef<ConstantRange> Ops) {
  257. assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!");
  258. return Ops[0];
  259. };
  260. break;
  261. case Instruction::FCmp:
  262. Op = [](ArrayRef<ConstantRange> Ops) {
  263. assert(Ops.size() == 2 && "FCmp is a binary operator!");
  264. return Ops[0].unionWith(Ops[1]);
  265. };
  266. break;
  267. }
  268. bool Abort = false;
  269. SmallVector<ConstantRange,4> OpRanges;
  270. for (Value *O : I->operands()) {
  271. if (Instruction *OI = dyn_cast<Instruction>(O)) {
  272. assert(SeenInsts.find(OI) != SeenInsts.end() &&
  273. "def not seen before use!");
  274. OpRanges.push_back(SeenInsts.find(OI)->second);
  275. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {
  276. // Work out if the floating point number can be losslessly represented
  277. // as an integer.
  278. // APFloat::convertToInteger(&Exact) purports to do what we want, but
  279. // the exactness can be too precise. For example, negative zero can
  280. // never be exactly converted to an integer.
  281. //
  282. // Instead, we ask APFloat to round itself to an integral value - this
  283. // preserves sign-of-zero - then compare the result with the original.
  284. //
  285. APFloat F = CF->getValueAPF();
  286. // First, weed out obviously incorrect values. Non-finite numbers
  287. // can't be represented and neither can negative zero, unless
  288. // we're in fast math mode.
  289. if (!F.isFinite() ||
  290. (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&
  291. !I->hasNoSignedZeros())) {
  292. seen(I, badRange());
  293. Abort = true;
  294. break;
  295. }
  296. APFloat NewF = F;
  297. auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
  298. if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) {
  299. seen(I, badRange());
  300. Abort = true;
  301. break;
  302. }
  303. // OK, it's representable. Now get it.
  304. APSInt Int(MaxIntegerBW+1, false);
  305. bool Exact;
  306. CF->getValueAPF().convertToInteger(Int,
  307. APFloat::rmNearestTiesToEven,
  308. &Exact);
  309. OpRanges.push_back(ConstantRange(Int));
  310. } else {
  311. llvm_unreachable("Should have already marked this as badRange!");
  312. }
  313. }
  314. // Reduce the operands' ranges to a single range and return.
  315. if (!Abort)
  316. seen(I, Op(OpRanges));
  317. }
  318. }
  319. // If there is a valid transform to be done, do it.
  320. bool Float2Int::validateAndTransform() {
  321. bool MadeChange = false;
  322. // Iterate over every disjoint partition of the def-use graph.
  323. for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {
  324. ConstantRange R(MaxIntegerBW + 1, false);
  325. bool Fail = false;
  326. Type *ConvertedToTy = nullptr;
  327. // For every member of the partition, union all the ranges together.
  328. for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
  329. MI != ME; ++MI) {
  330. Instruction *I = *MI;
  331. auto SeenI = SeenInsts.find(I);
  332. if (SeenI == SeenInsts.end())
  333. continue;
  334. R = R.unionWith(SeenI->second);
  335. // We need to ensure I has no users that have not been seen.
  336. // If it does, transformation would be illegal.
  337. //
  338. // Don't count the roots, as they terminate the graphs.
  339. if (Roots.count(I) == 0) {
  340. // Set the type of the conversion while we're here.
  341. if (!ConvertedToTy)
  342. ConvertedToTy = I->getType();
  343. for (User *U : I->users()) {
  344. Instruction *UI = dyn_cast<Instruction>(U);
  345. if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
  346. DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
  347. Fail = true;
  348. break;
  349. }
  350. }
  351. }
  352. if (Fail)
  353. break;
  354. }
  355. // If the set was empty, or we failed, or the range is poisonous,
  356. // bail out.
  357. if (ECs.member_begin(It) == ECs.member_end() || Fail ||
  358. R.isFullSet() || R.isSignWrappedSet())
  359. continue;
  360. assert(ConvertedToTy && "Must have set the convertedtoty by this point!");
  361. // The number of bits required is the maximum of the upper and
  362. // lower limits, plus one so it can be signed.
  363. unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
  364. R.getUpper().getMinSignedBits()) + 1;
  365. DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
  366. // If we've run off the realms of the exactly representable integers,
  367. // the floating point result will differ from an integer approximation.
  368. // Do we need more bits than are in the mantissa of the type we converted
  369. // to? semanticsPrecision returns the number of mantissa bits plus one
  370. // for the sign bit.
  371. unsigned MaxRepresentableBits
  372. = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;
  373. if (MinBW > MaxRepresentableBits) {
  374. DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");
  375. continue;
  376. }
  377. if (MinBW > 64) {
  378. DEBUG(dbgs() << "F2I: Value requires more than 64 bits to represent!\n");
  379. continue;
  380. }
  381. // OK, R is known to be representable. Now pick a type for it.
  382. // FIXME: Pick the smallest legal type that will fit.
  383. Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx);
  384. for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
  385. MI != ME; ++MI)
  386. convert(*MI, Ty);
  387. MadeChange = true;
  388. }
  389. return MadeChange;
  390. }
  391. Value *Float2Int::convert(Instruction *I, Type *ToTy) {
  392. if (ConvertedInsts.find(I) != ConvertedInsts.end())
  393. // Already converted this instruction.
  394. return ConvertedInsts[I];
  395. SmallVector<Value*,4> NewOperands;
  396. for (Value *V : I->operands()) {
  397. // Don't recurse if we're an instruction that terminates the path.
  398. if (I->getOpcode() == Instruction::UIToFP ||
  399. I->getOpcode() == Instruction::SIToFP) {
  400. NewOperands.push_back(V);
  401. } else if (Instruction *VI = dyn_cast<Instruction>(V)) {
  402. NewOperands.push_back(convert(VI, ToTy));
  403. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {
  404. APSInt Val(ToTy->getPrimitiveSizeInBits(), /*IsUnsigned=*/false);
  405. bool Exact;
  406. CF->getValueAPF().convertToInteger(Val,
  407. APFloat::rmNearestTiesToEven,
  408. &Exact);
  409. NewOperands.push_back(ConstantInt::get(ToTy, Val));
  410. } else {
  411. llvm_unreachable("Unhandled operand type?");
  412. }
  413. }
  414. // Now create a new instruction.
  415. IRBuilder<> IRB(I);
  416. Value *NewV = nullptr;
  417. switch (I->getOpcode()) {
  418. default: llvm_unreachable("Unhandled instruction!");
  419. case Instruction::FPToUI:
  420. NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());
  421. break;
  422. case Instruction::FPToSI:
  423. NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());
  424. break;
  425. case Instruction::FCmp: {
  426. CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());
  427. assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");
  428. NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());
  429. break;
  430. }
  431. case Instruction::UIToFP:
  432. NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);
  433. break;
  434. case Instruction::SIToFP:
  435. NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);
  436. break;
  437. case Instruction::FAdd:
  438. case Instruction::FSub:
  439. case Instruction::FMul:
  440. NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),
  441. NewOperands[0], NewOperands[1],
  442. I->getName());
  443. break;
  444. }
  445. // If we're a root instruction, RAUW.
  446. if (Roots.count(I))
  447. I->replaceAllUsesWith(NewV);
  448. ConvertedInsts[I] = NewV;
  449. return NewV;
  450. }
  451. // Perform dead code elimination on the instructions we just modified.
  452. void Float2Int::cleanup() {
  453. for (auto I = ConvertedInsts.rbegin(), E = ConvertedInsts.rend();
  454. I != E; ++I)
  455. I->first->eraseFromParent();
  456. }
  457. bool Float2Int::runOnFunction(Function &F) {
  458. if (skipOptnoneFunction(F))
  459. return false;
  460. DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
  461. // Clear out all state.
  462. ECs = EquivalenceClasses<Instruction*>();
  463. SeenInsts.clear();
  464. ConvertedInsts.clear();
  465. Roots.clear();
  466. Ctx = &F.getParent()->getContext();
  467. findRoots(F, Roots);
  468. walkBackwards(Roots);
  469. walkForwards();
  470. bool Modified = validateAndTransform();
  471. if (Modified)
  472. cleanup();
  473. return Modified;
  474. }
  475. FunctionPass *llvm::createFloat2IntPass() {
  476. return new Float2Int();
  477. }