scalar_analysis.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef SOURCE_OPT_SCALAR_ANALYSIS_H_
  15. #define SOURCE_OPT_SCALAR_ANALYSIS_H_
  16. #include <algorithm>
  17. #include <cstdint>
  18. #include <map>
  19. #include <memory>
  20. #include <unordered_set>
  21. #include <utility>
  22. #include <vector>
  23. #include "source/opt/basic_block.h"
  24. #include "source/opt/instruction.h"
  25. #include "source/opt/scalar_analysis_nodes.h"
  26. namespace spvtools {
  27. namespace opt {
  28. class IRContext;
  29. class Loop;
  30. // Manager for the Scalar Evolution analysis. Creates and maintains a DAG of
  31. // scalar operations generated from analysing the use def graph from incoming
  32. // instructions. Each node is hashed as it is added so like node (for instance,
  33. // two induction variables i=0,i++ and j=0,j++) become the same node. After
  34. // creating a DAG with AnalyzeInstruction it can the be simplified into a more
  35. // usable form with SimplifyExpression.
  36. class ScalarEvolutionAnalysis {
  37. public:
  38. explicit ScalarEvolutionAnalysis(IRContext* context);
  39. // Create a unary negative node on |operand|.
  40. SENode* CreateNegation(SENode* operand);
  41. // Creates a subtraction between the two operands by adding |operand_1| to the
  42. // negation of |operand_2|.
  43. SENode* CreateSubtraction(SENode* operand_1, SENode* operand_2);
  44. // Create an addition node between two operands. The |simplify| when set will
  45. // allow the function to return an SEConstant instead of an addition if the
  46. // two input operands are also constant.
  47. SENode* CreateAddNode(SENode* operand_1, SENode* operand_2);
  48. // Create a multiply node between two operands.
  49. SENode* CreateMultiplyNode(SENode* operand_1, SENode* operand_2);
  50. // Create a node representing a constant integer.
  51. SENode* CreateConstant(int64_t integer);
  52. // Create a value unknown node, such as a load.
  53. SENode* CreateValueUnknownNode(const Instruction* inst);
  54. // Create a CantComputeNode. Used to exit out of analysis.
  55. SENode* CreateCantComputeNode();
  56. // Create a new recurrent node with |offset| and |coefficient|, with respect
  57. // to |loop|.
  58. SENode* CreateRecurrentExpression(const Loop* loop, SENode* offset,
  59. SENode* coefficient);
  60. // Construct the DAG by traversing use def chain of |inst|.
  61. SENode* AnalyzeInstruction(const Instruction* inst);
  62. // Simplify the |node| by grouping like terms or if contains a recurrent
  63. // expression, rewrite the graph so the whole DAG (from |node| down) is in
  64. // terms of that recurrent expression.
  65. //
  66. // For example.
  67. // Induction variable i=0, i++ would produce Rec(0,1) so i+1 could be
  68. // transformed into Rec(1,1).
  69. //
  70. // X+X*2+Y-Y+34-17 would be transformed into 3*X + 17, where X and Y are
  71. // ValueUnknown nodes (such as a load instruction).
  72. SENode* SimplifyExpression(SENode* node);
  73. // Add |prospective_node| into the cache and return a raw pointer to it. If
  74. // |prospective_node| is already in the cache just return the raw pointer.
  75. SENode* GetCachedOrAdd(std::unique_ptr<SENode> prospective_node);
  76. // Checks that the graph starting from |node| is invariant to the |loop|.
  77. bool IsLoopInvariant(const Loop* loop, const SENode* node) const;
  78. // Sets |is_gt_zero| to true if |node| represent a value always strictly
  79. // greater than 0. The result of |is_gt_zero| is valid only if the function
  80. // returns true.
  81. bool IsAlwaysGreaterThanZero(SENode* node, bool* is_gt_zero) const;
  82. // Sets |is_ge_zero| to true if |node| represent a value greater or equals to
  83. // 0. The result of |is_ge_zero| is valid only if the function returns true.
  84. bool IsAlwaysGreaterOrEqualToZero(SENode* node, bool* is_ge_zero) const;
  85. // Find the recurrent term belonging to |loop| in the graph starting from
  86. // |node| and return the coefficient of that recurrent term. Constant zero
  87. // will be returned if no recurrent could be found. |node| should be in
  88. // simplest form.
  89. SENode* GetCoefficientFromRecurrentTerm(SENode* node, const Loop* loop);
  90. // Return a rebuilt graph starting from |node| with the recurrent expression
  91. // belonging to |loop| being zeroed out. Returned node will be simplified.
  92. SENode* BuildGraphWithoutRecurrentTerm(SENode* node, const Loop* loop);
  93. // Return the recurrent term belonging to |loop| if it appears in the graph
  94. // starting at |node| or null if it doesn't.
  95. SERecurrentNode* GetRecurrentTerm(SENode* node, const Loop* loop);
  96. SENode* UpdateChildNode(SENode* parent, SENode* child, SENode* new_child);
  97. // The loops in |loop_pair| will be considered the same when constructing
  98. // SERecurrentNode objects. This enables analysing dependencies that will be
  99. // created during loop fusion.
  100. void AddLoopsToPretendAreTheSame(
  101. const std::pair<const Loop*, const Loop*>& loop_pair) {
  102. pretend_equal_[std::get<1>(loop_pair)] = std::get<0>(loop_pair);
  103. }
  104. private:
  105. SENode* AnalyzeConstant(const Instruction* inst);
  106. // Handles both addition and subtraction. If the |instruction| is OpISub
  107. // then the resulting node will be op1+(-op2) otherwise if it is OpIAdd then
  108. // the result will be op1+op2. |instruction| must be OpIAdd or OpISub.
  109. SENode* AnalyzeAddOp(const Instruction* instruction);
  110. SENode* AnalyzeMultiplyOp(const Instruction* multiply);
  111. SENode* AnalyzePhiInstruction(const Instruction* phi);
  112. IRContext* context_;
  113. // A map of instructions to SENodes. This is used to track recurrent
  114. // expressions as they are added when analyzing instructions. Recurrent
  115. // expressions come from phi nodes which by nature can include recursion so we
  116. // check if nodes have already been built when analyzing instructions.
  117. std::map<const Instruction*, SENode*> recurrent_node_map_;
  118. // On creation we create and cache the CantCompute node so we not need to
  119. // perform a needless create step.
  120. SENode* cached_cant_compute_;
  121. // Helper functor to allow two unique_ptr to nodes to be compare. Only
  122. // needed
  123. // for the unordered_set implementation.
  124. struct NodePointersEquality {
  125. bool operator()(const std::unique_ptr<SENode>& lhs,
  126. const std::unique_ptr<SENode>& rhs) const {
  127. return *lhs == *rhs;
  128. }
  129. };
  130. // Cache of nodes. All pointers to the nodes are references to the memory
  131. // managed by they set.
  132. std::unordered_set<std::unique_ptr<SENode>, SENodeHash, NodePointersEquality>
  133. node_cache_;
  134. // Loops that should be considered the same for performing analysis for loop
  135. // fusion.
  136. std::map<const Loop*, const Loop*> pretend_equal_;
  137. };
  138. // Wrapping class to manipulate SENode pointer using + - * / operators.
  139. class SExpression {
  140. public:
  141. // Implicit on purpose !
  142. SExpression(SENode* node)
  143. : node_(node->GetParentAnalysis()->SimplifyExpression(node)),
  144. scev_(node->GetParentAnalysis()) {}
  145. inline operator SENode*() const { return node_; }
  146. inline SENode* operator->() const { return node_; }
  147. const SENode& operator*() const { return *node_; }
  148. inline ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() const {
  149. return scev_;
  150. }
  151. inline SExpression operator+(SENode* rhs) const;
  152. template <typename T,
  153. typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
  154. inline SExpression operator+(T integer) const;
  155. inline SExpression operator+(SExpression rhs) const;
  156. inline SExpression operator-() const;
  157. inline SExpression operator-(SENode* rhs) const;
  158. template <typename T,
  159. typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
  160. inline SExpression operator-(T integer) const;
  161. inline SExpression operator-(SExpression rhs) const;
  162. inline SExpression operator*(SENode* rhs) const;
  163. template <typename T,
  164. typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
  165. inline SExpression operator*(T integer) const;
  166. inline SExpression operator*(SExpression rhs) const;
  167. template <typename T,
  168. typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
  169. inline std::pair<SExpression, int64_t> operator/(T integer) const;
  170. // Try to perform a division. Returns the pair <this.node_ / rhs, division
  171. // remainder>. If it fails to simplify it, the function returns a
  172. // CanNotCompute node.
  173. std::pair<SExpression, int64_t> operator/(SExpression rhs) const;
  174. private:
  175. SENode* node_;
  176. ScalarEvolutionAnalysis* scev_;
  177. };
  178. inline SExpression SExpression::operator+(SENode* rhs) const {
  179. return scev_->CreateAddNode(node_, rhs);
  180. }
  181. template <typename T,
  182. typename std::enable_if<std::is_integral<T>::value, int>::type>
  183. inline SExpression SExpression::operator+(T integer) const {
  184. return *this + scev_->CreateConstant(integer);
  185. }
  186. inline SExpression SExpression::operator+(SExpression rhs) const {
  187. return *this + rhs.node_;
  188. }
  189. inline SExpression SExpression::operator-() const {
  190. return scev_->CreateNegation(node_);
  191. }
  192. inline SExpression SExpression::operator-(SENode* rhs) const {
  193. return *this + scev_->CreateNegation(rhs);
  194. }
  195. template <typename T,
  196. typename std::enable_if<std::is_integral<T>::value, int>::type>
  197. inline SExpression SExpression::operator-(T integer) const {
  198. return *this - scev_->CreateConstant(integer);
  199. }
  200. inline SExpression SExpression::operator-(SExpression rhs) const {
  201. return *this - rhs.node_;
  202. }
  203. inline SExpression SExpression::operator*(SENode* rhs) const {
  204. return scev_->CreateMultiplyNode(node_, rhs);
  205. }
  206. template <typename T,
  207. typename std::enable_if<std::is_integral<T>::value, int>::type>
  208. inline SExpression SExpression::operator*(T integer) const {
  209. return *this * scev_->CreateConstant(integer);
  210. }
  211. inline SExpression SExpression::operator*(SExpression rhs) const {
  212. return *this * rhs.node_;
  213. }
  214. template <typename T,
  215. typename std::enable_if<std::is_integral<T>::value, int>::type>
  216. inline std::pair<SExpression, int64_t> SExpression::operator/(T integer) const {
  217. return *this / scev_->CreateConstant(integer);
  218. }
  219. template <typename T,
  220. typename std::enable_if<std::is_integral<T>::value, int>::type>
  221. inline SExpression operator+(T lhs, SExpression rhs) {
  222. return rhs + lhs;
  223. }
  224. inline SExpression operator+(SENode* lhs, SExpression rhs) { return rhs + lhs; }
  225. template <typename T,
  226. typename std::enable_if<std::is_integral<T>::value, int>::type>
  227. inline SExpression operator-(T lhs, SExpression rhs) {
  228. // NOLINTNEXTLINE(whitespace/braces)
  229. return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} -
  230. rhs;
  231. }
  232. inline SExpression operator-(SENode* lhs, SExpression rhs) {
  233. // NOLINTNEXTLINE(whitespace/braces)
  234. return SExpression{lhs} - rhs;
  235. }
  236. template <typename T,
  237. typename std::enable_if<std::is_integral<T>::value, int>::type>
  238. inline SExpression operator*(T lhs, SExpression rhs) {
  239. return rhs * lhs;
  240. }
  241. inline SExpression operator*(SENode* lhs, SExpression rhs) { return rhs * lhs; }
  242. template <typename T,
  243. typename std::enable_if<std::is_integral<T>::value, int>::type>
  244. inline std::pair<SExpression, int64_t> operator/(T lhs, SExpression rhs) {
  245. // NOLINTNEXTLINE(whitespace/braces)
  246. return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} /
  247. rhs;
  248. }
  249. inline std::pair<SExpression, int64_t> operator/(SENode* lhs, SExpression rhs) {
  250. // NOLINTNEXTLINE(whitespace/braces)
  251. return SExpression{lhs} / rhs;
  252. }
  253. } // namespace opt
  254. } // namespace spvtools
  255. #endif // SOURCE_OPT_SCALAR_ANALYSIS_H_