scalar_analysis_simplification.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/scalar_analysis.h"
  15. #include <functional>
  16. #include <map>
  17. #include <memory>
  18. #include <set>
  19. #include <unordered_set>
  20. #include <utility>
  21. #include <vector>
  22. // Simplifies scalar analysis DAGs.
  23. //
  24. // 1. Given a node passed to SimplifyExpression we first simplify the graph by
  25. // calling SimplifyPolynomial. This groups like nodes following basic arithmetic
  26. // rules, so multiple adds of the same load instruction could be grouped into a
  27. // single multiply of that instruction. SimplifyPolynomial will traverse the DAG
  28. // and build up an accumulator buffer for each class of instruction it finds.
  29. // For example take the loop:
  30. // for (i=0, i<N; i++) { i+B+23+4+B+C; }
  31. // In this example the expression "i+B+23+4+B+C" has four classes of
  32. // instruction, induction variable i, the two value unknowns B and C, and the
  33. // constants. The accumulator buffer is then used to rebuild the graph using
  34. // the accumulation of each type. This example would then be folded into
  35. // i+2*B+C+27.
  36. //
  37. // This new graph contains a single add node (or if only one type found then
  38. // just that node) with each of the like terms (or multiplication node) as a
  39. // child.
  40. //
  41. // 2. FoldRecurrentAddExpressions is then called on this new DAG. This will take
  42. // RecurrentAddExpressions which are with respect to the same loop and fold them
  43. // into a single new RecurrentAddExpression with respect to that same loop. An
  44. // expression can have multiple RecurrentAddExpression's with respect to
  45. // different loops in the case of nested loops. These expressions cannot be
  46. // folded further. For example:
  47. //
  48. // for (i=0; i<N;i++) for(j=0,k=1; j<N;++j,++k)
  49. //
  50. // The 'j' and 'k' are RecurrentAddExpression with respect to the second loop
  51. // and 'i' to the first. If 'j' and 'k' are used in an expression together then
  52. // they will be folded into a new RecurrentAddExpression with respect to the
  53. // second loop in that expression.
  54. //
  55. //
  56. // 3. If the DAG now only contains a single RecurrentAddExpression we can now
  57. // perform a final optimization SimplifyRecurrentAddExpression. This will
  58. // transform the entire DAG into a RecurrentAddExpression. Additions to the
  59. // RecurrentAddExpression are added to the offset field and multiplications to
  60. // the coefficient.
  61. //
  62. namespace spvtools {
  63. namespace opt {
  64. // Implementation of the functions which are used to simplify the graph. Graphs
  65. // of unknowns, multiplies, additions, and constants can be turned into a linear
  66. // add node with each term as a child. For instance a large graph built from, X
  67. // + X*2 + Y - Y*3 + 4 - 1, would become a single add expression with the
  68. // children X*3, -Y*2, and the constant 3. Graphs containing a recurrent
  69. // expression will be simplified to represent the entire graph around a single
  70. // recurrent expression. So for an induction variable (i=0, i++) if you add 1 to
  71. // i in an expression we can rewrite the graph of that expression to be a single
  72. // recurrent expression of (i=1,i++).
  73. class SENodeSimplifyImpl {
  74. public:
  75. SENodeSimplifyImpl(ScalarEvolutionAnalysis* analysis,
  76. SENode* node_to_simplify)
  77. : analysis_(*analysis),
  78. node_(node_to_simplify),
  79. constant_accumulator_(0) {}
  80. // Return the result of the simplification.
  81. SENode* Simplify();
  82. private:
  83. // Recursively descend through the graph to build up the accumulator objects
  84. // which are used to flatten the graph. |child| is the node currenty being
  85. // traversed and the |negation| flag is used to signify that this operation
  86. // was preceded by a unary negative operation and as such the result should be
  87. // negated.
  88. void GatherAccumulatorsFromChildNodes(SENode* new_node, SENode* child,
  89. bool negation);
  90. // Given a |multiply| node add to the accumulators for the term type within
  91. // the |multiply| expression. Will return true if the accumulators could be
  92. // calculated successfully. If the |multiply| is in any form other than
  93. // unknown*constant then we return false. |negation| signifies that the
  94. // operation was preceded by a unary negative.
  95. bool AccumulatorsFromMultiply(SENode* multiply, bool negation);
  96. SERecurrentNode* UpdateCoefficient(SERecurrentNode* recurrent,
  97. int64_t coefficient_update) const;
  98. // If the graph contains a recurrent expression, ie, an expression with the
  99. // loop iterations as a term in the expression, then the whole expression
  100. // can be rewritten to be a recurrent expression.
  101. SENode* SimplifyRecurrentAddExpression(SERecurrentNode* node);
  102. // Simplify the whole graph by linking like terms together in a single flat
  103. // add node. So X*2 + Y -Y + 3 +6 would become X*2 + 9. Where X and Y are a
  104. // ValueUnknown node (i.e, a load) or a recurrent expression.
  105. SENode* SimplifyPolynomial();
  106. // Each recurrent expression is an expression with respect to a specific loop.
  107. // If we have two different recurrent terms with respect to the same loop in a
  108. // single expression then we can fold those terms into a single new term.
  109. // For instance:
  110. //
  111. // induction i = 0, i++
  112. // temp = i*10
  113. // array[i+temp]
  114. //
  115. // We can fold the i + temp into a single expression. Rec(0,1) + Rec(0,10) can
  116. // become Rec(0,11).
  117. SENode* FoldRecurrentAddExpressions(SENode*);
  118. // We can eliminate recurrent expressions which have a coefficient of zero by
  119. // replacing them with their offset value. We are able to do this because a
  120. // recurrent expression represents the equation coefficient*iterations +
  121. // offset.
  122. SENode* EliminateZeroCoefficientRecurrents(SENode* node);
  123. // A reference the the analysis which requested the simplification.
  124. ScalarEvolutionAnalysis& analysis_;
  125. // The node being simplified.
  126. SENode* node_;
  127. // An accumulator of the net result of all the constant operations performed
  128. // in a graph.
  129. int64_t constant_accumulator_;
  130. // An accumulator for each of the non constant terms in the graph.
  131. std::map<SENode*, int64_t> accumulators_;
  132. };
  133. // From a |multiply| build up the accumulator objects.
  134. bool SENodeSimplifyImpl::AccumulatorsFromMultiply(SENode* multiply,
  135. bool negation) {
  136. if (multiply->GetChildren().size() != 2 ||
  137. multiply->GetType() != SENode::Multiply)
  138. return false;
  139. SENode* operand_1 = multiply->GetChild(0);
  140. SENode* operand_2 = multiply->GetChild(1);
  141. SENode* value_unknown = nullptr;
  142. SENode* constant = nullptr;
  143. // Work out which operand is the unknown value.
  144. if (operand_1->GetType() == SENode::ValueUnknown ||
  145. operand_1->GetType() == SENode::RecurrentAddExpr)
  146. value_unknown = operand_1;
  147. else if (operand_2->GetType() == SENode::ValueUnknown ||
  148. operand_2->GetType() == SENode::RecurrentAddExpr)
  149. value_unknown = operand_2;
  150. // Work out which operand is the constant coefficient.
  151. if (operand_1->GetType() == SENode::Constant)
  152. constant = operand_1;
  153. else if (operand_2->GetType() == SENode::Constant)
  154. constant = operand_2;
  155. // If the expression is not a variable multiplied by a constant coefficient,
  156. // exit out.
  157. if (!(value_unknown && constant)) {
  158. return false;
  159. }
  160. int64_t sign = negation ? -1 : 1;
  161. auto iterator = accumulators_.find(value_unknown);
  162. int64_t new_value = constant->AsSEConstantNode()->FoldToSingleValue() * sign;
  163. // Add the result of the multiplication to the accumulators.
  164. if (iterator != accumulators_.end()) {
  165. (*iterator).second += new_value;
  166. } else {
  167. accumulators_.insert({value_unknown, new_value});
  168. }
  169. return true;
  170. }
  171. SENode* SENodeSimplifyImpl::Simplify() {
  172. // We only handle graphs with an addition, multiplication, or negation, at the
  173. // root.
  174. if (node_->GetType() != SENode::Add && node_->GetType() != SENode::Multiply &&
  175. node_->GetType() != SENode::Negative)
  176. return node_;
  177. SENode* simplified_polynomial = SimplifyPolynomial();
  178. SERecurrentNode* recurrent_expr = nullptr;
  179. node_ = simplified_polynomial;
  180. // Fold recurrent expressions which are with respect to the same loop into a
  181. // single recurrent expression.
  182. simplified_polynomial = FoldRecurrentAddExpressions(simplified_polynomial);
  183. simplified_polynomial =
  184. EliminateZeroCoefficientRecurrents(simplified_polynomial);
  185. // Traverse the immediate children of the new node to find the recurrent
  186. // expression. If there is more than one there is nothing further we can do.
  187. for (SENode* child : simplified_polynomial->GetChildren()) {
  188. if (child->GetType() == SENode::RecurrentAddExpr) {
  189. recurrent_expr = child->AsSERecurrentNode();
  190. }
  191. }
  192. // We need to count the number of unique recurrent expressions in the DAG to
  193. // ensure there is only one.
  194. for (auto child_iterator = simplified_polynomial->graph_begin();
  195. child_iterator != simplified_polynomial->graph_end(); ++child_iterator) {
  196. if (child_iterator->GetType() == SENode::RecurrentAddExpr &&
  197. recurrent_expr != child_iterator->AsSERecurrentNode()) {
  198. return simplified_polynomial;
  199. }
  200. }
  201. if (recurrent_expr) {
  202. return SimplifyRecurrentAddExpression(recurrent_expr);
  203. }
  204. return simplified_polynomial;
  205. }
  206. // Traverse the graph to build up the accumulator objects.
  207. void SENodeSimplifyImpl::GatherAccumulatorsFromChildNodes(SENode* new_node,
  208. SENode* child,
  209. bool negation) {
  210. int32_t sign = negation ? -1 : 1;
  211. if (child->GetType() == SENode::Constant) {
  212. // Collect all the constants and add them together.
  213. constant_accumulator_ +=
  214. child->AsSEConstantNode()->FoldToSingleValue() * sign;
  215. } else if (child->GetType() == SENode::ValueUnknown ||
  216. child->GetType() == SENode::RecurrentAddExpr) {
  217. // To rebuild the graph of X+X+X*2 into 4*X we count the occurrences of X
  218. // and create a new node of count*X after. X can either be a ValueUnknown or
  219. // a RecurrentAddExpr. The count for each X is stored in the accumulators_
  220. // map.
  221. auto iterator = accumulators_.find(child);
  222. // If we've encountered this term before add to the accumulator for it.
  223. if (iterator == accumulators_.end())
  224. accumulators_.insert({child, sign});
  225. else
  226. iterator->second += sign;
  227. } else if (child->GetType() == SENode::Multiply) {
  228. if (!AccumulatorsFromMultiply(child, negation)) {
  229. new_node->AddChild(child);
  230. }
  231. } else if (child->GetType() == SENode::Add) {
  232. for (SENode* next_child : *child) {
  233. GatherAccumulatorsFromChildNodes(new_node, next_child, negation);
  234. }
  235. } else if (child->GetType() == SENode::Negative) {
  236. SENode* negated_node = child->GetChild(0);
  237. GatherAccumulatorsFromChildNodes(new_node, negated_node, !negation);
  238. } else {
  239. // If we can't work out how to fold the expression just add it back into
  240. // the graph.
  241. new_node->AddChild(child);
  242. }
  243. }
  244. SERecurrentNode* SENodeSimplifyImpl::UpdateCoefficient(
  245. SERecurrentNode* recurrent, int64_t coefficient_update) const {
  246. std::unique_ptr<SERecurrentNode> new_recurrent_node{new SERecurrentNode(
  247. recurrent->GetParentAnalysis(), recurrent->GetLoop())};
  248. SENode* new_coefficient = analysis_.CreateMultiplyNode(
  249. recurrent->GetCoefficient(),
  250. analysis_.CreateConstant(coefficient_update));
  251. // See if the node can be simplified.
  252. SENode* simplified = analysis_.SimplifyExpression(new_coefficient);
  253. if (simplified->GetType() != SENode::CanNotCompute)
  254. new_coefficient = simplified;
  255. if (coefficient_update < 0) {
  256. new_recurrent_node->AddOffset(
  257. analysis_.CreateNegation(recurrent->GetOffset()));
  258. } else {
  259. new_recurrent_node->AddOffset(recurrent->GetOffset());
  260. }
  261. new_recurrent_node->AddCoefficient(new_coefficient);
  262. return analysis_.GetCachedOrAdd(std::move(new_recurrent_node))
  263. ->AsSERecurrentNode();
  264. }
  265. // Simplify all the terms in the polynomial function.
  266. SENode* SENodeSimplifyImpl::SimplifyPolynomial() {
  267. std::unique_ptr<SENode> new_add{new SEAddNode(node_->GetParentAnalysis())};
  268. // Traverse the graph and gather the accumulators from it.
  269. GatherAccumulatorsFromChildNodes(new_add.get(), node_, false);
  270. // Fold all the constants into a single constant node.
  271. if (constant_accumulator_ != 0) {
  272. new_add->AddChild(analysis_.CreateConstant(constant_accumulator_));
  273. }
  274. for (auto& pair : accumulators_) {
  275. SENode* term = pair.first;
  276. int64_t count = pair.second;
  277. // We can eliminate the term completely.
  278. if (count == 0) continue;
  279. if (count == 1) {
  280. new_add->AddChild(term);
  281. } else if (count == -1 && term->GetType() != SENode::RecurrentAddExpr) {
  282. // If the count is -1 we can just add a negative version of that node,
  283. // unless it is a recurrent expression as we would rather the negative
  284. // goes on the recurrent expressions children. This makes it easier to
  285. // work with in other places.
  286. new_add->AddChild(analysis_.CreateNegation(term));
  287. } else {
  288. // Output value unknown terms as count*term and output recurrent
  289. // expression terms as rec(offset, coefficient + count) offset and
  290. // coefficient are the same as in the original expression.
  291. if (term->GetType() == SENode::ValueUnknown) {
  292. SENode* count_as_constant = analysis_.CreateConstant(count);
  293. new_add->AddChild(
  294. analysis_.CreateMultiplyNode(count_as_constant, term));
  295. } else {
  296. assert(term->GetType() == SENode::RecurrentAddExpr &&
  297. "We only handle value unknowns or recurrent expressions");
  298. // Create a new recurrent expression by adding the count to the
  299. // coefficient of the old one.
  300. new_add->AddChild(UpdateCoefficient(term->AsSERecurrentNode(), count));
  301. }
  302. }
  303. }
  304. // If there is only one term in the addition left just return that term.
  305. if (new_add->GetChildren().size() == 1) {
  306. return new_add->GetChild(0);
  307. }
  308. // If there are no terms left in the addition just return 0.
  309. if (new_add->GetChildren().size() == 0) {
  310. return analysis_.CreateConstant(0);
  311. }
  312. return analysis_.GetCachedOrAdd(std::move(new_add));
  313. }
  314. SENode* SENodeSimplifyImpl::FoldRecurrentAddExpressions(SENode* root) {
  315. std::unique_ptr<SEAddNode> new_node{new SEAddNode(&analysis_)};
  316. // A mapping of loops to the list of recurrent expressions which are with
  317. // respect to those loops.
  318. std::map<const Loop*, std::vector<std::pair<SERecurrentNode*, bool>>>
  319. loops_to_recurrent{};
  320. bool has_multiple_same_loop_recurrent_terms = false;
  321. for (SENode* child : *root) {
  322. bool negation = false;
  323. if (child->GetType() == SENode::Negative) {
  324. child = child->GetChild(0);
  325. negation = true;
  326. }
  327. if (child->GetType() == SENode::RecurrentAddExpr) {
  328. const Loop* loop = child->AsSERecurrentNode()->GetLoop();
  329. SERecurrentNode* rec = child->AsSERecurrentNode();
  330. if (loops_to_recurrent.find(loop) == loops_to_recurrent.end()) {
  331. loops_to_recurrent[loop] = {std::make_pair(rec, negation)};
  332. } else {
  333. loops_to_recurrent[loop].push_back(std::make_pair(rec, negation));
  334. has_multiple_same_loop_recurrent_terms = true;
  335. }
  336. } else {
  337. new_node->AddChild(child);
  338. }
  339. }
  340. if (!has_multiple_same_loop_recurrent_terms) return root;
  341. for (auto pair : loops_to_recurrent) {
  342. std::vector<std::pair<SERecurrentNode*, bool>>& recurrent_expressions =
  343. pair.second;
  344. const Loop* loop = pair.first;
  345. std::unique_ptr<SENode> new_coefficient{new SEAddNode(&analysis_)};
  346. std::unique_ptr<SENode> new_offset{new SEAddNode(&analysis_)};
  347. for (auto node_pair : recurrent_expressions) {
  348. SERecurrentNode* node = node_pair.first;
  349. bool negative = node_pair.second;
  350. if (!negative) {
  351. new_coefficient->AddChild(node->GetCoefficient());
  352. new_offset->AddChild(node->GetOffset());
  353. } else {
  354. new_coefficient->AddChild(
  355. analysis_.CreateNegation(node->GetCoefficient()));
  356. new_offset->AddChild(analysis_.CreateNegation(node->GetOffset()));
  357. }
  358. }
  359. std::unique_ptr<SERecurrentNode> new_recurrent{
  360. new SERecurrentNode(&analysis_, loop)};
  361. SENode* new_coefficient_simplified =
  362. analysis_.SimplifyExpression(new_coefficient.get());
  363. SENode* new_offset_simplified =
  364. analysis_.SimplifyExpression(new_offset.get());
  365. if (new_coefficient_simplified->GetType() == SENode::Constant &&
  366. new_coefficient_simplified->AsSEConstantNode()->FoldToSingleValue() ==
  367. 0) {
  368. return new_offset_simplified;
  369. }
  370. new_recurrent->AddCoefficient(new_coefficient_simplified);
  371. new_recurrent->AddOffset(new_offset_simplified);
  372. new_node->AddChild(analysis_.GetCachedOrAdd(std::move(new_recurrent)));
  373. }
  374. // If we only have one child in the add just return that.
  375. if (new_node->GetChildren().size() == 1) {
  376. return new_node->GetChild(0);
  377. }
  378. return analysis_.GetCachedOrAdd(std::move(new_node));
  379. }
  380. SENode* SENodeSimplifyImpl::EliminateZeroCoefficientRecurrents(SENode* node) {
  381. if (node->GetType() != SENode::Add) return node;
  382. bool has_change = false;
  383. std::vector<SENode*> new_children{};
  384. for (SENode* child : *node) {
  385. if (child->GetType() == SENode::RecurrentAddExpr) {
  386. SENode* coefficient = child->AsSERecurrentNode()->GetCoefficient();
  387. // If coefficient is zero then we can eliminate the recurrent expression
  388. // entirely and just return the offset as the recurrent expression is
  389. // representing the equation coefficient*iterations + offset.
  390. if (coefficient->GetType() == SENode::Constant &&
  391. coefficient->AsSEConstantNode()->FoldToSingleValue() == 0) {
  392. new_children.push_back(child->AsSERecurrentNode()->GetOffset());
  393. has_change = true;
  394. } else {
  395. new_children.push_back(child);
  396. }
  397. } else {
  398. new_children.push_back(child);
  399. }
  400. }
  401. if (!has_change) return node;
  402. std::unique_ptr<SENode> new_add{new SEAddNode(node_->GetParentAnalysis())};
  403. for (SENode* child : new_children) {
  404. new_add->AddChild(child);
  405. }
  406. return analysis_.GetCachedOrAdd(std::move(new_add));
  407. }
  408. SENode* SENodeSimplifyImpl::SimplifyRecurrentAddExpression(
  409. SERecurrentNode* recurrent_expr) {
  410. const std::vector<SENode*>& children = node_->GetChildren();
  411. std::unique_ptr<SERecurrentNode> recurrent_node{new SERecurrentNode(
  412. recurrent_expr->GetParentAnalysis(), recurrent_expr->GetLoop())};
  413. // Create and simplify the new offset node.
  414. std::unique_ptr<SENode> new_offset{
  415. new SEAddNode(recurrent_expr->GetParentAnalysis())};
  416. new_offset->AddChild(recurrent_expr->GetOffset());
  417. for (SENode* child : children) {
  418. if (child->GetType() != SENode::RecurrentAddExpr) {
  419. new_offset->AddChild(child);
  420. }
  421. }
  422. // Simplify the new offset.
  423. SENode* simplified_child = analysis_.SimplifyExpression(new_offset.get());
  424. // If the child can be simplified, add the simplified form otherwise, add it
  425. // via the usual caching mechanism.
  426. if (simplified_child->GetType() != SENode::CanNotCompute) {
  427. recurrent_node->AddOffset(simplified_child);
  428. } else {
  429. recurrent_expr->AddOffset(analysis_.GetCachedOrAdd(std::move(new_offset)));
  430. }
  431. recurrent_node->AddCoefficient(recurrent_expr->GetCoefficient());
  432. return analysis_.GetCachedOrAdd(std::move(recurrent_node));
  433. }
  434. /*
  435. * Scalar Analysis simplification public methods.
  436. */
  437. SENode* ScalarEvolutionAnalysis::SimplifyExpression(SENode* node) {
  438. SENodeSimplifyImpl impl{this, node};
  439. return impl.Simplify();
  440. }
  441. } // namespace opt
  442. } // namespace spvtools