shrinker.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/shrinker.h"
  15. #include <sstream>
  16. #include "source/fuzz/added_function_reducer.h"
  17. #include "source/fuzz/pseudo_random_generator.h"
  18. #include "source/fuzz/replayer.h"
  19. #include "source/opt/build_module.h"
  20. #include "source/opt/ir_context.h"
  21. #include "source/spirv_fuzzer_options.h"
  22. #include "source/util/make_unique.h"
  23. namespace spvtools {
  24. namespace fuzz {
  25. namespace {
  26. // A helper to get the size of a protobuf transformation sequence in a less
  27. // verbose manner.
  28. uint32_t NumRemainingTransformations(
  29. const protobufs::TransformationSequence& transformation_sequence) {
  30. return static_cast<uint32_t>(transformation_sequence.transformation_size());
  31. }
  32. // A helper to return a transformation sequence identical to |transformations|,
  33. // except that a chunk of size |chunk_size| starting from |chunk_index| x
  34. // |chunk_size| is removed (or as many transformations as available if the whole
  35. // chunk is not).
  36. protobufs::TransformationSequence RemoveChunk(
  37. const protobufs::TransformationSequence& transformations,
  38. uint32_t chunk_index, uint32_t chunk_size) {
  39. uint32_t lower = chunk_index * chunk_size;
  40. uint32_t upper = std::min((chunk_index + 1) * chunk_size,
  41. NumRemainingTransformations(transformations));
  42. assert(lower < upper);
  43. assert(upper <= NumRemainingTransformations(transformations));
  44. protobufs::TransformationSequence result;
  45. for (uint32_t j = 0; j < NumRemainingTransformations(transformations); j++) {
  46. if (j >= lower && j < upper) {
  47. continue;
  48. }
  49. protobufs::Transformation transformation =
  50. transformations.transformation()[j];
  51. *result.mutable_transformation()->Add() = transformation;
  52. }
  53. return result;
  54. }
  55. } // namespace
  56. Shrinker::Shrinker(
  57. spv_target_env target_env, MessageConsumer consumer,
  58. const std::vector<uint32_t>& binary_in,
  59. const protobufs::FactSequence& initial_facts,
  60. const protobufs::TransformationSequence& transformation_sequence_in,
  61. const InterestingnessFunction& interestingness_function,
  62. uint32_t step_limit, bool validate_during_replay,
  63. spv_validator_options validator_options)
  64. : target_env_(target_env),
  65. consumer_(std::move(consumer)),
  66. binary_in_(binary_in),
  67. initial_facts_(initial_facts),
  68. transformation_sequence_in_(transformation_sequence_in),
  69. interestingness_function_(interestingness_function),
  70. step_limit_(step_limit),
  71. validate_during_replay_(validate_during_replay),
  72. validator_options_(validator_options) {}
  73. Shrinker::~Shrinker() = default;
  74. Shrinker::ShrinkerResult Shrinker::Run() {
  75. // Check compatibility between the library version being linked with and the
  76. // header files being used.
  77. GOOGLE_PROTOBUF_VERIFY_VERSION;
  78. SpirvTools tools(target_env_);
  79. if (!tools.IsValid()) {
  80. consumer_(SPV_MSG_ERROR, nullptr, {},
  81. "Failed to create SPIRV-Tools interface; stopping.");
  82. return {Shrinker::ShrinkerResultStatus::kFailedToCreateSpirvToolsInterface,
  83. std::vector<uint32_t>(), protobufs::TransformationSequence()};
  84. }
  85. // Initial binary should be valid.
  86. if (!tools.Validate(&binary_in_[0], binary_in_.size(), validator_options_)) {
  87. consumer_(SPV_MSG_INFO, nullptr, {},
  88. "Initial binary is invalid; stopping.");
  89. return {Shrinker::ShrinkerResultStatus::kInitialBinaryInvalid,
  90. std::vector<uint32_t>(), protobufs::TransformationSequence()};
  91. }
  92. // Run a replay of the initial transformation sequence to check that it
  93. // succeeds.
  94. auto initial_replay_result =
  95. Replayer(target_env_, consumer_, binary_in_, initial_facts_,
  96. transformation_sequence_in_,
  97. static_cast<uint32_t>(
  98. transformation_sequence_in_.transformation_size()),
  99. validate_during_replay_, validator_options_)
  100. .Run();
  101. if (initial_replay_result.status !=
  102. Replayer::ReplayerResultStatus::kComplete) {
  103. return {ShrinkerResultStatus::kReplayFailed, std::vector<uint32_t>(),
  104. protobufs::TransformationSequence()};
  105. }
  106. // Get the binary that results from running these transformations, and the
  107. // subsequence of the initial transformations that actually apply (in
  108. // principle this could be a strict subsequence).
  109. std::vector<uint32_t> current_best_binary;
  110. initial_replay_result.transformed_module->module()->ToBinary(
  111. &current_best_binary, false);
  112. protobufs::TransformationSequence current_best_transformations =
  113. std::move(initial_replay_result.applied_transformations);
  114. // Check that the binary produced by applying the initial transformations is
  115. // indeed interesting.
  116. if (!interestingness_function_(current_best_binary, 0)) {
  117. consumer_(SPV_MSG_INFO, nullptr, {},
  118. "Initial binary is not interesting; stopping.");
  119. return {ShrinkerResultStatus::kInitialBinaryNotInteresting,
  120. std::vector<uint32_t>(), protobufs::TransformationSequence()};
  121. }
  122. uint32_t attempt = 0; // Keeps track of the number of shrink attempts that
  123. // have been tried, whether successful or not.
  124. uint32_t chunk_size =
  125. std::max(1u, NumRemainingTransformations(current_best_transformations) /
  126. 2); // The number of contiguous transformations that the
  127. // shrinker will try to remove in one go; starts
  128. // high and decreases during the shrinking process.
  129. // Keep shrinking until we:
  130. // - reach the step limit,
  131. // - run out of transformations to remove, or
  132. // - cannot make the chunk size any smaller.
  133. while (attempt < step_limit_ &&
  134. !current_best_transformations.transformation().empty() &&
  135. chunk_size > 0) {
  136. bool progress_this_round =
  137. false; // Used to decide whether to make the chunk size with which we
  138. // remove transformations smaller. If we managed to remove at
  139. // least one chunk of transformations at a particular chunk
  140. // size, we set this flag so that we do not yet decrease the
  141. // chunk size.
  142. assert(chunk_size <=
  143. NumRemainingTransformations(current_best_transformations) &&
  144. "Chunk size should never exceed the number of transformations that "
  145. "remain.");
  146. // The number of chunks is the ceiling of (#remaining_transformations /
  147. // chunk_size).
  148. const uint32_t num_chunks =
  149. (NumRemainingTransformations(current_best_transformations) +
  150. chunk_size - 1) /
  151. chunk_size;
  152. assert(num_chunks >= 1 && "There should be at least one chunk.");
  153. assert(num_chunks * chunk_size >=
  154. NumRemainingTransformations(current_best_transformations) &&
  155. "All transformations should be in some chunk.");
  156. // We go through the transformations in reverse, in chunks of size
  157. // |chunk_size|, using |chunk_index| to track which chunk to try removing
  158. // next. The loop exits early if we reach the shrinking step limit.
  159. for (int chunk_index = num_chunks - 1;
  160. attempt < step_limit_ && chunk_index >= 0; chunk_index--) {
  161. // Remove a chunk of transformations according to the current index and
  162. // chunk size.
  163. auto transformations_with_chunk_removed =
  164. RemoveChunk(current_best_transformations,
  165. static_cast<uint32_t>(chunk_index), chunk_size);
  166. // Replay the smaller sequence of transformations to get a next binary and
  167. // transformation sequence. Note that the transformations arising from
  168. // replay might be even smaller than the transformations with the chunk
  169. // removed, because removing those transformations might make further
  170. // transformations inapplicable.
  171. auto replay_result =
  172. Replayer(
  173. target_env_, consumer_, binary_in_, initial_facts_,
  174. transformations_with_chunk_removed,
  175. static_cast<uint32_t>(
  176. transformations_with_chunk_removed.transformation_size()),
  177. validate_during_replay_, validator_options_)
  178. .Run();
  179. if (replay_result.status != Replayer::ReplayerResultStatus::kComplete) {
  180. // Replay should not fail; if it does, we need to abort shrinking.
  181. return {ShrinkerResultStatus::kReplayFailed, std::vector<uint32_t>(),
  182. protobufs::TransformationSequence()};
  183. }
  184. assert(
  185. NumRemainingTransformations(replay_result.applied_transformations) >=
  186. chunk_index * chunk_size &&
  187. "Removing this chunk of transformations should not have an effect "
  188. "on earlier chunks.");
  189. std::vector<uint32_t> transformed_binary;
  190. replay_result.transformed_module->module()->ToBinary(&transformed_binary,
  191. false);
  192. if (interestingness_function_(transformed_binary, attempt)) {
  193. // If the binary arising from the smaller transformation sequence is
  194. // interesting, this becomes our current best binary and transformation
  195. // sequence.
  196. current_best_binary = std::move(transformed_binary);
  197. current_best_transformations =
  198. std::move(replay_result.applied_transformations);
  199. progress_this_round = true;
  200. }
  201. // Either way, this was a shrink attempt, so increment our count of shrink
  202. // attempts.
  203. attempt++;
  204. }
  205. if (!progress_this_round) {
  206. // If we didn't manage to remove any chunks at this chunk size, try a
  207. // smaller chunk size.
  208. chunk_size /= 2;
  209. }
  210. // Decrease the chunk size until it becomes no larger than the number of
  211. // remaining transformations.
  212. while (chunk_size >
  213. NumRemainingTransformations(current_best_transformations)) {
  214. chunk_size /= 2;
  215. }
  216. }
  217. // We now use spirv-reduce to minimise the functions associated with any
  218. // AddFunction transformations that remain.
  219. //
  220. // Consider every remaining transformation.
  221. for (uint32_t transformation_index = 0;
  222. attempt < step_limit_ &&
  223. transformation_index <
  224. static_cast<uint32_t>(
  225. current_best_transformations.transformation_size());
  226. transformation_index++) {
  227. // Skip all transformations apart from TransformationAddFunction.
  228. if (!current_best_transformations.transformation(transformation_index)
  229. .has_add_function()) {
  230. continue;
  231. }
  232. // Invoke spirv-reduce on the function encoded in this AddFunction
  233. // transformation. The details of this are rather involved, and so are
  234. // encapsulated in a separate class.
  235. auto added_function_reducer_result =
  236. AddedFunctionReducer(target_env_, consumer_, binary_in_, initial_facts_,
  237. current_best_transformations, transformation_index,
  238. interestingness_function_, validate_during_replay_,
  239. validator_options_, step_limit_, attempt)
  240. .Run();
  241. // Reducing the added function should succeed. If it doesn't, we report
  242. // a shrinking error.
  243. if (added_function_reducer_result.status !=
  244. AddedFunctionReducer::AddedFunctionReducerResultStatus::kComplete) {
  245. return {ShrinkerResultStatus::kAddedFunctionReductionFailed,
  246. std::vector<uint32_t>(), protobufs::TransformationSequence()};
  247. }
  248. assert(current_best_transformations.transformation_size() ==
  249. added_function_reducer_result.applied_transformations
  250. .transformation_size() &&
  251. "The number of transformations should not have changed.");
  252. current_best_binary =
  253. std::move(added_function_reducer_result.transformed_binary);
  254. current_best_transformations =
  255. std::move(added_function_reducer_result.applied_transformations);
  256. // The added function reducer reports how many reduction attempts
  257. // spirv-reduce took when reducing the function. We regard each of these
  258. // as a shrinker attempt.
  259. attempt += added_function_reducer_result.num_reduction_attempts;
  260. }
  261. // Indicate whether shrinking completed or was truncated due to reaching the
  262. // step limit.
  263. //
  264. // Either way, the output from the shrinker is the best binary we saw, and the
  265. // transformations that led to it.
  266. assert(attempt <= step_limit_);
  267. if (attempt == step_limit_) {
  268. std::stringstream strstream;
  269. strstream << "Shrinking did not complete; step limit " << step_limit_
  270. << " was reached.";
  271. consumer_(SPV_MSG_WARNING, nullptr, {}, strstream.str().c_str());
  272. return {Shrinker::ShrinkerResultStatus::kStepLimitReached,
  273. std::move(current_best_binary),
  274. std::move(current_best_transformations)};
  275. }
  276. return {Shrinker::ShrinkerResultStatus::kComplete,
  277. std::move(current_best_binary),
  278. std::move(current_best_transformations)};
  279. }
  280. uint32_t Shrinker::GetIdBound(const std::vector<uint32_t>& binary) const {
  281. // Build the module from the input binary.
  282. std::unique_ptr<opt::IRContext> ir_context =
  283. BuildModule(target_env_, consumer_, binary.data(), binary.size());
  284. assert(ir_context && "Error building module.");
  285. return ir_context->module()->id_bound();
  286. }
  287. } // namespace fuzz
  288. } // namespace spvtools