AlignmentFromAssumptions.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
  2. // Set Load/Store Alignments From Assumptions
  3. //
  4. // The LLVM Compiler Infrastructure
  5. //
  6. // This file is distributed under the University of Illinois Open Source
  7. // License. See LICENSE.TXT for details.
  8. //
  9. //===----------------------------------------------------------------------===//
  10. //
  11. // This file implements a ScalarEvolution-based transformation to set
  12. // the alignments of load, stores and memory intrinsics based on the truth
  13. // expressions of assume intrinsics. The primary motivation is to handle
  14. // complex alignment assumptions that apply to vector loads and stores that
  15. // appear after vectorization and unrolling.
  16. //
  17. //===----------------------------------------------------------------------===//
  18. #define AA_NAME "alignment-from-assumptions"
  19. #define DEBUG_TYPE AA_NAME
  20. #include "llvm/Transforms/Scalar.h"
  21. #include "llvm/ADT/SmallPtrSet.h"
  22. #include "llvm/ADT/Statistic.h"
  23. #include "llvm/Analysis/AssumptionCache.h"
  24. #include "llvm/Analysis/LoopInfo.h"
  25. #include "llvm/Analysis/ScalarEvolution.h"
  26. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  27. #include "llvm/Analysis/ValueTracking.h"
  28. #include "llvm/IR/Constant.h"
  29. #include "llvm/IR/Dominators.h"
  30. #include "llvm/IR/Instruction.h"
  31. #include "llvm/IR/IntrinsicInst.h"
  32. #include "llvm/IR/Intrinsics.h"
  33. #include "llvm/IR/Module.h"
  34. #include "llvm/Support/Debug.h"
  35. #include "llvm/Support/raw_ostream.h"
  36. using namespace llvm;
  37. STATISTIC(NumLoadAlignChanged,
  38. "Number of loads changed by alignment assumptions");
  39. STATISTIC(NumStoreAlignChanged,
  40. "Number of stores changed by alignment assumptions");
  41. STATISTIC(NumMemIntAlignChanged,
  42. "Number of memory intrinsics changed by alignment assumptions");
  43. namespace {
  44. struct AlignmentFromAssumptions : public FunctionPass {
  45. static char ID; // Pass identification, replacement for typeid
  46. AlignmentFromAssumptions() : FunctionPass(ID) {
  47. initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
  48. }
  49. bool runOnFunction(Function &F) override;
  50. void getAnalysisUsage(AnalysisUsage &AU) const override {
  51. AU.addRequired<AssumptionCacheTracker>();
  52. AU.addRequired<ScalarEvolution>();
  53. AU.addRequired<DominatorTreeWrapperPass>();
  54. AU.setPreservesCFG();
  55. AU.addPreserved<LoopInfoWrapperPass>();
  56. AU.addPreserved<DominatorTreeWrapperPass>();
  57. AU.addPreserved<ScalarEvolution>();
  58. }
  59. // For memory transfers, we need a common alignment for both the source and
  60. // destination. If we have a new alignment for only one operand of a transfer
  61. // instruction, save it in these maps. If we reach the other operand through
  62. // another assumption later, then we may change the alignment at that point.
  63. DenseMap<MemTransferInst *, unsigned> NewDestAlignments, NewSrcAlignments;
  64. ScalarEvolution *SE;
  65. DominatorTree *DT;
  66. bool extractAlignmentInfo(CallInst *I, Value *&AAPtr, const SCEV *&AlignSCEV,
  67. const SCEV *&OffSCEV);
  68. bool processAssumption(CallInst *I);
  69. };
  70. }
  71. char AlignmentFromAssumptions::ID = 0;
  72. static const char aip_name[] = "Alignment from assumptions";
  73. INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
  74. aip_name, false, false)
  75. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  76. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  77. INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
  78. INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
  79. aip_name, false, false)
  80. FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
  81. return new AlignmentFromAssumptions();
  82. }
  83. // Given an expression for the (constant) alignment, AlignSCEV, and an
  84. // expression for the displacement between a pointer and the aligned address,
  85. // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
  86. // to a constant. Using SCEV to compute alignment handles the case where
  87. // DiffSCEV is a recurrence with constant start such that the aligned offset
  88. // is constant. e.g. {16,+,32} % 32 -> 16.
  89. static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV,
  90. const SCEV *AlignSCEV,
  91. ScalarEvolution *SE) {
  92. // DiffUnits = Diff % int64_t(Alignment)
  93. const SCEV *DiffAlignDiv = SE->getUDivExpr(DiffSCEV, AlignSCEV);
  94. const SCEV *DiffAlign = SE->getMulExpr(DiffAlignDiv, AlignSCEV);
  95. const SCEV *DiffUnitsSCEV = SE->getMinusSCEV(DiffAlign, DiffSCEV);
  96. DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " <<
  97. *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
  98. if (const SCEVConstant *ConstDUSCEV =
  99. dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
  100. int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
  101. // If the displacement is an exact multiple of the alignment, then the
  102. // displaced pointer has the same alignment as the aligned pointer, so
  103. // return the alignment value.
  104. if (!DiffUnits)
  105. return (unsigned)
  106. cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue();
  107. // If the displacement is not an exact multiple, but the remainder is a
  108. // constant, then return this remainder (but only if it is a power of 2).
  109. uint64_t DiffUnitsAbs = std::abs(DiffUnits);
  110. if (isPowerOf2_64(DiffUnitsAbs))
  111. return (unsigned) DiffUnitsAbs;
  112. }
  113. return 0;
  114. }
  115. // There is an address given by an offset OffSCEV from AASCEV which has an
  116. // alignment AlignSCEV. Use that information, if possible, to compute a new
  117. // alignment for Ptr.
  118. static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
  119. const SCEV *OffSCEV, Value *Ptr,
  120. ScalarEvolution *SE) {
  121. const SCEV *PtrSCEV = SE->getSCEV(Ptr);
  122. const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
  123. // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
  124. // sign-extended OffSCEV to i64, so make sure they agree again.
  125. DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
  126. // What we really want to know is the overall offset to the aligned
  127. // address. This address is displaced by the provided offset.
  128. DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV);
  129. DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " <<
  130. *AlignSCEV << " and offset " << *OffSCEV <<
  131. " using diff " << *DiffSCEV << "\n");
  132. unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE);
  133. DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n");
  134. if (NewAlignment) {
  135. return NewAlignment;
  136. } else if (const SCEVAddRecExpr *DiffARSCEV =
  137. dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
  138. // The relative offset to the alignment assumption did not yield a constant,
  139. // but we should try harder: if we assume that a is 32-byte aligned, then in
  140. // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
  141. // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
  142. // As a result, the new alignment will not be a constant, but can still
  143. // be improved over the default (of 4) to 16.
  144. const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
  145. const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
  146. DEBUG(dbgs() << "\ttrying start/inc alignment using start " <<
  147. *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
  148. // Now compute the new alignment using the displacement to the value in the
  149. // first iteration, and also the alignment using the per-iteration delta.
  150. // If these are the same, then use that answer. Otherwise, use the smaller
  151. // one, but only if it divides the larger one.
  152. NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
  153. unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
  154. DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n");
  155. DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n");
  156. if (!NewAlignment || !NewIncAlignment) {
  157. return 0;
  158. } else if (NewAlignment > NewIncAlignment) {
  159. if (NewAlignment % NewIncAlignment == 0) {
  160. DEBUG(dbgs() << "\tnew start/inc alignment: " <<
  161. NewIncAlignment << "\n");
  162. return NewIncAlignment;
  163. }
  164. } else if (NewIncAlignment > NewAlignment) {
  165. if (NewIncAlignment % NewAlignment == 0) {
  166. DEBUG(dbgs() << "\tnew start/inc alignment: " <<
  167. NewAlignment << "\n");
  168. return NewAlignment;
  169. }
  170. } else if (NewIncAlignment == NewAlignment) {
  171. DEBUG(dbgs() << "\tnew start/inc alignment: " <<
  172. NewAlignment << "\n");
  173. return NewAlignment;
  174. }
  175. }
  176. return 0;
  177. }
  178. bool AlignmentFromAssumptions::extractAlignmentInfo(CallInst *I,
  179. Value *&AAPtr, const SCEV *&AlignSCEV,
  180. const SCEV *&OffSCEV) {
  181. // An alignment assume must be a statement about the least-significant
  182. // bits of the pointer being zero, possibly with some offset.
  183. ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
  184. if (!ICI)
  185. return false;
  186. // This must be an expression of the form: x & m == 0.
  187. if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
  188. return false;
  189. // Swap things around so that the RHS is 0.
  190. Value *CmpLHS = ICI->getOperand(0);
  191. Value *CmpRHS = ICI->getOperand(1);
  192. const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
  193. const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
  194. if (CmpLHSSCEV->isZero())
  195. std::swap(CmpLHS, CmpRHS);
  196. else if (!CmpRHSSCEV->isZero())
  197. return false;
  198. BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
  199. if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
  200. return false;
  201. // Swap things around so that the right operand of the and is a constant
  202. // (the mask); we cannot deal with variable masks.
  203. Value *AndLHS = CmpBO->getOperand(0);
  204. Value *AndRHS = CmpBO->getOperand(1);
  205. const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
  206. const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
  207. if (isa<SCEVConstant>(AndLHSSCEV)) {
  208. std::swap(AndLHS, AndRHS);
  209. std::swap(AndLHSSCEV, AndRHSSCEV);
  210. }
  211. const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
  212. if (!MaskSCEV)
  213. return false;
  214. // The mask must have some trailing ones (otherwise the condition is
  215. // trivial and tells us nothing about the alignment of the left operand).
  216. unsigned TrailingOnes =
  217. MaskSCEV->getValue()->getValue().countTrailingOnes();
  218. if (!TrailingOnes)
  219. return false;
  220. // Cap the alignment at the maximum with which LLVM can deal (and make sure
  221. // we don't overflow the shift).
  222. uint64_t Alignment;
  223. TrailingOnes = std::min(TrailingOnes,
  224. unsigned(sizeof(unsigned) * CHAR_BIT - 1));
  225. Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
  226. Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext());
  227. AlignSCEV = SE->getConstant(Int64Ty, Alignment);
  228. // The LHS might be a ptrtoint instruction, or it might be the pointer
  229. // with an offset.
  230. AAPtr = nullptr;
  231. OffSCEV = nullptr;
  232. if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
  233. AAPtr = PToI->getPointerOperand();
  234. OffSCEV = SE->getConstant(Int64Ty, 0);
  235. } else if (const SCEVAddExpr* AndLHSAddSCEV =
  236. dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
  237. // Try to find the ptrtoint; subtract it and the rest is the offset.
  238. for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
  239. JE = AndLHSAddSCEV->op_end(); J != JE; ++J)
  240. if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
  241. if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
  242. AAPtr = PToI->getPointerOperand();
  243. OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
  244. break;
  245. }
  246. }
  247. if (!AAPtr)
  248. return false;
  249. // Sign extend the offset to 64 bits (so that it is like all of the other
  250. // expressions).
  251. unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
  252. if (OffSCEVBits < 64)
  253. OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
  254. else if (OffSCEVBits > 64)
  255. return false;
  256. AAPtr = AAPtr->stripPointerCasts();
  257. return true;
  258. }
  259. bool AlignmentFromAssumptions::processAssumption(CallInst *ACall) {
  260. Value *AAPtr;
  261. const SCEV *AlignSCEV, *OffSCEV;
  262. if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV))
  263. return false;
  264. const SCEV *AASCEV = SE->getSCEV(AAPtr);
  265. // Apply the assumption to all other users of the specified pointer.
  266. SmallPtrSet<Instruction *, 32> Visited;
  267. SmallVector<Instruction*, 16> WorkList;
  268. for (User *J : AAPtr->users()) {
  269. if (J == ACall)
  270. continue;
  271. if (Instruction *K = dyn_cast<Instruction>(J))
  272. if (isValidAssumeForContext(ACall, K, DT))
  273. WorkList.push_back(K);
  274. }
  275. while (!WorkList.empty()) {
  276. Instruction *J = WorkList.pop_back_val();
  277. if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
  278. unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
  279. LI->getPointerOperand(), SE);
  280. if (NewAlignment > LI->getAlignment()) {
  281. LI->setAlignment(NewAlignment);
  282. ++NumLoadAlignChanged;
  283. }
  284. } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
  285. unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
  286. SI->getPointerOperand(), SE);
  287. if (NewAlignment > SI->getAlignment()) {
  288. SI->setAlignment(NewAlignment);
  289. ++NumStoreAlignChanged;
  290. }
  291. } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
  292. unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
  293. MI->getDest(), SE);
  294. // For memory transfers, we need a common alignment for both the
  295. // source and destination. If we have a new alignment for this
  296. // instruction, but only for one operand, save it. If we reach the
  297. // other operand through another assumption later, then we may
  298. // change the alignment at that point.
  299. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
  300. unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
  301. MTI->getSource(), SE);
  302. DenseMap<MemTransferInst *, unsigned>::iterator DI =
  303. NewDestAlignments.find(MTI);
  304. unsigned AltDestAlignment = (DI == NewDestAlignments.end()) ?
  305. 0 : DI->second;
  306. DenseMap<MemTransferInst *, unsigned>::iterator SI =
  307. NewSrcAlignments.find(MTI);
  308. unsigned AltSrcAlignment = (SI == NewSrcAlignments.end()) ?
  309. 0 : SI->second;
  310. DEBUG(dbgs() << "\tmem trans: " << NewDestAlignment << " " <<
  311. AltDestAlignment << " " << NewSrcAlignment <<
  312. " " << AltSrcAlignment << "\n");
  313. // Of these four alignments, pick the largest possible...
  314. unsigned NewAlignment = 0;
  315. if (NewDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment))
  316. NewAlignment = std::max(NewAlignment, NewDestAlignment);
  317. if (AltDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment))
  318. NewAlignment = std::max(NewAlignment, AltDestAlignment);
  319. if (NewSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment))
  320. NewAlignment = std::max(NewAlignment, NewSrcAlignment);
  321. if (AltSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment))
  322. NewAlignment = std::max(NewAlignment, AltSrcAlignment);
  323. if (NewAlignment > MI->getAlignment()) {
  324. MI->setAlignment(ConstantInt::get(Type::getInt32Ty(
  325. MI->getParent()->getContext()), NewAlignment));
  326. ++NumMemIntAlignChanged;
  327. }
  328. NewDestAlignments.insert(std::make_pair(MTI, NewDestAlignment));
  329. NewSrcAlignments.insert(std::make_pair(MTI, NewSrcAlignment));
  330. } else if (NewDestAlignment > MI->getAlignment()) {
  331. assert((!isa<MemIntrinsic>(MI) || isa<MemSetInst>(MI)) &&
  332. "Unknown memory intrinsic");
  333. MI->setAlignment(ConstantInt::get(Type::getInt32Ty(
  334. MI->getParent()->getContext()), NewDestAlignment));
  335. ++NumMemIntAlignChanged;
  336. }
  337. }
  338. // Now that we've updated that use of the pointer, look for other uses of
  339. // the pointer to update.
  340. Visited.insert(J);
  341. for (User *UJ : J->users()) {
  342. Instruction *K = cast<Instruction>(UJ);
  343. if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
  344. WorkList.push_back(K);
  345. }
  346. }
  347. return true;
  348. }
  349. bool AlignmentFromAssumptions::runOnFunction(Function &F) {
  350. bool Changed = false;
  351. auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  352. SE = &getAnalysis<ScalarEvolution>();
  353. DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  354. NewDestAlignments.clear();
  355. NewSrcAlignments.clear();
  356. for (auto &AssumeVH : AC.assumptions())
  357. if (AssumeVH)
  358. Changed |= processAssumption(cast<CallInst>(AssumeVH));
  359. return Changed;
  360. }