equivalence_relation.h 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef SOURCE_FUZZ_EQUIVALENCE_RELATION_H_
  15. #define SOURCE_FUZZ_EQUIVALENCE_RELATION_H_
  16. #include <algorithm>
  17. #include <cassert>
  18. #include <memory>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include <vector>
  22. #include "source/util/make_unique.h"
  23. namespace spvtools {
  24. namespace fuzz {
  25. // A class for representing an equivalence relation on objects of type |T|,
  26. // which should be a value type. The type |T| is required to have a copy
  27. // constructor, and |PointerHashT| and |PointerEqualsT| must be functors
  28. // providing hashing and equality testing functionality for pointers to objects
  29. // of type |T|.
  30. //
  31. // A disjoint-set (a.k.a. union-find or merge-find) data structure is used to
  32. // represent the equivalence relation. Path compression is used. Union by
  33. // rank/size is not used.
  34. //
  35. // Each disjoint set is represented as a tree, rooted at the representative
  36. // of the set.
  37. //
  38. // Getting the representative of a value simply requires chasing parent pointers
  39. // from the value until you reach the root.
  40. //
  41. // Checking equivalence of two elements requires checking that the
  42. // representatives are equal.
  43. //
  44. // Traversing the tree rooted at a value's representative visits the value's
  45. // equivalence class.
  46. //
  47. // |PointerHashT| and |PointerEqualsT| are used to define *equality* between
  48. // values, and otherwise are *not* used to define the equivalence relation
  49. // (except that equal values are equivalent). The equivalence relation is
  50. // constructed by repeatedly adding pairs of (typically non-equal) values that
  51. // are deemed to be equivalent.
  52. //
  53. // For example in an equivalence relation on integers, 1 and 5 might be added
  54. // as equivalent, so that IsEquivalent(1, 5) holds, because they represent
  55. // IDs in a SPIR-V binary that are known to contain the same value at run time,
  56. // but clearly 1 != 5. Since 1 and 1 are equal, IsEquivalent(1, 1) will also
  57. // hold.
  58. //
  59. // Each unique (up to equality) value added to the relation is copied into
  60. // |owned_values_|, so there is one canonical memory address per unique value.
  61. // Uniqueness is ensured by storing (and checking) a set of pointers to these
  62. // values in |value_set_|, which uses |PointerHashT| and |PointerEqualsT|.
  63. //
  64. // |parent_| and |children_| encode the equivalence relation, i.e., the trees.
  65. template <typename T, typename PointerHashT, typename PointerEqualsT>
  66. class EquivalenceRelation {
  67. public:
  68. // Requires that |value1| and |value2| are already registered in the
  69. // equivalence relation. Merges the equivalence classes associated with
  70. // |value1| and |value2|.
  71. void MakeEquivalent(const T& value1, const T& value2) {
  72. assert(Exists(value1) &&
  73. "Precondition: value1 must already be registered.");
  74. assert(Exists(value2) &&
  75. "Precondition: value2 must already be registered.");
  76. // Look up canonical pointers to each of the values in the value pool.
  77. const T* value1_ptr = *value_set_.find(&value1);
  78. const T* value2_ptr = *value_set_.find(&value2);
  79. // If the values turn out to be identical, they are already in the same
  80. // equivalence class so there is nothing to do.
  81. if (value1_ptr == value2_ptr) {
  82. return;
  83. }
  84. // Find the representative for each value's equivalence class, and if they
  85. // are not already in the same class, make one the parent of the other.
  86. const T* representative1 = Find(value1_ptr);
  87. const T* representative2 = Find(value2_ptr);
  88. assert(representative1 && "Representatives should never be null.");
  89. assert(representative2 && "Representatives should never be null.");
  90. if (representative1 != representative2) {
  91. parent_[representative1] = representative2;
  92. children_[representative2].push_back(representative1);
  93. }
  94. }
  95. // Requires that |value| is not known to the equivalence relation. Registers
  96. // it in its own equivalence class and returns a pointer to the equivalence
  97. // class representative.
  98. const T* Register(const T& value) {
  99. assert(!Exists(value));
  100. // This relies on T having a copy constructor.
  101. auto unique_pointer_to_value = MakeUnique<T>(value);
  102. auto pointer_to_value = unique_pointer_to_value.get();
  103. owned_values_.push_back(std::move(unique_pointer_to_value));
  104. value_set_.insert(pointer_to_value);
  105. // Initially say that the value is its own parent and that it has no
  106. // children.
  107. assert(pointer_to_value && "Representatives should never be null.");
  108. parent_[pointer_to_value] = pointer_to_value;
  109. children_[pointer_to_value] = std::vector<const T*>();
  110. return pointer_to_value;
  111. }
  112. // Returns exactly one representative per equivalence class.
  113. std::vector<const T*> GetEquivalenceClassRepresentatives() const {
  114. std::vector<const T*> result;
  115. for (auto& value : owned_values_) {
  116. if (parent_[value.get()] == value.get()) {
  117. result.push_back(value.get());
  118. }
  119. }
  120. return result;
  121. }
  122. // Returns pointers to all values in the equivalence class of |value|, which
  123. // must already be part of the equivalence relation.
  124. std::vector<const T*> GetEquivalenceClass(const T& value) const {
  125. assert(Exists(value));
  126. std::vector<const T*> result;
  127. // Traverse the tree of values rooted at the representative of the
  128. // equivalence class to which |value| belongs, and collect up all the values
  129. // that are encountered. This constitutes the whole equivalence class.
  130. std::vector<const T*> stack;
  131. stack.push_back(Find(*value_set_.find(&value)));
  132. while (!stack.empty()) {
  133. const T* item = stack.back();
  134. result.push_back(item);
  135. stack.pop_back();
  136. for (auto child : children_[item]) {
  137. stack.push_back(child);
  138. }
  139. }
  140. return result;
  141. }
  142. // Returns true if and only if |value1| and |value2| are in the same
  143. // equivalence class. Both values must already be known to the equivalence
  144. // relation.
  145. bool IsEquivalent(const T& value1, const T& value2) const {
  146. return Find(&value1) == Find(&value2);
  147. }
  148. // Returns all values known to be part of the equivalence relation.
  149. std::vector<const T*> GetAllKnownValues() const {
  150. std::vector<const T*> result;
  151. for (auto& value : owned_values_) {
  152. result.push_back(value.get());
  153. }
  154. return result;
  155. }
  156. // Returns true if and only if |value| is known to be part of the equivalence
  157. // relation.
  158. bool Exists(const T& value) const {
  159. return value_set_.find(&value) != value_set_.end();
  160. }
  161. // Returns the representative of the equivalence class of |value|, which must
  162. // already be known to the equivalence relation. This is the 'Find' operation
  163. // in a classic union-find data structure.
  164. const T* Find(const T* value) const {
  165. assert(Exists(*value));
  166. // Get the canonical pointer to the value from the value pool.
  167. const T* known_value = *value_set_.find(value);
  168. assert(parent_[known_value] && "Every known value should have a parent.");
  169. // Compute the result by chasing parents until we find a value that is its
  170. // own parent.
  171. const T* result = known_value;
  172. while (parent_[result] != result) {
  173. result = parent_[result];
  174. }
  175. assert(result && "Representatives should never be null.");
  176. // At this point, |result| is the representative of the equivalence class.
  177. // Now perform the 'path compression' optimization by doing another pass up
  178. // the parent chain, setting the parent of each node to be the
  179. // representative, and rewriting children correspondingly.
  180. const T* current = known_value;
  181. while (parent_[current] != result) {
  182. const T* next = parent_[current];
  183. parent_[current] = result;
  184. children_[result].push_back(current);
  185. auto child_iterator =
  186. std::find(children_[next].begin(), children_[next].end(), current);
  187. assert(child_iterator != children_[next].end() &&
  188. "'next' is the parent of 'current', so 'current' should be a "
  189. "child of 'next'");
  190. children_[next].erase(child_iterator);
  191. current = next;
  192. }
  193. return result;
  194. }
  195. private:
  196. // Maps every value to a parent. The representative of an equivalence class
  197. // is its own parent. A value's representative can be found by walking its
  198. // chain of ancestors.
  199. //
  200. // Mutable because the intuitively const method, 'Find', performs path
  201. // compression.
  202. mutable std::unordered_map<const T*, const T*> parent_;
  203. // Stores the children of each value. This allows the equivalence class of
  204. // a value to be calculated by traversing all descendents of the class's
  205. // representative.
  206. //
  207. // Mutable because the intuitively const method, 'Find', performs path
  208. // compression.
  209. mutable std::unordered_map<const T*, std::vector<const T*>> children_;
  210. // The values known to the equivalence relation are allocated in
  211. // |owned_values_|, and |value_pool_| provides (via |PointerHashT| and
  212. // |PointerEqualsT|) a means for mapping a value of interest to a pointer
  213. // into an equivalent value in |owned_values_|.
  214. std::unordered_set<const T*, PointerHashT, PointerEqualsT> value_set_;
  215. std::vector<std::unique_ptr<T>> owned_values_;
  216. };
  217. } // namespace fuzz
  218. } // namespace spvtools
  219. #endif // SOURCE_FUZZ_EQUIVALENCE_RELATION_H_