CostModel.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. //===- CostModel.cpp ------ Cost Model Analysis ---------------------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This file defines the cost model analysis. It provides a very basic cost
  11. // estimation for LLVM-IR. This analysis uses the services of the codegen
  12. // to approximate the cost of any IR instruction when lowered to machine
  13. // instructions. The cost results are unit-less and the cost number represents
  14. // the throughput of the machine assuming that all loads hit the cache, all
  15. // branches are predicted, etc. The cost numbers can be added in order to
  16. // compare two or more transformation alternatives.
  17. //
  18. //===----------------------------------------------------------------------===//
  19. #include "llvm/ADT/STLExtras.h"
  20. #include "llvm/Analysis/Passes.h"
  21. #include "llvm/Analysis/TargetTransformInfo.h"
  22. #include "llvm/IR/Function.h"
  23. #include "llvm/IR/Instructions.h"
  24. #include "llvm/IR/IntrinsicInst.h"
  25. #include "llvm/IR/Value.h"
  26. #include "llvm/Pass.h"
  27. #include "llvm/Support/CommandLine.h"
  28. #include "llvm/Support/Debug.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. using namespace llvm;
  31. #define CM_NAME "cost-model"
  32. #define DEBUG_TYPE CM_NAME
  33. static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false),
  34. cl::Hidden,
  35. cl::desc("Recognize reduction patterns."));
  36. namespace {
  37. class CostModelAnalysis : public FunctionPass {
  38. public:
  39. static char ID; // Class identification, replacement for typeinfo
  40. CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) {
  41. initializeCostModelAnalysisPass(
  42. *PassRegistry::getPassRegistry());
  43. }
  44. /// Returns the expected cost of the instruction.
  45. /// Returns -1 if the cost is unknown.
  46. /// Note, this method does not cache the cost calculation and it
  47. /// can be expensive in some cases.
  48. unsigned getInstructionCost(const Instruction *I) const;
  49. private:
  50. void getAnalysisUsage(AnalysisUsage &AU) const override;
  51. bool runOnFunction(Function &F) override;
  52. void print(raw_ostream &OS, const Module*) const override;
  53. /// The function that we analyze.
  54. Function *F;
  55. /// Target information.
  56. const TargetTransformInfo *TTI;
  57. };
  58. } // End of anonymous namespace
  59. // Register this pass.
  60. char CostModelAnalysis::ID = 0;
  61. static const char cm_name[] = "Cost Model Analysis";
  62. INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true)
  63. INITIALIZE_PASS_END (CostModelAnalysis, CM_NAME, cm_name, false, true)
  64. FunctionPass *llvm::createCostModelAnalysisPass() {
  65. return new CostModelAnalysis();
  66. }
  67. void
  68. CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
  69. AU.setPreservesAll();
  70. }
  71. bool
  72. CostModelAnalysis::runOnFunction(Function &F) {
  73. this->F = &F;
  74. auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
  75. TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr;
  76. return false;
  77. }
  78. static bool isReverseVectorMask(SmallVectorImpl<int> &Mask) {
  79. for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i)
  80. if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i))
  81. return false;
  82. return true;
  83. }
  84. static bool isAlternateVectorMask(SmallVectorImpl<int> &Mask) {
  85. bool isAlternate = true;
  86. unsigned MaskSize = Mask.size();
  87. // Example: shufflevector A, B, <0,5,2,7>
  88. for (unsigned i = 0; i < MaskSize && isAlternate; ++i) {
  89. if (Mask[i] < 0)
  90. continue;
  91. isAlternate = Mask[i] == (int)((i & 1) ? MaskSize + i : i);
  92. }
  93. if (isAlternate)
  94. return true;
  95. isAlternate = true;
  96. // Example: shufflevector A, B, <4,1,6,3>
  97. for (unsigned i = 0; i < MaskSize && isAlternate; ++i) {
  98. if (Mask[i] < 0)
  99. continue;
  100. isAlternate = Mask[i] == (int)((i & 1) ? i : MaskSize + i);
  101. }
  102. return isAlternate;
  103. }
  104. static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) {
  105. TargetTransformInfo::OperandValueKind OpInfo =
  106. TargetTransformInfo::OK_AnyValue;
  107. // Check for a splat of a constant or for a non uniform vector of constants.
  108. if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) {
  109. OpInfo = TargetTransformInfo::OK_NonUniformConstantValue;
  110. if (cast<Constant>(V)->getSplatValue() != nullptr)
  111. OpInfo = TargetTransformInfo::OK_UniformConstantValue;
  112. }
  113. return OpInfo;
  114. }
  115. static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft,
  116. unsigned Level) {
  117. // We don't need a shuffle if we just want to have element 0 in position 0 of
  118. // the vector.
  119. if (!SI && Level == 0 && IsLeft)
  120. return true;
  121. else if (!SI)
  122. return false;
  123. SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1);
  124. // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether
  125. // we look at the left or right side.
  126. for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2)
  127. Mask[i] = val;
  128. SmallVector<int, 16> ActualMask = SI->getShuffleMask();
  129. if (Mask != ActualMask)
  130. return false;
  131. return true;
  132. }
  133. static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp,
  134. unsigned Level, unsigned NumLevels) {
  135. // Match one level of pairwise operations.
  136. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
  137. // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
  138. // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
  139. // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
  140. // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
  141. if (BinOp == nullptr)
  142. return false;
  143. assert(BinOp->getType()->isVectorTy() && "Expecting a vector type");
  144. unsigned Opcode = BinOp->getOpcode();
  145. Value *L = BinOp->getOperand(0);
  146. Value *R = BinOp->getOperand(1);
  147. ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L);
  148. if (!LS && Level)
  149. return false;
  150. ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R);
  151. if (!RS && Level)
  152. return false;
  153. // On level 0 we can omit one shufflevector instruction.
  154. if (!Level && !RS && !LS)
  155. return false;
  156. // Shuffle inputs must match.
  157. Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
  158. Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr;
  159. Value *NextLevelOp = nullptr;
  160. if (NextLevelOpR && NextLevelOpL) {
  161. // If we have two shuffles their operands must match.
  162. if (NextLevelOpL != NextLevelOpR)
  163. return false;
  164. NextLevelOp = NextLevelOpL;
  165. } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
  166. // On the first level we can omit the shufflevector <0, undef,...>. So the
  167. // input to the other shufflevector <1, undef> must match with one of the
  168. // inputs to the current binary operation.
  169. // Example:
  170. // %NextLevelOpL = shufflevector %R, <1, undef ...>
  171. // %BinOp = fadd %NextLevelOpL, %R
  172. if (NextLevelOpL && NextLevelOpL != R)
  173. return false;
  174. else if (NextLevelOpR && NextLevelOpR != L)
  175. return false;
  176. NextLevelOp = NextLevelOpL ? R : L;
  177. } else
  178. return false;
  179. // Check that the next levels binary operation exists and matches with the
  180. // current one.
  181. BinaryOperator *NextLevelBinOp = nullptr;
  182. if (Level + 1 != NumLevels) {
  183. if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp)))
  184. return false;
  185. else if (NextLevelBinOp->getOpcode() != Opcode)
  186. return false;
  187. }
  188. // Shuffle mask for pairwise operation must match.
  189. if (matchPairwiseShuffleMask(LS, true, Level)) {
  190. if (!matchPairwiseShuffleMask(RS, false, Level))
  191. return false;
  192. } else if (matchPairwiseShuffleMask(RS, true, Level)) {
  193. if (!matchPairwiseShuffleMask(LS, false, Level))
  194. return false;
  195. } else
  196. return false;
  197. if (++Level == NumLevels)
  198. return true;
  199. // Match next level.
  200. return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels);
  201. }
  202. static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
  203. unsigned &Opcode, Type *&Ty) {
  204. if (!EnableReduxCost)
  205. return false;
  206. // Need to extract the first element.
  207. ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
  208. unsigned Idx = ~0u;
  209. if (CI)
  210. Idx = CI->getZExtValue();
  211. if (Idx != 0)
  212. return false;
  213. BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
  214. if (!RdxStart)
  215. return false;
  216. Type *VecTy = ReduxRoot->getOperand(0)->getType();
  217. unsigned NumVecElems = VecTy->getVectorNumElements();
  218. if (!isPowerOf2_32(NumVecElems))
  219. return false;
  220. // We look for a sequence of shuffle,shuffle,add triples like the following
  221. // that builds a pairwise reduction tree.
  222. //
  223. // (X0, X1, X2, X3)
  224. // (X0 + X1, X2 + X3, undef, undef)
  225. // ((X0 + X1) + (X2 + X3), undef, undef, undef)
  226. //
  227. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
  228. // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
  229. // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
  230. // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
  231. // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
  232. // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
  233. // <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef>
  234. // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
  235. // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
  236. // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
  237. // %r = extractelement <4 x float> %bin.rdx8, i32 0
  238. if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)))
  239. return false;
  240. Opcode = RdxStart->getOpcode();
  241. Ty = VecTy;
  242. return true;
  243. }
  244. static std::pair<Value *, ShuffleVectorInst *>
  245. getShuffleAndOtherOprd(BinaryOperator *B) {
  246. Value *L = B->getOperand(0);
  247. Value *R = B->getOperand(1);
  248. ShuffleVectorInst *S = nullptr;
  249. if ((S = dyn_cast<ShuffleVectorInst>(L)))
  250. return std::make_pair(R, S);
  251. S = dyn_cast<ShuffleVectorInst>(R);
  252. return std::make_pair(L, S);
  253. }
  254. static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
  255. unsigned &Opcode, Type *&Ty) {
  256. if (!EnableReduxCost)
  257. return false;
  258. // Need to extract the first element.
  259. ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
  260. unsigned Idx = ~0u;
  261. if (CI)
  262. Idx = CI->getZExtValue();
  263. if (Idx != 0)
  264. return false;
  265. BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
  266. if (!RdxStart)
  267. return false;
  268. unsigned RdxOpcode = RdxStart->getOpcode();
  269. Type *VecTy = ReduxRoot->getOperand(0)->getType();
  270. unsigned NumVecElems = VecTy->getVectorNumElements();
  271. if (!isPowerOf2_32(NumVecElems))
  272. return false;
  273. // We look for a sequence of shuffles and adds like the following matching one
  274. // fadd, shuffle vector pair at a time.
  275. //
  276. // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef,
  277. // <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
  278. // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf
  279. // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef,
  280. // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
  281. // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7
  282. // %r = extractelement <4 x float> %bin.rdx8, i32 0
  283. unsigned MaskStart = 1;
  284. Value *RdxOp = RdxStart;
  285. SmallVector<int, 32> ShuffleMask(NumVecElems, 0);
  286. unsigned NumVecElemsRemain = NumVecElems;
  287. while (NumVecElemsRemain - 1) {
  288. // Check for the right reduction operation.
  289. BinaryOperator *BinOp;
  290. if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp)))
  291. return false;
  292. if (BinOp->getOpcode() != RdxOpcode)
  293. return false;
  294. Value *NextRdxOp;
  295. ShuffleVectorInst *Shuffle;
  296. std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp);
  297. // Check the current reduction operation and the shuffle use the same value.
  298. if (Shuffle == nullptr)
  299. return false;
  300. if (Shuffle->getOperand(0) != NextRdxOp)
  301. return false;
  302. // Check that shuffle masks matches.
  303. for (unsigned j = 0; j != MaskStart; ++j)
  304. ShuffleMask[j] = MaskStart + j;
  305. // Fill the rest of the mask with -1 for undef.
  306. std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1);
  307. SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
  308. if (ShuffleMask != Mask)
  309. return false;
  310. RdxOp = NextRdxOp;
  311. NumVecElemsRemain /= 2;
  312. MaskStart *= 2;
  313. }
  314. Opcode = RdxOpcode;
  315. Ty = VecTy;
  316. return true;
  317. }
  318. unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
  319. if (!TTI)
  320. return -1;
  321. switch (I->getOpcode()) {
  322. case Instruction::GetElementPtr:{
  323. Type *ValTy = I->getOperand(0)->getType()->getPointerElementType();
  324. return TTI->getAddressComputationCost(ValTy);
  325. }
  326. case Instruction::Ret:
  327. case Instruction::PHI:
  328. case Instruction::Br: {
  329. return TTI->getCFInstrCost(I->getOpcode());
  330. }
  331. case Instruction::Add:
  332. case Instruction::FAdd:
  333. case Instruction::Sub:
  334. case Instruction::FSub:
  335. case Instruction::Mul:
  336. case Instruction::FMul:
  337. case Instruction::UDiv:
  338. case Instruction::SDiv:
  339. case Instruction::FDiv:
  340. case Instruction::URem:
  341. case Instruction::SRem:
  342. case Instruction::FRem:
  343. case Instruction::Shl:
  344. case Instruction::LShr:
  345. case Instruction::AShr:
  346. case Instruction::And:
  347. case Instruction::Or:
  348. case Instruction::Xor: {
  349. TargetTransformInfo::OperandValueKind Op1VK =
  350. getOperandInfo(I->getOperand(0));
  351. TargetTransformInfo::OperandValueKind Op2VK =
  352. getOperandInfo(I->getOperand(1));
  353. return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK,
  354. Op2VK);
  355. }
  356. case Instruction::Select: {
  357. const SelectInst *SI = cast<SelectInst>(I);
  358. Type *CondTy = SI->getCondition()->getType();
  359. return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy);
  360. }
  361. case Instruction::ICmp:
  362. case Instruction::FCmp: {
  363. Type *ValTy = I->getOperand(0)->getType();
  364. return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy);
  365. }
  366. case Instruction::Store: {
  367. const StoreInst *SI = cast<StoreInst>(I);
  368. Type *ValTy = SI->getValueOperand()->getType();
  369. return TTI->getMemoryOpCost(I->getOpcode(), ValTy,
  370. SI->getAlignment(),
  371. SI->getPointerAddressSpace());
  372. }
  373. case Instruction::Load: {
  374. const LoadInst *LI = cast<LoadInst>(I);
  375. return TTI->getMemoryOpCost(I->getOpcode(), I->getType(),
  376. LI->getAlignment(),
  377. LI->getPointerAddressSpace());
  378. }
  379. case Instruction::ZExt:
  380. case Instruction::SExt:
  381. case Instruction::FPToUI:
  382. case Instruction::FPToSI:
  383. case Instruction::FPExt:
  384. case Instruction::PtrToInt:
  385. case Instruction::IntToPtr:
  386. case Instruction::SIToFP:
  387. case Instruction::UIToFP:
  388. case Instruction::Trunc:
  389. case Instruction::FPTrunc:
  390. case Instruction::BitCast:
  391. case Instruction::AddrSpaceCast: {
  392. Type *SrcTy = I->getOperand(0)->getType();
  393. return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy);
  394. }
  395. case Instruction::ExtractElement: {
  396. const ExtractElementInst * EEI = cast<ExtractElementInst>(I);
  397. ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
  398. unsigned Idx = -1;
  399. if (CI)
  400. Idx = CI->getZExtValue();
  401. // Try to match a reduction sequence (series of shufflevector and vector
  402. // adds followed by a extractelement).
  403. unsigned ReduxOpCode;
  404. Type *ReduxType;
  405. if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType))
  406. return TTI->getReductionCost(ReduxOpCode, ReduxType, false);
  407. else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType))
  408. return TTI->getReductionCost(ReduxOpCode, ReduxType, true);
  409. return TTI->getVectorInstrCost(I->getOpcode(),
  410. EEI->getOperand(0)->getType(), Idx);
  411. }
  412. case Instruction::InsertElement: {
  413. const InsertElementInst * IE = cast<InsertElementInst>(I);
  414. ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
  415. unsigned Idx = -1;
  416. if (CI)
  417. Idx = CI->getZExtValue();
  418. return TTI->getVectorInstrCost(I->getOpcode(),
  419. IE->getType(), Idx);
  420. }
  421. case Instruction::ShuffleVector: {
  422. const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
  423. Type *VecTypOp0 = Shuffle->getOperand(0)->getType();
  424. unsigned NumVecElems = VecTypOp0->getVectorNumElements();
  425. SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
  426. if (NumVecElems == Mask.size()) {
  427. if (isReverseVectorMask(Mask))
  428. return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0,
  429. 0, nullptr);
  430. if (isAlternateVectorMask(Mask))
  431. return TTI->getShuffleCost(TargetTransformInfo::SK_Alternate,
  432. VecTypOp0, 0, nullptr);
  433. }
  434. return -1;
  435. }
  436. case Instruction::Call:
  437. if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
  438. SmallVector<Type*, 4> Tys;
  439. for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J)
  440. Tys.push_back(II->getArgOperand(J)->getType());
  441. return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(),
  442. Tys);
  443. }
  444. return -1;
  445. default:
  446. // We don't have any information on this instruction.
  447. return -1;
  448. }
  449. }
  450. void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
  451. if (!F)
  452. return;
  453. for (Function::iterator B = F->begin(), BE = F->end(); B != BE; ++B) {
  454. for (BasicBlock::iterator it = B->begin(), e = B->end(); it != e; ++it) {
  455. Instruction *Inst = it;
  456. unsigned Cost = getInstructionCost(Inst);
  457. if (Cost != (unsigned)-1)
  458. OS << "Cost Model: Found an estimated cost of " << Cost;
  459. else
  460. OS << "Cost Model: Unknown cost";
  461. OS << " for instruction: "<< *Inst << "\n";
  462. }
  463. }
  464. }