scalar_analysis_nodes.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASI,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
  15. #define SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
  16. #include <algorithm>
  17. #include <memory>
  18. #include <string>
  19. #include <vector>
  20. #include "source/opt/tree_iterator.h"
  21. namespace spvtools {
  22. namespace opt {
  23. class Loop;
  24. class ScalarEvolutionAnalysis;
  25. class SEConstantNode;
  26. class SERecurrentNode;
  27. class SEAddNode;
  28. class SEMultiplyNode;
  29. class SENegative;
  30. class SEValueUnknown;
  31. class SECantCompute;
  32. // Abstract class representing a node in the scalar evolution DAG. Each node
  33. // contains a vector of pointers to its children and each subclass of SENode
  34. // implements GetType and an As method to allow casting. SENodes can be hashed
  35. // using the SENodeHash functor. The vector of children is sorted when a node is
  36. // added. This is important as it allows the hash of X+Y to be the same as Y+X.
  37. class SENode {
  38. public:
  39. enum SENodeType {
  40. Constant,
  41. RecurrentAddExpr,
  42. Add,
  43. Multiply,
  44. Negative,
  45. ValueUnknown,
  46. CanNotCompute
  47. };
  48. using ChildContainerType = std::vector<SENode*>;
  49. explicit SENode(ScalarEvolutionAnalysis* parent_analysis)
  50. : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {}
  51. virtual SENodeType GetType() const = 0;
  52. virtual ~SENode() {}
  53. virtual inline void AddChild(SENode* child) {
  54. // If this is a constant node, assert.
  55. if (AsSEConstantNode()) {
  56. assert(false && "Trying to add a child node to a constant!");
  57. }
  58. // Find the first point in the vector where |child| is greater than the node
  59. // currently in the vector.
  60. auto find_first_less_than = [child](const SENode* node) {
  61. return child->unique_id_ <= node->unique_id_;
  62. };
  63. auto position = std::find_if_not(children_.begin(), children_.end(),
  64. find_first_less_than);
  65. // Children are sorted so the hashing and equality operator will be the same
  66. // for a node with the same children. X+Y should be the same as Y+X.
  67. children_.insert(position, child);
  68. }
  69. // Get the type as an std::string. This is used to represent the node in the
  70. // dot output and is used to hash the type as well.
  71. std::string AsString() const;
  72. // Dump the SENode and its immediate children, if |recurse| is true then it
  73. // will recurse through all children to print the DAG starting from this node
  74. // as a root.
  75. void DumpDot(std::ostream& out, bool recurse = false) const;
  76. // Checks if two nodes are the same by hashing them.
  77. bool operator==(const SENode& other) const;
  78. // Checks if two nodes are not the same by comparing the hashes.
  79. bool operator!=(const SENode& other) const;
  80. // Return the child node at |index|.
  81. inline SENode* GetChild(size_t index) { return children_[index]; }
  82. inline const SENode* GetChild(size_t index) const { return children_[index]; }
  83. // Iterator to iterate over the child nodes.
  84. using iterator = ChildContainerType::iterator;
  85. using const_iterator = ChildContainerType::const_iterator;
  86. // Iterate over immediate child nodes.
  87. iterator begin() { return children_.begin(); }
  88. iterator end() { return children_.end(); }
  89. // Constant overloads for iterating over immediate child nodes.
  90. const_iterator begin() const { return children_.cbegin(); }
  91. const_iterator end() const { return children_.cend(); }
  92. const_iterator cbegin() { return children_.cbegin(); }
  93. const_iterator cend() { return children_.cend(); }
  94. // Collect all the recurrent nodes in this SENode
  95. std::vector<SERecurrentNode*> CollectRecurrentNodes() {
  96. std::vector<SERecurrentNode*> recurrent_nodes{};
  97. if (auto recurrent_node = AsSERecurrentNode()) {
  98. recurrent_nodes.push_back(recurrent_node);
  99. }
  100. for (auto child : GetChildren()) {
  101. auto child_recurrent_nodes = child->CollectRecurrentNodes();
  102. recurrent_nodes.insert(recurrent_nodes.end(),
  103. child_recurrent_nodes.begin(),
  104. child_recurrent_nodes.end());
  105. }
  106. return recurrent_nodes;
  107. }
  108. // Collect all the value unknown nodes in this SENode
  109. std::vector<SEValueUnknown*> CollectValueUnknownNodes() {
  110. std::vector<SEValueUnknown*> value_unknown_nodes{};
  111. if (auto value_unknown_node = AsSEValueUnknown()) {
  112. value_unknown_nodes.push_back(value_unknown_node);
  113. }
  114. for (auto child : GetChildren()) {
  115. auto child_value_unknown_nodes = child->CollectValueUnknownNodes();
  116. value_unknown_nodes.insert(value_unknown_nodes.end(),
  117. child_value_unknown_nodes.begin(),
  118. child_value_unknown_nodes.end());
  119. }
  120. return value_unknown_nodes;
  121. }
  122. // Iterator to iterate over the entire DAG. Even though we are using the tree
  123. // iterator it should still be safe to iterate over. However, nodes with
  124. // multiple parents will be visited multiple times, unlike in a tree.
  125. using dag_iterator = TreeDFIterator<SENode>;
  126. using const_dag_iterator = TreeDFIterator<const SENode>;
  127. // Iterate over all child nodes in the graph.
  128. dag_iterator graph_begin() { return dag_iterator(this); }
  129. dag_iterator graph_end() { return dag_iterator(); }
  130. const_dag_iterator graph_begin() const { return graph_cbegin(); }
  131. const_dag_iterator graph_end() const { return graph_cend(); }
  132. const_dag_iterator graph_cbegin() const { return const_dag_iterator(this); }
  133. const_dag_iterator graph_cend() const { return const_dag_iterator(); }
  134. // Return the vector of immediate children.
  135. const ChildContainerType& GetChildren() const { return children_; }
  136. ChildContainerType& GetChildren() { return children_; }
  137. // Return true if this node is a can't compute node.
  138. bool IsCantCompute() const { return GetType() == CanNotCompute; }
  139. // Implements a casting method for each type.
  140. // clang-format off
  141. #define DeclareCastMethod(target) \
  142. virtual target* As##target() { return nullptr; } \
  143. virtual const target* As##target() const { return nullptr; }
  144. DeclareCastMethod(SEConstantNode)
  145. DeclareCastMethod(SERecurrentNode)
  146. DeclareCastMethod(SEAddNode)
  147. DeclareCastMethod(SEMultiplyNode)
  148. DeclareCastMethod(SENegative)
  149. DeclareCastMethod(SEValueUnknown)
  150. DeclareCastMethod(SECantCompute)
  151. #undef DeclareCastMethod
  152. // Get the analysis which has this node in its cache.
  153. inline ScalarEvolutionAnalysis* GetParentAnalysis() const {
  154. return parent_analysis_;
  155. }
  156. protected:
  157. ChildContainerType children_;
  158. ScalarEvolutionAnalysis* parent_analysis_;
  159. // The unique id of this node, assigned on creation by incrementing the static
  160. // node count.
  161. uint32_t unique_id_;
  162. // The number of nodes created.
  163. static uint32_t NumberOfNodes;
  164. };
  165. // clang-format on
  166. // Function object to handle the hashing of SENodes. Hashing algorithm hashes
  167. // the type (as a string), the literal value of any constants, and the child
  168. // pointers which are assumed to be unique.
  169. struct SENodeHash {
  170. size_t operator()(const std::unique_ptr<SENode>& node) const;
  171. size_t operator()(const SENode* node) const;
  172. };
  173. // A node representing a constant integer.
  174. class SEConstantNode : public SENode {
  175. public:
  176. SEConstantNode(ScalarEvolutionAnalysis* parent_analysis, int64_t value)
  177. : SENode(parent_analysis), literal_value_(value) {}
  178. SENodeType GetType() const final { return Constant; }
  179. int64_t FoldToSingleValue() const { return literal_value_; }
  180. SEConstantNode* AsSEConstantNode() override { return this; }
  181. const SEConstantNode* AsSEConstantNode() const override { return this; }
  182. inline void AddChild(SENode*) final {
  183. assert(false && "Attempting to add a child to a constant node!");
  184. }
  185. protected:
  186. int64_t literal_value_;
  187. };
  188. // A node representing a recurrent expression in the code. A recurrent
  189. // expression is an expression whose value can be expressed as a linear
  190. // expression of the loop iterations. Such as an induction variable. The actual
  191. // value of a recurrent expression is coefficent_ * iteration + offset_, hence
  192. // an induction variable i=0, i++ becomes a recurrent expression with an offset
  193. // of zero and a coefficient of one.
  194. class SERecurrentNode : public SENode {
  195. public:
  196. SERecurrentNode(ScalarEvolutionAnalysis* parent_analysis, const Loop* loop)
  197. : SENode(parent_analysis), loop_(loop) {}
  198. SENodeType GetType() const final { return RecurrentAddExpr; }
  199. inline void AddCoefficient(SENode* child) {
  200. coefficient_ = child;
  201. SENode::AddChild(child);
  202. }
  203. inline void AddOffset(SENode* child) {
  204. offset_ = child;
  205. SENode::AddChild(child);
  206. }
  207. inline const SENode* GetCoefficient() const { return coefficient_; }
  208. inline SENode* GetCoefficient() { return coefficient_; }
  209. inline const SENode* GetOffset() const { return offset_; }
  210. inline SENode* GetOffset() { return offset_; }
  211. // Return the loop which this recurrent expression is recurring within.
  212. const Loop* GetLoop() const { return loop_; }
  213. SERecurrentNode* AsSERecurrentNode() override { return this; }
  214. const SERecurrentNode* AsSERecurrentNode() const override { return this; }
  215. private:
  216. SENode* coefficient_;
  217. SENode* offset_;
  218. const Loop* loop_;
  219. };
  220. // A node representing an addition operation between child nodes.
  221. class SEAddNode : public SENode {
  222. public:
  223. explicit SEAddNode(ScalarEvolutionAnalysis* parent_analysis)
  224. : SENode(parent_analysis) {}
  225. SENodeType GetType() const final { return Add; }
  226. SEAddNode* AsSEAddNode() override { return this; }
  227. const SEAddNode* AsSEAddNode() const override { return this; }
  228. };
  229. // A node representing a multiply operation between child nodes.
  230. class SEMultiplyNode : public SENode {
  231. public:
  232. explicit SEMultiplyNode(ScalarEvolutionAnalysis* parent_analysis)
  233. : SENode(parent_analysis) {}
  234. SENodeType GetType() const final { return Multiply; }
  235. SEMultiplyNode* AsSEMultiplyNode() override { return this; }
  236. const SEMultiplyNode* AsSEMultiplyNode() const override { return this; }
  237. };
  238. // A node representing a unary negative operation.
  239. class SENegative : public SENode {
  240. public:
  241. explicit SENegative(ScalarEvolutionAnalysis* parent_analysis)
  242. : SENode(parent_analysis) {}
  243. SENodeType GetType() const final { return Negative; }
  244. SENegative* AsSENegative() override { return this; }
  245. const SENegative* AsSENegative() const override { return this; }
  246. };
  247. // A node representing a value which we do not know the value of, such as a load
  248. // instruction.
  249. class SEValueUnknown : public SENode {
  250. public:
  251. // SEValueUnknowns must come from an instruction |unique_id| is the unique id
  252. // of that instruction. This is so we cancompare value unknowns and have a
  253. // unique value unknown for each instruction.
  254. SEValueUnknown(ScalarEvolutionAnalysis* parent_analysis, uint32_t result_id)
  255. : SENode(parent_analysis), result_id_(result_id) {}
  256. SENodeType GetType() const final { return ValueUnknown; }
  257. SEValueUnknown* AsSEValueUnknown() override { return this; }
  258. const SEValueUnknown* AsSEValueUnknown() const override { return this; }
  259. inline uint32_t ResultId() const { return result_id_; }
  260. private:
  261. uint32_t result_id_;
  262. };
  263. // A node which we cannot reason about at all.
  264. class SECantCompute : public SENode {
  265. public:
  266. explicit SECantCompute(ScalarEvolutionAnalysis* parent_analysis)
  267. : SENode(parent_analysis) {}
  268. SENodeType GetType() const final { return CanNotCompute; }
  269. SECantCompute* AsSECantCompute() override { return this; }
  270. const SECantCompute* AsSECantCompute() const override { return this; }
  271. };
  272. } // namespace opt
  273. } // namespace spvtools
  274. #endif // SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_