Float2Int.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. //===- Float2Int.cpp - Demote floating point ops to work on integers ------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This file implements the Float2Int pass, which aims to demote floating
  11. // point operations to work on integers, where that is losslessly possible.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #define DEBUG_TYPE "float2int"
  15. #include "llvm/ADT/APInt.h"
  16. #include "llvm/ADT/APSInt.h"
  17. #include "llvm/ADT/DenseMap.h"
  18. #include "llvm/ADT/EquivalenceClasses.h"
  19. #include "llvm/ADT/MapVector.h"
  20. #include "llvm/ADT/SmallVector.h"
  21. #include "llvm/IR/ConstantRange.h"
  22. #include "llvm/IR/Constants.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/InstIterator.h"
  25. #include "llvm/IR/Instructions.h"
  26. #include "llvm/IR/Module.h"
  27. #include "llvm/Pass.h"
  28. #include "llvm/Support/Debug.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. #include "llvm/Transforms/Scalar.h"
  31. #include <deque>
  32. #include <functional> // For std::function
  33. using namespace llvm;
  34. // The algorithm is simple. Start at instructions that convert from the
  35. // float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use
  36. // graph, using an equivalence datastructure to unify graphs that interfere.
  37. //
  38. // Mappable instructions are those with an integer corrollary that, given
  39. // integer domain inputs, produce an integer output; fadd, for example.
  40. //
  41. // If a non-mappable instruction is seen, this entire def-use graph is marked
  42. // as non-transformable. If we see an instruction that converts from the
  43. // integer domain to FP domain (uitofp,sitofp), we terminate our walk.
  44. #if 0 // HLSL Change Starts - option pending
  45. /// The largest integer type worth dealing with.
  46. static cl::opt<unsigned>
  47. MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
  48. cl::desc("Max integer bitwidth to consider in float2int"
  49. "(default=64)"));
  50. #else
  51. static const unsigned MaxIntegerBW = 64;
  52. #endif // HLSL Change Ends
  53. namespace {
  54. struct Float2Int : public FunctionPass {
  55. static char ID; // Pass identification, replacement for typeid
  56. Float2Int() : FunctionPass(ID) {
  57. initializeFloat2IntPass(*PassRegistry::getPassRegistry());
  58. }
  59. bool runOnFunction(Function &F) override;
  60. void getAnalysisUsage(AnalysisUsage &AU) const override {
  61. AU.setPreservesCFG();
  62. }
  63. void findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots);
  64. ConstantRange seen(Instruction *I, ConstantRange R);
  65. ConstantRange badRange();
  66. ConstantRange unknownRange();
  67. ConstantRange validateRange(ConstantRange R);
  68. void walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots);
  69. void walkForwards();
  70. bool validateAndTransform();
  71. Value *convert(Instruction *I, Type *ToTy);
  72. void cleanup();
  73. MapVector<Instruction*, ConstantRange > SeenInsts;
  74. SmallPtrSet<Instruction*,8> Roots;
  75. EquivalenceClasses<Instruction*> ECs;
  76. MapVector<Instruction*, Value*> ConvertedInsts;
  77. LLVMContext *Ctx;
  78. };
  79. }
  80. char Float2Int::ID = 0;
  81. INITIALIZE_PASS(Float2Int, "float2int", "Float to int", false, false)
  82. // Given a FCmp predicate, return a matching ICmp predicate if one
  83. // exists, otherwise return BAD_ICMP_PREDICATE.
  84. static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
  85. switch (P) {
  86. case CmpInst::FCMP_OEQ:
  87. case CmpInst::FCMP_UEQ:
  88. return CmpInst::ICMP_EQ;
  89. case CmpInst::FCMP_OGT:
  90. case CmpInst::FCMP_UGT:
  91. return CmpInst::ICMP_SGT;
  92. case CmpInst::FCMP_OGE:
  93. case CmpInst::FCMP_UGE:
  94. return CmpInst::ICMP_SGE;
  95. case CmpInst::FCMP_OLT:
  96. case CmpInst::FCMP_ULT:
  97. return CmpInst::ICMP_SLT;
  98. case CmpInst::FCMP_OLE:
  99. case CmpInst::FCMP_ULE:
  100. return CmpInst::ICMP_SLE;
  101. case CmpInst::FCMP_ONE:
  102. case CmpInst::FCMP_UNE:
  103. return CmpInst::ICMP_NE;
  104. default:
  105. return CmpInst::BAD_ICMP_PREDICATE;
  106. }
  107. }
  108. // Given a floating point binary operator, return the matching
  109. // integer version.
  110. static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
  111. switch (Opcode) {
  112. default: llvm_unreachable("Unhandled opcode!");
  113. case Instruction::FAdd: return Instruction::Add;
  114. case Instruction::FSub: return Instruction::Sub;
  115. case Instruction::FMul: return Instruction::Mul;
  116. }
  117. }
  118. // Find the roots - instructions that convert from the FP domain to
  119. // integer domain.
  120. void Float2Int::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) {
  121. for (auto &I : inst_range(F)) {
  122. switch (I.getOpcode()) {
  123. default: break;
  124. case Instruction::FPToUI:
  125. case Instruction::FPToSI:
  126. Roots.insert(&I);
  127. break;
  128. case Instruction::FCmp:
  129. if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
  130. CmpInst::BAD_ICMP_PREDICATE)
  131. Roots.insert(&I);
  132. break;
  133. }
  134. }
  135. }
  136. // Helper - mark I as having been traversed, having range R.
  137. ConstantRange Float2Int::seen(Instruction *I, ConstantRange R) {
  138. DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");
  139. if (SeenInsts.find(I) != SeenInsts.end())
  140. SeenInsts.find(I)->second = R;
  141. else
  142. SeenInsts.insert(std::make_pair(I, R));
  143. return R;
  144. }
  145. // Helper - get a range representing a poison value.
  146. ConstantRange Float2Int::badRange() {
  147. return ConstantRange(MaxIntegerBW + 1, true);
  148. }
  149. ConstantRange Float2Int::unknownRange() {
  150. return ConstantRange(MaxIntegerBW + 1, false);
  151. }
  152. ConstantRange Float2Int::validateRange(ConstantRange R) {
  153. if (R.getBitWidth() > MaxIntegerBW + 1)
  154. return badRange();
  155. return R;
  156. }
  157. // The most obvious way to structure the search is a depth-first, eager
  158. // search from each root. However, that require direct recursion and so
  159. // can only handle small instruction sequences. Instead, we split the search
  160. // up into two phases:
  161. // - walkBackwards: A breadth-first walk of the use-def graph starting from
  162. // the roots. Populate "SeenInsts" with interesting
  163. // instructions and poison values if they're obvious and
  164. // cheap to compute. Calculate the equivalance set structure
  165. // while we're here too.
  166. // - walkForwards: Iterate over SeenInsts in reverse order, so we visit
  167. // defs before their uses. Calculate the real range info.
  168. // Breadth-first walk of the use-def graph; determine the set of nodes
  169. // we care about and eagerly determine if some of them are poisonous.
  170. void Float2Int::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) {
  171. std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
  172. while (!Worklist.empty()) {
  173. Instruction *I = Worklist.back();
  174. Worklist.pop_back();
  175. if (SeenInsts.find(I) != SeenInsts.end())
  176. // Seen already.
  177. continue;
  178. switch (I->getOpcode()) {
  179. // FIXME: Handle select and phi nodes.
  180. default:
  181. // Path terminated uncleanly.
  182. seen(I, badRange());
  183. break;
  184. case Instruction::UIToFP: {
  185. // Path terminated cleanly.
  186. unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
  187. APInt Min = APInt::getMinValue(BW).zextOrSelf(MaxIntegerBW+1);
  188. APInt Max = APInt::getMaxValue(BW).zextOrSelf(MaxIntegerBW+1);
  189. seen(I, validateRange(ConstantRange(Min, Max)));
  190. continue;
  191. }
  192. case Instruction::SIToFP: {
  193. // Path terminated cleanly.
  194. unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
  195. APInt SMin = APInt::getSignedMinValue(BW).sextOrSelf(MaxIntegerBW+1);
  196. APInt SMax = APInt::getSignedMaxValue(BW).sextOrSelf(MaxIntegerBW+1);
  197. seen(I, validateRange(ConstantRange(SMin, SMax)));
  198. continue;
  199. }
  200. case Instruction::FAdd:
  201. case Instruction::FSub:
  202. case Instruction::FMul:
  203. case Instruction::FPToUI:
  204. case Instruction::FPToSI:
  205. case Instruction::FCmp:
  206. seen(I, unknownRange());
  207. break;
  208. }
  209. for (Value *O : I->operands()) {
  210. if (Instruction *OI = dyn_cast<Instruction>(O)) {
  211. // Unify def-use chains if they interfere.
  212. ECs.unionSets(I, OI);
  213. if (SeenInsts.find(I)->second != badRange())
  214. Worklist.push_back(OI);
  215. } else if (!isa<ConstantFP>(O)) {
  216. // Not an instruction or ConstantFP? we can't do anything.
  217. seen(I, badRange());
  218. }
  219. }
  220. }
  221. }
  222. // Walk forwards down the list of seen instructions, so we visit defs before
  223. // uses.
  224. void Float2Int::walkForwards() {
  225. for (auto It = SeenInsts.rbegin(), E = SeenInsts.rend(); It != E; ++It) {
  226. if (It->second != unknownRange())
  227. continue;
  228. Instruction *I = It->first;
  229. std::function<ConstantRange(ArrayRef<ConstantRange>)> Op;
  230. switch (I->getOpcode()) {
  231. // FIXME: Handle select and phi nodes.
  232. default:
  233. case Instruction::UIToFP:
  234. case Instruction::SIToFP:
  235. llvm_unreachable("Should have been handled in walkForwards!");
  236. case Instruction::FAdd:
  237. Op = [](ArrayRef<ConstantRange> Ops) {
  238. assert(Ops.size() == 2 && "FAdd is a binary operator!");
  239. return Ops[0].add(Ops[1]);
  240. };
  241. break;
  242. case Instruction::FSub:
  243. Op = [](ArrayRef<ConstantRange> Ops) {
  244. assert(Ops.size() == 2 && "FSub is a binary operator!");
  245. return Ops[0].sub(Ops[1]);
  246. };
  247. break;
  248. case Instruction::FMul:
  249. Op = [](ArrayRef<ConstantRange> Ops) {
  250. assert(Ops.size() == 2 && "FMul is a binary operator!");
  251. return Ops[0].multiply(Ops[1]);
  252. };
  253. break;
  254. //
  255. // Root-only instructions - we'll only see these if they're the
  256. // first node in a walk.
  257. //
  258. case Instruction::FPToUI:
  259. case Instruction::FPToSI:
  260. Op = [](ArrayRef<ConstantRange> Ops) {
  261. assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!");
  262. return Ops[0];
  263. };
  264. break;
  265. case Instruction::FCmp:
  266. Op = [](ArrayRef<ConstantRange> Ops) {
  267. assert(Ops.size() == 2 && "FCmp is a binary operator!");
  268. return Ops[0].unionWith(Ops[1]);
  269. };
  270. break;
  271. }
  272. bool Abort = false;
  273. SmallVector<ConstantRange,4> OpRanges;
  274. for (Value *O : I->operands()) {
  275. if (Instruction *OI = dyn_cast<Instruction>(O)) {
  276. assert(SeenInsts.find(OI) != SeenInsts.end() &&
  277. "def not seen before use!");
  278. OpRanges.push_back(SeenInsts.find(OI)->second);
  279. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {
  280. // Work out if the floating point number can be losslessly represented
  281. // as an integer.
  282. // APFloat::convertToInteger(&Exact) purports to do what we want, but
  283. // the exactness can be too precise. For example, negative zero can
  284. // never be exactly converted to an integer.
  285. //
  286. // Instead, we ask APFloat to round itself to an integral value - this
  287. // preserves sign-of-zero - then compare the result with the original.
  288. //
  289. APFloat F = CF->getValueAPF();
  290. // First, weed out obviously incorrect values. Non-finite numbers
  291. // can't be represented and neither can negative zero, unless
  292. // we're in fast math mode.
  293. if (!F.isFinite() ||
  294. (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&
  295. !I->hasNoSignedZeros())) {
  296. seen(I, badRange());
  297. Abort = true;
  298. break;
  299. }
  300. APFloat NewF = F;
  301. auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
  302. if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) {
  303. seen(I, badRange());
  304. Abort = true;
  305. break;
  306. }
  307. // OK, it's representable. Now get it.
  308. APSInt Int(MaxIntegerBW+1, false);
  309. bool Exact;
  310. CF->getValueAPF().convertToInteger(Int,
  311. APFloat::rmNearestTiesToEven,
  312. &Exact);
  313. OpRanges.push_back(ConstantRange(Int));
  314. } else {
  315. llvm_unreachable("Should have already marked this as badRange!");
  316. }
  317. }
  318. // Reduce the operands' ranges to a single range and return.
  319. if (!Abort)
  320. seen(I, Op(OpRanges));
  321. }
  322. }
  323. // If there is a valid transform to be done, do it.
  324. bool Float2Int::validateAndTransform() {
  325. bool MadeChange = false;
  326. // Iterate over every disjoint partition of the def-use graph.
  327. for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {
  328. ConstantRange R(MaxIntegerBW + 1, false);
  329. bool Fail = false;
  330. Type *ConvertedToTy = nullptr;
  331. // For every member of the partition, union all the ranges together.
  332. for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
  333. MI != ME; ++MI) {
  334. Instruction *I = *MI;
  335. auto SeenI = SeenInsts.find(I);
  336. if (SeenI == SeenInsts.end())
  337. continue;
  338. R = R.unionWith(SeenI->second);
  339. // We need to ensure I has no users that have not been seen.
  340. // If it does, transformation would be illegal.
  341. //
  342. // Don't count the roots, as they terminate the graphs.
  343. if (Roots.count(I) == 0) {
  344. // Set the type of the conversion while we're here.
  345. if (!ConvertedToTy)
  346. ConvertedToTy = I->getType();
  347. for (User *U : I->users()) {
  348. Instruction *UI = dyn_cast<Instruction>(U);
  349. if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
  350. DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
  351. Fail = true;
  352. break;
  353. }
  354. }
  355. }
  356. if (Fail)
  357. break;
  358. }
  359. // If the set was empty, or we failed, or the range is poisonous,
  360. // bail out.
  361. if (ECs.member_begin(It) == ECs.member_end() || Fail ||
  362. R.isFullSet() || R.isSignWrappedSet())
  363. continue;
  364. assert(ConvertedToTy && "Must have set the convertedtoty by this point!");
  365. // The number of bits required is the maximum of the upper and
  366. // lower limits, plus one so it can be signed.
  367. unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
  368. R.getUpper().getMinSignedBits()) + 1;
  369. DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
  370. // If we've run off the realms of the exactly representable integers,
  371. // the floating point result will differ from an integer approximation.
  372. // Do we need more bits than are in the mantissa of the type we converted
  373. // to? semanticsPrecision returns the number of mantissa bits plus one
  374. // for the sign bit.
  375. unsigned MaxRepresentableBits
  376. = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;
  377. if (MinBW > MaxRepresentableBits) {
  378. DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");
  379. continue;
  380. }
  381. if (MinBW > 64) {
  382. DEBUG(dbgs() << "F2I: Value requires more than 64 bits to represent!\n");
  383. continue;
  384. }
  385. // OK, R is known to be representable. Now pick a type for it.
  386. // FIXME: Pick the smallest legal type that will fit.
  387. Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx);
  388. for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
  389. MI != ME; ++MI)
  390. convert(*MI, Ty);
  391. MadeChange = true;
  392. }
  393. return MadeChange;
  394. }
  395. Value *Float2Int::convert(Instruction *I, Type *ToTy) {
  396. if (ConvertedInsts.find(I) != ConvertedInsts.end())
  397. // Already converted this instruction.
  398. return ConvertedInsts[I];
  399. SmallVector<Value*,4> NewOperands;
  400. for (Value *V : I->operands()) {
  401. // Don't recurse if we're an instruction that terminates the path.
  402. if (I->getOpcode() == Instruction::UIToFP ||
  403. I->getOpcode() == Instruction::SIToFP) {
  404. NewOperands.push_back(V);
  405. } else if (Instruction *VI = dyn_cast<Instruction>(V)) {
  406. NewOperands.push_back(convert(VI, ToTy));
  407. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {
  408. APSInt Val(ToTy->getPrimitiveSizeInBits(), /*IsUnsigned=*/false);
  409. bool Exact;
  410. CF->getValueAPF().convertToInteger(Val,
  411. APFloat::rmNearestTiesToEven,
  412. &Exact);
  413. NewOperands.push_back(ConstantInt::get(ToTy, Val));
  414. } else {
  415. llvm_unreachable("Unhandled operand type?");
  416. }
  417. }
  418. // Now create a new instruction.
  419. IRBuilder<> IRB(I);
  420. Value *NewV = nullptr;
  421. switch (I->getOpcode()) {
  422. default: llvm_unreachable("Unhandled instruction!");
  423. case Instruction::FPToUI:
  424. NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());
  425. break;
  426. case Instruction::FPToSI:
  427. NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());
  428. break;
  429. case Instruction::FCmp: {
  430. CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());
  431. assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");
  432. NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());
  433. break;
  434. }
  435. case Instruction::UIToFP:
  436. NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);
  437. break;
  438. case Instruction::SIToFP:
  439. NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);
  440. break;
  441. case Instruction::FAdd:
  442. case Instruction::FSub:
  443. case Instruction::FMul:
  444. NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),
  445. NewOperands[0], NewOperands[1],
  446. I->getName());
  447. break;
  448. }
  449. // If we're a root instruction, RAUW.
  450. if (Roots.count(I))
  451. I->replaceAllUsesWith(NewV);
  452. ConvertedInsts[I] = NewV;
  453. return NewV;
  454. }
  455. // Perform dead code elimination on the instructions we just modified.
  456. void Float2Int::cleanup() {
  457. for (auto I = ConvertedInsts.rbegin(), E = ConvertedInsts.rend();
  458. I != E; ++I)
  459. I->first->eraseFromParent();
  460. }
  461. bool Float2Int::runOnFunction(Function &F) {
  462. if (skipOptnoneFunction(F))
  463. return false;
  464. DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
  465. // Clear out all state.
  466. ECs = EquivalenceClasses<Instruction*>();
  467. SeenInsts.clear();
  468. ConvertedInsts.clear();
  469. Roots.clear();
  470. Ctx = &F.getParent()->getContext();
  471. findRoots(F, Roots);
  472. walkBackwards(Roots);
  473. walkForwards();
  474. bool Modified = validateAndTransform();
  475. if (Modified)
  476. cleanup();
  477. return Modified;
  478. }
  479. FunctionPass *llvm::createFloat2IntPass() {
  480. return new Float2Int();
  481. }