fuzz_test_util.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "test/fuzz/fuzz_test_util.h"
  15. #include "gtest/gtest.h"
  16. #include <fstream>
  17. #include <iostream>
  18. #include "source/opt/def_use_manager.h"
  19. #include "tools/io.h"
  20. namespace spvtools {
  21. namespace fuzz {
  22. const spvtools::MessageConsumer kConsoleMessageConsumer =
  23. [](spv_message_level_t level, const char*, const spv_position_t& position,
  24. const char* message) -> void {
  25. switch (level) {
  26. case SPV_MSG_FATAL:
  27. case SPV_MSG_INTERNAL_ERROR:
  28. case SPV_MSG_ERROR:
  29. std::cerr << "error: line " << position.index << ": " << message
  30. << std::endl;
  31. break;
  32. case SPV_MSG_WARNING:
  33. std::cout << "warning: line " << position.index << ": " << message
  34. << std::endl;
  35. break;
  36. case SPV_MSG_INFO:
  37. std::cout << "info: line " << position.index << ": " << message
  38. << std::endl;
  39. break;
  40. default:
  41. break;
  42. }
  43. };
  44. bool IsEqual(const spv_target_env env,
  45. const std::vector<uint32_t>& expected_binary,
  46. const std::vector<uint32_t>& actual_binary) {
  47. if (expected_binary == actual_binary) {
  48. return true;
  49. }
  50. SpirvTools t(env);
  51. std::string expected_disassembled;
  52. std::string actual_disassembled;
  53. if (!t.Disassemble(expected_binary, &expected_disassembled,
  54. kFuzzDisassembleOption)) {
  55. return false;
  56. }
  57. if (!t.Disassemble(actual_binary, &actual_disassembled,
  58. kFuzzDisassembleOption)) {
  59. return false;
  60. }
  61. // Using expect gives us a string diff if the strings are not the same.
  62. EXPECT_EQ(expected_disassembled, actual_disassembled);
  63. // We then return the result of the equality comparison, to be used by an
  64. // assertion in the test root function.
  65. return expected_disassembled == actual_disassembled;
  66. }
  67. bool IsEqual(const spv_target_env env, const std::string& expected_text,
  68. const std::vector<uint32_t>& actual_binary) {
  69. std::vector<uint32_t> expected_binary;
  70. SpirvTools t(env);
  71. if (!t.Assemble(expected_text, &expected_binary, kFuzzAssembleOption)) {
  72. return false;
  73. }
  74. return IsEqual(env, expected_binary, actual_binary);
  75. }
  76. bool IsEqual(const spv_target_env env, const std::string& expected_text,
  77. const opt::IRContext* actual_ir) {
  78. std::vector<uint32_t> actual_binary;
  79. actual_ir->module()->ToBinary(&actual_binary, false);
  80. return IsEqual(env, expected_text, actual_binary);
  81. }
  82. bool IsEqual(const spv_target_env env, const opt::IRContext* ir_1,
  83. const opt::IRContext* ir_2) {
  84. std::vector<uint32_t> binary_1;
  85. ir_1->module()->ToBinary(&binary_1, false);
  86. std::vector<uint32_t> binary_2;
  87. ir_2->module()->ToBinary(&binary_2, false);
  88. return IsEqual(env, binary_1, binary_2);
  89. }
  90. bool IsEqual(const spv_target_env env, const std::vector<uint32_t>& binary_1,
  91. const opt::IRContext* ir_2) {
  92. std::vector<uint32_t> binary_2;
  93. ir_2->module()->ToBinary(&binary_2, false);
  94. return IsEqual(env, binary_1, binary_2);
  95. }
  96. std::string ToString(spv_target_env env, const opt::IRContext* ir) {
  97. std::vector<uint32_t> binary;
  98. ir->module()->ToBinary(&binary, false);
  99. return ToString(env, binary);
  100. }
  101. std::string ToString(spv_target_env env, const std::vector<uint32_t>& binary) {
  102. SpirvTools t(env);
  103. std::string result;
  104. t.Disassemble(binary, &result, kFuzzDisassembleOption);
  105. return result;
  106. }
  107. void DumpShader(opt::IRContext* context, const char* filename) {
  108. std::vector<uint32_t> binary;
  109. context->module()->ToBinary(&binary, false);
  110. DumpShader(binary, filename);
  111. }
  112. void DumpShader(const std::vector<uint32_t>& binary, const char* filename) {
  113. auto write_file_succeeded =
  114. WriteFile(filename, "wb", &binary[0], binary.size());
  115. if (!write_file_succeeded) {
  116. std::cerr << "Failed to dump shader" << std::endl;
  117. }
  118. }
  119. void DumpTransformationsBinary(
  120. const protobufs::TransformationSequence& transformations,
  121. const char* filename) {
  122. std::ofstream transformations_file;
  123. transformations_file.open(filename, std::ios::out | std::ios::binary);
  124. transformations.SerializeToOstream(&transformations_file);
  125. transformations_file.close();
  126. }
  127. void DumpTransformationsJson(
  128. const protobufs::TransformationSequence& transformations,
  129. const char* filename) {
  130. std::string json_string;
  131. auto json_options = google::protobuf::util::JsonPrintOptions();
  132. json_options.add_whitespace = true;
  133. auto json_generation_status = google::protobuf::util::MessageToJsonString(
  134. transformations, &json_string, json_options);
  135. if (json_generation_status.ok()) {
  136. std::ofstream transformations_json_file(filename);
  137. transformations_json_file << json_string;
  138. transformations_json_file.close();
  139. }
  140. }
  141. void ApplyAndCheckFreshIds(
  142. const Transformation& transformation, opt::IRContext* ir_context,
  143. TransformationContext* transformation_context,
  144. const std::unordered_set<uint32_t>& issued_overflow_ids) {
  145. // To ensure that we cover all ToMessage and message-based constructor methods
  146. // in our tests, we turn this into a message and back into a transformation,
  147. // and use the reconstructed transformation in the rest of the function.
  148. auto message = transformation.ToMessage();
  149. auto reconstructed_transformation = Transformation::FromMessage(message);
  150. opt::analysis::DefUseManager::IdToDefMap before_transformation =
  151. ir_context->get_def_use_mgr()->id_to_defs();
  152. reconstructed_transformation->Apply(ir_context, transformation_context);
  153. opt::analysis::DefUseManager::IdToDefMap after_transformation =
  154. ir_context->get_def_use_mgr()->id_to_defs();
  155. std::unordered_set<uint32_t> fresh_ids_for_transformation =
  156. reconstructed_transformation->GetFreshIds();
  157. for (auto& entry : after_transformation) {
  158. uint32_t id = entry.first;
  159. bool introduced_by_transformation_message =
  160. fresh_ids_for_transformation.count(id);
  161. bool introduced_by_overflow_ids = issued_overflow_ids.count(id);
  162. ASSERT_FALSE(introduced_by_transformation_message &&
  163. introduced_by_overflow_ids);
  164. if (before_transformation.count(entry.first)) {
  165. ASSERT_FALSE(introduced_by_transformation_message ||
  166. introduced_by_overflow_ids);
  167. } else {
  168. ASSERT_TRUE(introduced_by_transformation_message ||
  169. introduced_by_overflow_ids);
  170. }
  171. }
  172. }
  173. } // namespace fuzz
  174. } // namespace spvtools