force_render_red.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/force_render_red.h"
  15. #include "source/fuzz/fact_manager.h"
  16. #include "source/fuzz/instruction_descriptor.h"
  17. #include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
  18. #include "source/fuzz/transformation_context.h"
  19. #include "source/fuzz/transformation_replace_constant_with_uniform.h"
  20. #include "source/fuzz/uniform_buffer_element_descriptor.h"
  21. #include "source/opt/build_module.h"
  22. #include "source/opt/ir_context.h"
  23. #include "source/opt/types.h"
  24. #include "source/util/make_unique.h"
  25. #include "tools/util/cli_consumer.h"
  26. #include <algorithm>
  27. #include <utility>
  28. namespace spvtools {
  29. namespace fuzz {
  30. namespace {
  31. // Helper method to find the fragment shader entry point, complaining if there
  32. // is no shader or if there is no fragment entry point.
  33. opt::Function* FindFragmentShaderEntryPoint(opt::IRContext* ir_context,
  34. MessageConsumer message_consumer) {
  35. // Check that this is a fragment shader
  36. bool found_capability_shader = false;
  37. for (auto& capability : ir_context->capabilities()) {
  38. assert(capability.opcode() == SpvOpCapability);
  39. if (capability.GetSingleWordInOperand(0) == SpvCapabilityShader) {
  40. found_capability_shader = true;
  41. break;
  42. }
  43. }
  44. if (!found_capability_shader) {
  45. message_consumer(
  46. SPV_MSG_ERROR, nullptr, {},
  47. "Forcing of red rendering requires the Shader capability.");
  48. return nullptr;
  49. }
  50. opt::Instruction* fragment_entry_point = nullptr;
  51. for (auto& entry_point : ir_context->module()->entry_points()) {
  52. if (entry_point.GetSingleWordInOperand(0) == SpvExecutionModelFragment) {
  53. fragment_entry_point = &entry_point;
  54. break;
  55. }
  56. }
  57. if (fragment_entry_point == nullptr) {
  58. message_consumer(SPV_MSG_ERROR, nullptr, {},
  59. "Forcing of red rendering requires an entry point with "
  60. "the Fragment execution model.");
  61. return nullptr;
  62. }
  63. for (auto& function : *ir_context->module()) {
  64. if (function.result_id() ==
  65. fragment_entry_point->GetSingleWordInOperand(1)) {
  66. return &function;
  67. }
  68. }
  69. assert(
  70. false &&
  71. "A valid module must have a function associate with each entry point.");
  72. return nullptr;
  73. }
  74. // Helper method to check that there is a single vec4 output variable and get a
  75. // pointer to it.
  76. opt::Instruction* FindVec4OutputVariable(opt::IRContext* ir_context,
  77. MessageConsumer message_consumer) {
  78. opt::Instruction* output_variable = nullptr;
  79. for (auto& inst : ir_context->types_values()) {
  80. if (inst.opcode() == SpvOpVariable &&
  81. inst.GetSingleWordInOperand(0) == SpvStorageClassOutput) {
  82. if (output_variable != nullptr) {
  83. message_consumer(SPV_MSG_ERROR, nullptr, {},
  84. "Only one output variable can be handled at present; "
  85. "found multiple.");
  86. return nullptr;
  87. }
  88. output_variable = &inst;
  89. // Do not break, as we want to check for multiple output variables.
  90. }
  91. }
  92. if (output_variable == nullptr) {
  93. message_consumer(SPV_MSG_ERROR, nullptr, {},
  94. "No output variable to which to write red was found.");
  95. return nullptr;
  96. }
  97. auto output_variable_base_type = ir_context->get_type_mgr()
  98. ->GetType(output_variable->type_id())
  99. ->AsPointer()
  100. ->pointee_type()
  101. ->AsVector();
  102. if (!output_variable_base_type ||
  103. output_variable_base_type->element_count() != 4 ||
  104. !output_variable_base_type->element_type()->AsFloat()) {
  105. message_consumer(SPV_MSG_ERROR, nullptr, {},
  106. "The output variable must have type vec4.");
  107. return nullptr;
  108. }
  109. return output_variable;
  110. }
  111. // Helper to get the ids of float constants 0.0 and 1.0, creating them if
  112. // necessary.
  113. std::pair<uint32_t, uint32_t> FindOrCreateFloatZeroAndOne(
  114. opt::IRContext* ir_context, opt::analysis::Float* float_type) {
  115. float one = 1.0;
  116. uint32_t one_as_uint;
  117. memcpy(&one_as_uint, &one, sizeof(float));
  118. std::vector<uint32_t> zero_bytes = {0};
  119. std::vector<uint32_t> one_bytes = {one_as_uint};
  120. auto constant_zero = ir_context->get_constant_mgr()->RegisterConstant(
  121. MakeUnique<opt::analysis::FloatConstant>(float_type, zero_bytes));
  122. auto constant_one = ir_context->get_constant_mgr()->RegisterConstant(
  123. MakeUnique<opt::analysis::FloatConstant>(float_type, one_bytes));
  124. auto constant_zero_id = ir_context->get_constant_mgr()
  125. ->GetDefiningInstruction(constant_zero)
  126. ->result_id();
  127. auto constant_one_id = ir_context->get_constant_mgr()
  128. ->GetDefiningInstruction(constant_one)
  129. ->result_id();
  130. return std::pair<uint32_t, uint32_t>(constant_zero_id, constant_one_id);
  131. }
  132. std::unique_ptr<TransformationReplaceConstantWithUniform>
  133. MakeConstantUniformReplacement(opt::IRContext* ir_context,
  134. const FactManager& fact_manager,
  135. uint32_t constant_id,
  136. uint32_t greater_than_instruction,
  137. uint32_t in_operand_index) {
  138. return MakeUnique<TransformationReplaceConstantWithUniform>(
  139. MakeIdUseDescriptor(constant_id,
  140. MakeInstructionDescriptor(greater_than_instruction,
  141. SpvOpFOrdGreaterThan, 0),
  142. in_operand_index),
  143. fact_manager.GetUniformDescriptorsForConstant(ir_context, constant_id)[0],
  144. ir_context->TakeNextId(), ir_context->TakeNextId());
  145. }
  146. } // namespace
  147. bool ForceRenderRed(
  148. const spv_target_env& target_env, spv_validator_options validator_options,
  149. const std::vector<uint32_t>& binary_in,
  150. const spvtools::fuzz::protobufs::FactSequence& initial_facts,
  151. std::vector<uint32_t>* binary_out) {
  152. auto message_consumer = spvtools::utils::CLIMessageConsumer;
  153. spvtools::SpirvTools tools(target_env);
  154. if (!tools.IsValid()) {
  155. message_consumer(SPV_MSG_ERROR, nullptr, {},
  156. "Failed to create SPIRV-Tools interface; stopping.");
  157. return false;
  158. }
  159. // Initial binary should be valid.
  160. if (!tools.Validate(&binary_in[0], binary_in.size(), validator_options)) {
  161. message_consumer(SPV_MSG_ERROR, nullptr, {},
  162. "Initial binary is invalid; stopping.");
  163. return false;
  164. }
  165. // Build the module from the input binary.
  166. std::unique_ptr<opt::IRContext> ir_context = BuildModule(
  167. target_env, message_consumer, binary_in.data(), binary_in.size());
  168. assert(ir_context);
  169. // Set up a fact manager with any given initial facts.
  170. FactManager fact_manager;
  171. for (auto& fact : initial_facts.fact()) {
  172. fact_manager.AddFact(fact, ir_context.get());
  173. }
  174. TransformationContext transformation_context(&fact_manager,
  175. validator_options);
  176. auto entry_point_function =
  177. FindFragmentShaderEntryPoint(ir_context.get(), message_consumer);
  178. auto output_variable =
  179. FindVec4OutputVariable(ir_context.get(), message_consumer);
  180. if (entry_point_function == nullptr || output_variable == nullptr) {
  181. return false;
  182. }
  183. opt::analysis::Float temp_float_type(32);
  184. opt::analysis::Float* float_type = ir_context->get_type_mgr()
  185. ->GetRegisteredType(&temp_float_type)
  186. ->AsFloat();
  187. std::pair<uint32_t, uint32_t> zero_one_float_ids =
  188. FindOrCreateFloatZeroAndOne(ir_context.get(), float_type);
  189. // Make the new exit block
  190. auto new_exit_block_id = ir_context->TakeNextId();
  191. {
  192. auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
  193. new_exit_block_id,
  194. opt::Instruction::OperandList());
  195. auto new_exit_block = MakeUnique<opt::BasicBlock>(std::move(label));
  196. new_exit_block->AddInstruction(MakeUnique<opt::Instruction>(
  197. ir_context.get(), SpvOpReturn, 0, 0, opt::Instruction::OperandList()));
  198. entry_point_function->AddBasicBlock(std::move(new_exit_block));
  199. }
  200. // Make the new entry block
  201. {
  202. auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
  203. ir_context->TakeNextId(),
  204. opt::Instruction::OperandList());
  205. auto new_entry_block = MakeUnique<opt::BasicBlock>(std::move(label));
  206. // Make an instruction to construct vec4(1.0, 0.0, 0.0, 1.0), representing
  207. // the colour red.
  208. opt::Operand zero_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.first}};
  209. opt::Operand one_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.second}};
  210. opt::Instruction::OperandList op_composite_construct_operands = {
  211. one_float, zero_float, zero_float, one_float};
  212. auto temp_vec4 = opt::analysis::Vector(float_type, 4);
  213. auto vec4_id = ir_context->get_type_mgr()->GetId(&temp_vec4);
  214. auto red = MakeUnique<opt::Instruction>(
  215. ir_context.get(), SpvOpCompositeConstruct, vec4_id,
  216. ir_context->TakeNextId(), op_composite_construct_operands);
  217. auto red_id = red->result_id();
  218. new_entry_block->AddInstruction(std::move(red));
  219. // Make an instruction to store red into the output color.
  220. opt::Operand variable_to_store_into = {SPV_OPERAND_TYPE_ID,
  221. {output_variable->result_id()}};
  222. opt::Operand value_to_be_stored = {SPV_OPERAND_TYPE_ID, {red_id}};
  223. opt::Instruction::OperandList op_store_operands = {variable_to_store_into,
  224. value_to_be_stored};
  225. new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
  226. ir_context.get(), SpvOpStore, 0, 0, op_store_operands));
  227. // We are going to attempt to construct 'false' as an expression of the form
  228. // 'literal1 > literal2'. If we succeed, we will later replace each literal
  229. // with a uniform of the same value - we can only do that replacement once
  230. // we have added the entry block to the module.
  231. std::unique_ptr<TransformationReplaceConstantWithUniform>
  232. first_greater_then_operand_replacement = nullptr;
  233. std::unique_ptr<TransformationReplaceConstantWithUniform>
  234. second_greater_then_operand_replacement = nullptr;
  235. uint32_t id_guaranteed_to_be_false = 0;
  236. opt::analysis::Bool temp_bool_type;
  237. opt::analysis::Bool* registered_bool_type =
  238. ir_context->get_type_mgr()
  239. ->GetRegisteredType(&temp_bool_type)
  240. ->AsBool();
  241. auto float_type_id = ir_context->get_type_mgr()->GetId(float_type);
  242. auto types_for_which_uniforms_are_known =
  243. fact_manager.GetTypesForWhichUniformValuesAreKnown();
  244. // Check whether we have any float uniforms.
  245. if (std::find(types_for_which_uniforms_are_known.begin(),
  246. types_for_which_uniforms_are_known.end(),
  247. float_type_id) != types_for_which_uniforms_are_known.end()) {
  248. // We have at least one float uniform; let's see whether we have at least
  249. // two.
  250. auto available_constants =
  251. fact_manager.GetConstantsAvailableFromUniformsForType(
  252. ir_context.get(), float_type_id);
  253. if (available_constants.size() > 1) {
  254. // Grab the float constants associated with the first two known float
  255. // uniforms.
  256. auto first_constant =
  257. ir_context->get_constant_mgr()
  258. ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
  259. available_constants[0]))
  260. ->AsFloatConstant();
  261. auto second_constant =
  262. ir_context->get_constant_mgr()
  263. ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
  264. available_constants[1]))
  265. ->AsFloatConstant();
  266. // Now work out which of the two constants is larger than the other.
  267. uint32_t larger_constant_index = 0;
  268. uint32_t smaller_constant_index = 0;
  269. if (first_constant->GetFloat() > second_constant->GetFloat()) {
  270. larger_constant_index = 0;
  271. smaller_constant_index = 1;
  272. } else if (first_constant->GetFloat() < second_constant->GetFloat()) {
  273. larger_constant_index = 1;
  274. smaller_constant_index = 0;
  275. }
  276. // Only proceed with these constants if they have turned out to be
  277. // distinct.
  278. if (larger_constant_index != smaller_constant_index) {
  279. // We are in a position to create 'false' as 'literal1 > literal2', so
  280. // reserve an id for this computation; this id will end up being
  281. // guaranteed to be 'false'.
  282. id_guaranteed_to_be_false = ir_context->TakeNextId();
  283. auto smaller_constant = available_constants[smaller_constant_index];
  284. auto larger_constant = available_constants[larger_constant_index];
  285. opt::Instruction::OperandList greater_than_operands = {
  286. {SPV_OPERAND_TYPE_ID, {smaller_constant}},
  287. {SPV_OPERAND_TYPE_ID, {larger_constant}}};
  288. new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
  289. ir_context.get(), SpvOpFOrdGreaterThan,
  290. ir_context->get_type_mgr()->GetId(registered_bool_type),
  291. id_guaranteed_to_be_false, greater_than_operands));
  292. first_greater_then_operand_replacement =
  293. MakeConstantUniformReplacement(ir_context.get(), fact_manager,
  294. smaller_constant,
  295. id_guaranteed_to_be_false, 0);
  296. second_greater_then_operand_replacement =
  297. MakeConstantUniformReplacement(ir_context.get(), fact_manager,
  298. larger_constant,
  299. id_guaranteed_to_be_false, 1);
  300. }
  301. }
  302. }
  303. if (id_guaranteed_to_be_false == 0) {
  304. auto constant_false = ir_context->get_constant_mgr()->RegisterConstant(
  305. MakeUnique<opt::analysis::BoolConstant>(registered_bool_type, false));
  306. id_guaranteed_to_be_false = ir_context->get_constant_mgr()
  307. ->GetDefiningInstruction(constant_false)
  308. ->result_id();
  309. }
  310. opt::Operand false_condition = {SPV_OPERAND_TYPE_ID,
  311. {id_guaranteed_to_be_false}};
  312. opt::Operand then_block = {SPV_OPERAND_TYPE_ID,
  313. {entry_point_function->entry()->id()}};
  314. opt::Operand else_block = {SPV_OPERAND_TYPE_ID, {new_exit_block_id}};
  315. opt::Instruction::OperandList op_branch_conditional_operands = {
  316. false_condition, then_block, else_block};
  317. new_entry_block->AddInstruction(
  318. MakeUnique<opt::Instruction>(ir_context.get(), SpvOpBranchConditional,
  319. 0, 0, op_branch_conditional_operands));
  320. entry_point_function->InsertBasicBlockBefore(
  321. std::move(new_entry_block), entry_point_function->entry().get());
  322. for (auto& replacement : {first_greater_then_operand_replacement.get(),
  323. second_greater_then_operand_replacement.get()}) {
  324. if (replacement) {
  325. assert(replacement->IsApplicable(ir_context.get(),
  326. transformation_context));
  327. replacement->Apply(ir_context.get(), &transformation_context);
  328. }
  329. }
  330. }
  331. // Write out the module as a binary.
  332. ir_context->module()->ToBinary(binary_out, false);
  333. return true;
  334. }
  335. } // namespace fuzz
  336. } // namespace spvtools