lcs.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. // Copyright (c) 2022 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef SOURCE_DIFF_LCS_H_
  15. #define SOURCE_DIFF_LCS_H_
  16. #include <algorithm>
  17. #include <cassert>
  18. #include <cstddef>
  19. #include <functional>
  20. #include <vector>
  21. namespace spvtools {
  22. namespace diff {
  23. // The result of a diff.
  24. using DiffMatch = std::vector<bool>;
  25. // Helper class to find the longest common subsequence between two function
  26. // bodies.
  27. template <typename Sequence>
  28. class LongestCommonSubsequence {
  29. public:
  30. LongestCommonSubsequence(const Sequence& src, const Sequence& dst)
  31. : src_(src),
  32. dst_(dst),
  33. table_(src.size(), std::vector<DiffMatchEntry>(dst.size())) {}
  34. // Given two sequences, it creates a matching between them. The elements are
  35. // simply marked as matched in src and dst, with any unmatched element in src
  36. // implying a removal and any unmatched element in dst implying an addition.
  37. //
  38. // Returns the length of the longest common subsequence.
  39. template <typename T>
  40. size_t Get(std::function<bool(T src_elem, T dst_elem)> match,
  41. DiffMatch* src_match_result, DiffMatch* dst_match_result);
  42. private:
  43. template <typename T>
  44. size_t CalculateLCS(size_t src_start, size_t dst_start,
  45. std::function<bool(T src_elem, T dst_elem)> match);
  46. void RetrieveMatch(DiffMatch* src_match_result, DiffMatch* dst_match_result);
  47. bool IsInBound(size_t src_index, size_t dst_index) {
  48. return src_index < src_.size() && dst_index < dst_.size();
  49. }
  50. bool IsCalculated(size_t src_index, size_t dst_index) {
  51. assert(IsInBound(src_index, dst_index));
  52. return table_[src_index][dst_index].valid;
  53. }
  54. size_t GetMemoizedLength(size_t src_index, size_t dst_index) {
  55. if (!IsInBound(src_index, dst_index)) {
  56. return 0;
  57. }
  58. assert(IsCalculated(src_index, dst_index));
  59. return table_[src_index][dst_index].best_match_length;
  60. }
  61. bool IsMatched(size_t src_index, size_t dst_index) {
  62. assert(IsCalculated(src_index, dst_index));
  63. return table_[src_index][dst_index].matched;
  64. }
  65. const Sequence& src_;
  66. const Sequence& dst_;
  67. struct DiffMatchEntry {
  68. size_t best_match_length = 0;
  69. // Whether src[i] and dst[j] matched. This is an optimization to avoid
  70. // calling the `match` function again when walking the LCS table.
  71. bool matched = false;
  72. // Use for the recursive algorithm to know if the contents of this entry are
  73. // valid.
  74. bool valid = false;
  75. };
  76. std::vector<std::vector<DiffMatchEntry>> table_;
  77. };
  78. template <typename Sequence>
  79. template <typename T>
  80. size_t LongestCommonSubsequence<Sequence>::Get(
  81. std::function<bool(T src_elem, T dst_elem)> match,
  82. DiffMatch* src_match_result, DiffMatch* dst_match_result) {
  83. size_t best_match_length = CalculateLCS(0, 0, match);
  84. RetrieveMatch(src_match_result, dst_match_result);
  85. return best_match_length;
  86. }
  87. template <typename Sequence>
  88. template <typename T>
  89. size_t LongestCommonSubsequence<Sequence>::CalculateLCS(
  90. size_t src_start, size_t dst_start,
  91. std::function<bool(T src_elem, T dst_elem)> match) {
  92. // The LCS algorithm is simple. Given sequences s and d, with a:b depicting a
  93. // range in python syntax:
  94. //
  95. // lcs(s[i:], d[j:]) =
  96. // lcs(s[i+1:], d[j+1:]) + 1 if s[i] == d[j]
  97. // max(lcs(s[i+1:], d[j:]), lcs(s[i:], d[j+1:])) o.w.
  98. //
  99. // Once the LCS table is filled according to the above, it can be walked and
  100. // the best match retrieved.
  101. //
  102. // This is a recursive function with memoization, which avoids filling table
  103. // entries where unnecessary. This makes the best case O(N) instead of
  104. // O(N^2).
  105. // To avoid unnecessary recursion on long sequences, process a whole strip of
  106. // matching elements in one go.
  107. size_t src_cur = src_start;
  108. size_t dst_cur = dst_start;
  109. while (IsInBound(src_cur, dst_cur) && !IsCalculated(src_cur, dst_cur) &&
  110. match(src_[src_cur], dst_[dst_cur])) {
  111. ++src_cur;
  112. ++dst_cur;
  113. }
  114. // We've reached a pair of elements that don't match. Recursively determine
  115. // which one should be left unmatched.
  116. size_t best_match_length = 0;
  117. if (IsInBound(src_cur, dst_cur)) {
  118. if (IsCalculated(src_cur, dst_cur)) {
  119. best_match_length = GetMemoizedLength(src_cur, dst_cur);
  120. } else {
  121. best_match_length = std::max(CalculateLCS(src_cur + 1, dst_cur, match),
  122. CalculateLCS(src_cur, dst_cur + 1, match));
  123. // Fill the table with this information
  124. DiffMatchEntry& entry = table_[src_cur][dst_cur];
  125. assert(!entry.valid);
  126. entry.best_match_length = best_match_length;
  127. entry.valid = true;
  128. }
  129. }
  130. // Go over the matched strip and update the table as well.
  131. assert(src_cur - src_start == dst_cur - dst_start);
  132. size_t contiguous_match_len = src_cur - src_start;
  133. for (size_t i = 0; i < contiguous_match_len; ++i) {
  134. --src_cur;
  135. --dst_cur;
  136. assert(IsInBound(src_cur, dst_cur));
  137. DiffMatchEntry& entry = table_[src_cur][dst_cur];
  138. assert(!entry.valid);
  139. entry.best_match_length = ++best_match_length;
  140. entry.matched = true;
  141. entry.valid = true;
  142. }
  143. return best_match_length;
  144. }
  145. template <typename Sequence>
  146. void LongestCommonSubsequence<Sequence>::RetrieveMatch(
  147. DiffMatch* src_match_result, DiffMatch* dst_match_result) {
  148. src_match_result->clear();
  149. dst_match_result->clear();
  150. src_match_result->resize(src_.size(), false);
  151. dst_match_result->resize(dst_.size(), false);
  152. size_t src_cur = 0;
  153. size_t dst_cur = 0;
  154. while (IsInBound(src_cur, dst_cur)) {
  155. if (IsMatched(src_cur, dst_cur)) {
  156. (*src_match_result)[src_cur++] = true;
  157. (*dst_match_result)[dst_cur++] = true;
  158. continue;
  159. }
  160. if (GetMemoizedLength(src_cur + 1, dst_cur) >=
  161. GetMemoizedLength(src_cur, dst_cur + 1)) {
  162. ++src_cur;
  163. } else {
  164. ++dst_cur;
  165. }
  166. }
  167. }
  168. } // namespace diff
  169. } // namespace spvtools
  170. #endif // SOURCE_DIFF_LCS_H_