lcs.h 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. // Copyright (c) 2022 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef SOURCE_DIFF_LCS_H_
  15. #define SOURCE_DIFF_LCS_H_
  16. #include <algorithm>
  17. #include <cassert>
  18. #include <cstddef>
  19. #include <cstdint>
  20. #include <functional>
  21. #include <stack>
  22. #include <vector>
  23. namespace spvtools {
  24. namespace diff {
  25. // The result of a diff.
  26. using DiffMatch = std::vector<bool>;
  27. // Helper class to find the longest common subsequence between two function
  28. // bodies.
  29. template <typename Sequence>
  30. class LongestCommonSubsequence {
  31. public:
  32. LongestCommonSubsequence(const Sequence& src, const Sequence& dst)
  33. : src_(src),
  34. dst_(dst),
  35. table_(src.size(), std::vector<DiffMatchEntry>(dst.size())) {}
  36. // Given two sequences, it creates a matching between them. The elements are
  37. // simply marked as matched in src and dst, with any unmatched element in src
  38. // implying a removal and any unmatched element in dst implying an addition.
  39. //
  40. // Returns the length of the longest common subsequence.
  41. template <typename T>
  42. uint32_t Get(std::function<bool(T src_elem, T dst_elem)> match,
  43. DiffMatch* src_match_result, DiffMatch* dst_match_result);
  44. private:
  45. struct DiffMatchIndex {
  46. uint32_t src_offset;
  47. uint32_t dst_offset;
  48. };
  49. template <typename T>
  50. void CalculateLCS(std::function<bool(T src_elem, T dst_elem)> match);
  51. void RetrieveMatch(DiffMatch* src_match_result, DiffMatch* dst_match_result);
  52. bool IsInBound(DiffMatchIndex index) {
  53. return index.src_offset < src_.size() && index.dst_offset < dst_.size();
  54. }
  55. bool IsCalculated(DiffMatchIndex index) {
  56. assert(IsInBound(index));
  57. return table_[index.src_offset][index.dst_offset].valid;
  58. }
  59. bool IsCalculatedOrOutOfBound(DiffMatchIndex index) {
  60. return !IsInBound(index) || IsCalculated(index);
  61. }
  62. uint32_t GetMemoizedLength(DiffMatchIndex index) {
  63. if (!IsInBound(index)) {
  64. return 0;
  65. }
  66. assert(IsCalculated(index));
  67. return table_[index.src_offset][index.dst_offset].best_match_length;
  68. }
  69. bool IsMatched(DiffMatchIndex index) {
  70. assert(IsCalculated(index));
  71. return table_[index.src_offset][index.dst_offset].matched;
  72. }
  73. void MarkMatched(DiffMatchIndex index, uint32_t best_match_length,
  74. bool matched) {
  75. assert(IsInBound(index));
  76. DiffMatchEntry& entry = table_[index.src_offset][index.dst_offset];
  77. assert(!entry.valid);
  78. entry.best_match_length = best_match_length & 0x3FFFFFFF;
  79. assert(entry.best_match_length == best_match_length);
  80. entry.matched = matched;
  81. entry.valid = true;
  82. }
  83. const Sequence& src_;
  84. const Sequence& dst_;
  85. struct DiffMatchEntry {
  86. DiffMatchEntry() : best_match_length(0), matched(false), valid(false) {}
  87. uint32_t best_match_length : 30;
  88. // Whether src[i] and dst[j] matched. This is an optimization to avoid
  89. // calling the `match` function again when walking the LCS table.
  90. uint32_t matched : 1;
  91. // Use for the recursive algorithm to know if the contents of this entry are
  92. // valid.
  93. uint32_t valid : 1;
  94. };
  95. std::vector<std::vector<DiffMatchEntry>> table_;
  96. };
  97. template <typename Sequence>
  98. template <typename T>
  99. uint32_t LongestCommonSubsequence<Sequence>::Get(
  100. std::function<bool(T src_elem, T dst_elem)> match,
  101. DiffMatch* src_match_result, DiffMatch* dst_match_result) {
  102. CalculateLCS(match);
  103. RetrieveMatch(src_match_result, dst_match_result);
  104. return GetMemoizedLength({0, 0});
  105. }
  106. template <typename Sequence>
  107. template <typename T>
  108. void LongestCommonSubsequence<Sequence>::CalculateLCS(
  109. std::function<bool(T src_elem, T dst_elem)> match) {
  110. // The LCS algorithm is simple. Given sequences s and d, with a:b depicting a
  111. // range in python syntax:
  112. //
  113. // lcs(s[i:], d[j:]) =
  114. // lcs(s[i+1:], d[j+1:]) + 1 if s[i] == d[j]
  115. // max(lcs(s[i+1:], d[j:]), lcs(s[i:], d[j+1:])) o.w.
  116. //
  117. // Once the LCS table is filled according to the above, it can be walked and
  118. // the best match retrieved.
  119. //
  120. // This is a recursive function with memoization, which avoids filling table
  121. // entries where unnecessary. This makes the best case O(N) instead of
  122. // O(N^2). The implemention uses a std::stack to avoid stack overflow on long
  123. // sequences.
  124. if (src_.empty() || dst_.empty()) {
  125. return;
  126. }
  127. std::stack<DiffMatchIndex> to_calculate;
  128. to_calculate.push({0, 0});
  129. while (!to_calculate.empty()) {
  130. DiffMatchIndex current = to_calculate.top();
  131. to_calculate.pop();
  132. assert(IsInBound(current));
  133. // If already calculated through another path, ignore it.
  134. if (IsCalculated(current)) {
  135. continue;
  136. }
  137. if (match(src_[current.src_offset], dst_[current.dst_offset])) {
  138. // If the current elements match, advance both indices and calculate the
  139. // LCS if not already. Visit `current` again afterwards, so its
  140. // corresponding entry will be updated.
  141. DiffMatchIndex next = {current.src_offset + 1, current.dst_offset + 1};
  142. if (IsCalculatedOrOutOfBound(next)) {
  143. MarkMatched(current, GetMemoizedLength(next) + 1, true);
  144. } else {
  145. to_calculate.push(current);
  146. to_calculate.push(next);
  147. }
  148. continue;
  149. }
  150. // We've reached a pair of elements that don't match. Calculate the LCS for
  151. // both cases of either being left unmatched and take the max. Visit
  152. // `current` again afterwards, so its corresponding entry will be updated.
  153. DiffMatchIndex next_src = {current.src_offset + 1, current.dst_offset};
  154. DiffMatchIndex next_dst = {current.src_offset, current.dst_offset + 1};
  155. if (IsCalculatedOrOutOfBound(next_src) &&
  156. IsCalculatedOrOutOfBound(next_dst)) {
  157. uint32_t best_match_length =
  158. std::max(GetMemoizedLength(next_src), GetMemoizedLength(next_dst));
  159. MarkMatched(current, best_match_length, false);
  160. continue;
  161. }
  162. to_calculate.push(current);
  163. if (!IsCalculatedOrOutOfBound(next_src)) {
  164. to_calculate.push(next_src);
  165. }
  166. if (!IsCalculatedOrOutOfBound(next_dst)) {
  167. to_calculate.push(next_dst);
  168. }
  169. }
  170. }
  171. template <typename Sequence>
  172. void LongestCommonSubsequence<Sequence>::RetrieveMatch(
  173. DiffMatch* src_match_result, DiffMatch* dst_match_result) {
  174. src_match_result->clear();
  175. dst_match_result->clear();
  176. src_match_result->resize(src_.size(), false);
  177. dst_match_result->resize(dst_.size(), false);
  178. DiffMatchIndex current = {0, 0};
  179. while (IsInBound(current)) {
  180. if (IsMatched(current)) {
  181. (*src_match_result)[current.src_offset++] = true;
  182. (*dst_match_result)[current.dst_offset++] = true;
  183. continue;
  184. }
  185. if (GetMemoizedLength({current.src_offset + 1, current.dst_offset}) >=
  186. GetMemoizedLength({current.src_offset, current.dst_offset + 1})) {
  187. ++current.src_offset;
  188. } else {
  189. ++current.dst_offset;
  190. }
  191. }
  192. }
  193. } // namespace diff
  194. } // namespace spvtools
  195. #endif // SOURCE_DIFF_LCS_H_