transformation_add_function.cpp 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_add_function.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_message.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationAddFunction::TransformationAddFunction(
  20. protobufs::TransformationAddFunction message)
  21. : message_(std::move(message)) {}
  22. TransformationAddFunction::TransformationAddFunction(
  23. const std::vector<protobufs::Instruction>& instructions) {
  24. for (auto& instruction : instructions) {
  25. *message_.add_instruction() = instruction;
  26. }
  27. message_.set_is_livesafe(false);
  28. }
  29. TransformationAddFunction::TransformationAddFunction(
  30. const std::vector<protobufs::Instruction>& instructions,
  31. uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
  32. const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
  33. uint32_t kill_unreachable_return_value_id,
  34. const std::vector<protobufs::AccessChainClampingInfo>&
  35. access_chain_clampers) {
  36. for (auto& instruction : instructions) {
  37. *message_.add_instruction() = instruction;
  38. }
  39. message_.set_is_livesafe(true);
  40. message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
  41. message_.set_loop_limit_constant_id(loop_limit_constant_id);
  42. for (auto& loop_limiter : loop_limiters) {
  43. *message_.add_loop_limiter_info() = loop_limiter;
  44. }
  45. message_.set_kill_unreachable_return_value_id(
  46. kill_unreachable_return_value_id);
  47. for (auto& access_clamper : access_chain_clampers) {
  48. *message_.add_access_chain_clamping_info() = access_clamper;
  49. }
  50. }
  51. bool TransformationAddFunction::IsApplicable(
  52. opt::IRContext* ir_context,
  53. const TransformationContext& transformation_context) const {
  54. // This transformation may use a lot of ids, all of which need to be fresh
  55. // and distinct. This set tracks them.
  56. std::set<uint32_t> ids_used_by_this_transformation;
  57. // Ensure that all result ids in the new function are fresh and distinct.
  58. for (auto& instruction : message_.instruction()) {
  59. if (instruction.result_id()) {
  60. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  61. instruction.result_id(), ir_context,
  62. &ids_used_by_this_transformation)) {
  63. return false;
  64. }
  65. }
  66. }
  67. if (message_.is_livesafe()) {
  68. // Ensure that all ids provided for making the function livesafe are fresh
  69. // and distinct.
  70. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  71. message_.loop_limiter_variable_id(), ir_context,
  72. &ids_used_by_this_transformation)) {
  73. return false;
  74. }
  75. for (auto& loop_limiter_info : message_.loop_limiter_info()) {
  76. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  77. loop_limiter_info.load_id(), ir_context,
  78. &ids_used_by_this_transformation)) {
  79. return false;
  80. }
  81. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  82. loop_limiter_info.increment_id(), ir_context,
  83. &ids_used_by_this_transformation)) {
  84. return false;
  85. }
  86. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  87. loop_limiter_info.compare_id(), ir_context,
  88. &ids_used_by_this_transformation)) {
  89. return false;
  90. }
  91. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  92. loop_limiter_info.logical_op_id(), ir_context,
  93. &ids_used_by_this_transformation)) {
  94. return false;
  95. }
  96. }
  97. for (auto& access_chain_clamping_info :
  98. message_.access_chain_clamping_info()) {
  99. for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
  100. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  101. pair.first(), ir_context, &ids_used_by_this_transformation)) {
  102. return false;
  103. }
  104. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  105. pair.second(), ir_context, &ids_used_by_this_transformation)) {
  106. return false;
  107. }
  108. }
  109. }
  110. }
  111. // Because checking all the conditions for a function to be valid is a big
  112. // job that the SPIR-V validator can already do, a "try it and see" approach
  113. // is taken here.
  114. // We first clone the current module, so that we can try adding the new
  115. // function without risking wrecking |ir_context|.
  116. auto cloned_module = fuzzerutil::CloneIRContext(ir_context);
  117. // We try to add a function to the cloned module, which may fail if
  118. // |message_.instruction| is not sufficiently well-formed.
  119. if (!TryToAddFunction(cloned_module.get())) {
  120. return false;
  121. }
  122. // Check whether the cloned module is still valid after adding the function.
  123. // If it is not, the transformation is not applicable.
  124. if (!fuzzerutil::IsValid(cloned_module.get(),
  125. transformation_context.GetValidatorOptions(),
  126. fuzzerutil::kSilentMessageConsumer)) {
  127. return false;
  128. }
  129. if (message_.is_livesafe()) {
  130. if (!TryToMakeFunctionLivesafe(cloned_module.get(),
  131. transformation_context)) {
  132. return false;
  133. }
  134. // After making the function livesafe, we check validity of the module
  135. // again. This is because the turning of OpKill, OpUnreachable and OpReturn
  136. // instructions into branches changes control flow graph reachability, which
  137. // has the potential to make the module invalid when it was otherwise valid.
  138. // It is simpler to rely on the validator to guard against this than to
  139. // consider all scenarios when making a function livesafe.
  140. if (!fuzzerutil::IsValid(cloned_module.get(),
  141. transformation_context.GetValidatorOptions(),
  142. fuzzerutil::kSilentMessageConsumer)) {
  143. return false;
  144. }
  145. }
  146. return true;
  147. }
  148. void TransformationAddFunction::Apply(
  149. opt::IRContext* ir_context,
  150. TransformationContext* transformation_context) const {
  151. // Add the function to the module. As the transformation is applicable, this
  152. // should succeed.
  153. bool success = TryToAddFunction(ir_context);
  154. assert(success && "The function should be successfully added.");
  155. (void)(success); // Keep release builds happy (otherwise they may complain
  156. // that |success| is not used).
  157. if (message_.is_livesafe()) {
  158. // Make the function livesafe, which also should succeed.
  159. success = TryToMakeFunctionLivesafe(ir_context, *transformation_context);
  160. assert(success && "It should be possible to make the function livesafe.");
  161. (void)(success); // Keep release builds happy.
  162. }
  163. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  164. assert(spv::Op(message_.instruction(0).opcode()) == spv::Op::OpFunction &&
  165. "The first instruction of an 'add function' transformation must be "
  166. "OpFunction.");
  167. if (message_.is_livesafe()) {
  168. // Inform the fact manager that the function is livesafe.
  169. transformation_context->GetFactManager()->AddFactFunctionIsLivesafe(
  170. message_.instruction(0).result_id());
  171. } else {
  172. // Inform the fact manager that all blocks in the function are dead.
  173. for (auto& inst : message_.instruction()) {
  174. if (spv::Op(inst.opcode()) == spv::Op::OpLabel) {
  175. transformation_context->GetFactManager()->AddFactBlockIsDead(
  176. inst.result_id());
  177. }
  178. }
  179. }
  180. // Record the fact that all pointer parameters and variables declared in the
  181. // function should be regarded as having irrelevant values. This allows other
  182. // passes to store arbitrarily to such variables, and to pass them freely as
  183. // parameters to other functions knowing that it is OK if they get
  184. // over-written.
  185. for (auto& instruction : message_.instruction()) {
  186. switch (spv::Op(instruction.opcode())) {
  187. case spv::Op::OpFunctionParameter:
  188. if (ir_context->get_def_use_mgr()
  189. ->GetDef(instruction.result_type_id())
  190. ->opcode() == spv::Op::OpTypePointer) {
  191. transformation_context->GetFactManager()
  192. ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
  193. }
  194. break;
  195. case spv::Op::OpVariable:
  196. transformation_context->GetFactManager()
  197. ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
  198. break;
  199. default:
  200. break;
  201. }
  202. }
  203. }
  204. protobufs::Transformation TransformationAddFunction::ToMessage() const {
  205. protobufs::Transformation result;
  206. *result.mutable_add_function() = message_;
  207. return result;
  208. }
  209. bool TransformationAddFunction::TryToAddFunction(
  210. opt::IRContext* ir_context) const {
  211. // This function returns false if |message_.instruction| was not well-formed
  212. // enough to actually create a function and add it to |ir_context|.
  213. // A function must have at least some instructions.
  214. if (message_.instruction().empty()) {
  215. return false;
  216. }
  217. // A function must start with OpFunction.
  218. auto function_begin = message_.instruction(0);
  219. if (spv::Op(function_begin.opcode()) != spv::Op::OpFunction) {
  220. return false;
  221. }
  222. // Make a function, headed by the OpFunction instruction.
  223. std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
  224. InstructionFromMessage(ir_context, function_begin));
  225. // Keeps track of which instruction protobuf message we are currently
  226. // considering.
  227. uint32_t instruction_index = 1;
  228. const auto num_instructions =
  229. static_cast<uint32_t>(message_.instruction().size());
  230. // Iterate through all function parameter instructions, adding parameters to
  231. // the new function.
  232. while (instruction_index < num_instructions &&
  233. spv::Op(message_.instruction(instruction_index).opcode()) ==
  234. spv::Op::OpFunctionParameter) {
  235. new_function->AddParameter(InstructionFromMessage(
  236. ir_context, message_.instruction(instruction_index)));
  237. instruction_index++;
  238. }
  239. // After the parameters, there needs to be a label.
  240. if (instruction_index == num_instructions ||
  241. spv::Op(message_.instruction(instruction_index).opcode()) !=
  242. spv::Op::OpLabel) {
  243. return false;
  244. }
  245. // Iterate through the instructions block by block until the end of the
  246. // function is reached.
  247. while (instruction_index < num_instructions &&
  248. spv::Op(message_.instruction(instruction_index).opcode()) !=
  249. spv::Op::OpFunctionEnd) {
  250. // Invariant: we should always be at a label instruction at this point.
  251. assert(spv::Op(message_.instruction(instruction_index).opcode()) ==
  252. spv::Op::OpLabel);
  253. // Make a basic block using the label instruction.
  254. std::unique_ptr<opt::BasicBlock> block =
  255. MakeUnique<opt::BasicBlock>(InstructionFromMessage(
  256. ir_context, message_.instruction(instruction_index)));
  257. // Consider successive instructions until we hit another label or the end
  258. // of the function, adding each such instruction to the block.
  259. instruction_index++;
  260. while (instruction_index < num_instructions &&
  261. spv::Op(message_.instruction(instruction_index).opcode()) !=
  262. spv::Op::OpFunctionEnd &&
  263. spv::Op(message_.instruction(instruction_index).opcode()) !=
  264. spv::Op::OpLabel) {
  265. block->AddInstruction(InstructionFromMessage(
  266. ir_context, message_.instruction(instruction_index)));
  267. instruction_index++;
  268. }
  269. // Add the block to the new function.
  270. new_function->AddBasicBlock(std::move(block));
  271. }
  272. // Having considered all the blocks, we should be at the last instruction and
  273. // it needs to be OpFunctionEnd.
  274. if (instruction_index != num_instructions - 1 ||
  275. spv::Op(message_.instruction(instruction_index).opcode()) !=
  276. spv::Op::OpFunctionEnd) {
  277. return false;
  278. }
  279. // Set the function's final instruction, add the function to the module and
  280. // report success.
  281. new_function->SetFunctionEnd(InstructionFromMessage(
  282. ir_context, message_.instruction(instruction_index)));
  283. ir_context->AddFunction(std::move(new_function));
  284. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  285. return true;
  286. }
  287. bool TransformationAddFunction::TryToMakeFunctionLivesafe(
  288. opt::IRContext* ir_context,
  289. const TransformationContext& transformation_context) const {
  290. assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
  291. // Get a pointer to the added function.
  292. opt::Function* added_function = nullptr;
  293. for (auto& function : *ir_context->module()) {
  294. if (function.result_id() == message_.instruction(0).result_id()) {
  295. added_function = &function;
  296. break;
  297. }
  298. }
  299. assert(added_function && "The added function should have been found.");
  300. if (!TryToAddLoopLimiters(ir_context, added_function)) {
  301. // Adding loop limiters did not work; bail out.
  302. return false;
  303. }
  304. // Consider all the instructions in the function, and:
  305. // - attempt to replace OpKill and OpUnreachable with return instructions
  306. // - attempt to clamp access chains to be within bounds
  307. // - check that OpFunctionCall instructions are only to livesafe functions
  308. for (auto& block : *added_function) {
  309. for (auto& inst : block) {
  310. switch (inst.opcode()) {
  311. case spv::Op::OpKill:
  312. case spv::Op::OpUnreachable:
  313. if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function,
  314. &inst)) {
  315. return false;
  316. }
  317. break;
  318. case spv::Op::OpAccessChain:
  319. case spv::Op::OpInBoundsAccessChain:
  320. if (!TryToClampAccessChainIndices(ir_context, &inst)) {
  321. return false;
  322. }
  323. break;
  324. case spv::Op::OpFunctionCall:
  325. // A livesafe function my only call other livesafe functions.
  326. if (!transformation_context.GetFactManager()->FunctionIsLivesafe(
  327. inst.GetSingleWordInOperand(0))) {
  328. return false;
  329. }
  330. default:
  331. break;
  332. }
  333. }
  334. }
  335. return true;
  336. }
  337. uint32_t TransformationAddFunction::GetBackEdgeBlockId(
  338. opt::IRContext* ir_context, uint32_t loop_header_block_id) {
  339. const auto* loop_header_block =
  340. ir_context->cfg()->block(loop_header_block_id);
  341. assert(loop_header_block && "|loop_header_block_id| is invalid");
  342. for (auto pred : ir_context->cfg()->preds(loop_header_block_id)) {
  343. if (ir_context->GetDominatorAnalysis(loop_header_block->GetParent())
  344. ->Dominates(loop_header_block_id, pred)) {
  345. return pred;
  346. }
  347. }
  348. return 0;
  349. }
  350. bool TransformationAddFunction::TryToAddLoopLimiters(
  351. opt::IRContext* ir_context, opt::Function* added_function) const {
  352. // Collect up all the loop headers so that we can subsequently add loop
  353. // limiting logic.
  354. std::vector<opt::BasicBlock*> loop_headers;
  355. for (auto& block : *added_function) {
  356. if (block.IsLoopHeader()) {
  357. loop_headers.push_back(&block);
  358. }
  359. }
  360. if (loop_headers.empty()) {
  361. // There are no loops, so no need to add any loop limiters.
  362. return true;
  363. }
  364. // Check that the module contains appropriate ingredients for declaring and
  365. // manipulating a loop limiter.
  366. auto loop_limit_constant_id_instr =
  367. ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
  368. if (!loop_limit_constant_id_instr ||
  369. loop_limit_constant_id_instr->opcode() != spv::Op::OpConstant) {
  370. // The loop limit constant id instruction must exist and have an
  371. // appropriate opcode.
  372. return false;
  373. }
  374. auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef(
  375. loop_limit_constant_id_instr->type_id());
  376. if (loop_limit_type->opcode() != spv::Op::OpTypeInt ||
  377. loop_limit_type->GetSingleWordInOperand(0) != 32) {
  378. // The type of the loop limit constant must be 32-bit integer. It
  379. // doesn't actually matter whether the integer is signed or not.
  380. return false;
  381. }
  382. // Find the id of the "unsigned int" type.
  383. opt::analysis::Integer unsigned_int_type(32, false);
  384. uint32_t unsigned_int_type_id =
  385. ir_context->get_type_mgr()->GetId(&unsigned_int_type);
  386. if (!unsigned_int_type_id) {
  387. // Unsigned int is not available; we need this type in order to add loop
  388. // limiters.
  389. return false;
  390. }
  391. auto registered_unsigned_int_type =
  392. ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
  393. // Look for 0 of type unsigned int.
  394. opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
  395. {0});
  396. auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero);
  397. if (!registered_zero) {
  398. // We need 0 in order to be able to initialize loop limiters.
  399. return false;
  400. }
  401. uint32_t zero_id = ir_context->get_constant_mgr()
  402. ->GetDefiningInstruction(registered_zero)
  403. ->result_id();
  404. // Look for 1 of type unsigned int.
  405. opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
  406. {1});
  407. auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one);
  408. if (!registered_one) {
  409. // We need 1 in order to be able to increment loop limiters.
  410. return false;
  411. }
  412. uint32_t one_id = ir_context->get_constant_mgr()
  413. ->GetDefiningInstruction(registered_one)
  414. ->result_id();
  415. // Look for pointer-to-unsigned int type.
  416. opt::analysis::Pointer pointer_to_unsigned_int_type(
  417. registered_unsigned_int_type, spv::StorageClass::Function);
  418. uint32_t pointer_to_unsigned_int_type_id =
  419. ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
  420. if (!pointer_to_unsigned_int_type_id) {
  421. // We need pointer-to-unsigned int in order to declare the loop limiter
  422. // variable.
  423. return false;
  424. }
  425. // Look for bool type.
  426. opt::analysis::Bool bool_type;
  427. uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
  428. if (!bool_type_id) {
  429. // We need bool in order to compare the loop limiter's value with the loop
  430. // limit constant.
  431. return false;
  432. }
  433. // Declare the loop limiter variable at the start of the function's entry
  434. // block, via an instruction of the form:
  435. // %loop_limiter_var = spv::Op::OpVariable %ptr_to_uint Function %zero
  436. added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
  437. ir_context, spv::Op::OpVariable, pointer_to_unsigned_int_type_id,
  438. message_.loop_limiter_variable_id(),
  439. opt::Instruction::OperandList({{SPV_OPERAND_TYPE_STORAGE_CLASS,
  440. {uint32_t(spv::StorageClass::Function)}},
  441. {SPV_OPERAND_TYPE_ID, {zero_id}}})));
  442. // Update the module's id bound since we have added the loop limiter
  443. // variable id.
  444. fuzzerutil::UpdateModuleIdBound(ir_context,
  445. message_.loop_limiter_variable_id());
  446. // Consider each loop in turn.
  447. for (auto loop_header : loop_headers) {
  448. // Look for the loop's back-edge block. This is a predecessor of the loop
  449. // header that is dominated by the loop header.
  450. const auto back_edge_block_id =
  451. GetBackEdgeBlockId(ir_context, loop_header->id());
  452. if (!back_edge_block_id) {
  453. // The loop's back-edge block must be unreachable. This means that the
  454. // loop cannot iterate, so there is no need to make it lifesafe; we can
  455. // move on from this loop.
  456. continue;
  457. }
  458. // If the loop's merge block is unreachable, then there are no constraints
  459. // on where the merge block appears in relation to the blocks of the loop.
  460. // This means we need to be careful when adding a branch from the back-edge
  461. // block to the merge block: the branch might make the loop merge reachable,
  462. // and it might then be dominated by the loop header and possibly by other
  463. // blocks in the loop. Since a block needs to appear before those blocks it
  464. // strictly dominates, this could make the module invalid. To avoid this
  465. // problem we bail out in the case where the loop header does not dominate
  466. // the loop merge.
  467. if (!ir_context->GetDominatorAnalysis(added_function)
  468. ->Dominates(loop_header->id(), loop_header->MergeBlockId())) {
  469. return false;
  470. }
  471. // Go through the sequence of loop limiter infos and find the one
  472. // corresponding to this loop.
  473. bool found = false;
  474. protobufs::LoopLimiterInfo loop_limiter_info;
  475. for (auto& info : message_.loop_limiter_info()) {
  476. if (info.loop_header_id() == loop_header->id()) {
  477. loop_limiter_info = info;
  478. found = true;
  479. break;
  480. }
  481. }
  482. if (!found) {
  483. // We don't have loop limiter info for this loop header.
  484. return false;
  485. }
  486. // The back-edge block either has the form:
  487. //
  488. // (1)
  489. //
  490. // %l = OpLabel
  491. // ... instructions ...
  492. // OpBranch %loop_header
  493. //
  494. // (2)
  495. //
  496. // %l = OpLabel
  497. // ... instructions ...
  498. // OpBranchConditional %c %loop_header %loop_merge
  499. //
  500. // (3)
  501. //
  502. // %l = OpLabel
  503. // ... instructions ...
  504. // OpBranchConditional %c %loop_merge %loop_header
  505. //
  506. // We turn these into the following:
  507. //
  508. // (1)
  509. //
  510. // %l = OpLabel
  511. // ... instructions ...
  512. // %t1 = OpLoad %uint32 %loop_limiter
  513. // %t2 = OpIAdd %uint32 %t1 %one
  514. // OpStore %loop_limiter %t2
  515. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  516. // OpBranchConditional %t3 %loop_merge %loop_header
  517. //
  518. // (2)
  519. //
  520. // %l = OpLabel
  521. // ... instructions ...
  522. // %t1 = OpLoad %uint32 %loop_limiter
  523. // %t2 = OpIAdd %uint32 %t1 %one
  524. // OpStore %loop_limiter %t2
  525. // %t3 = OpULessThan %bool %t1 %loop_limit
  526. // %t4 = OpLogicalAnd %bool %c %t3
  527. // OpBranchConditional %t4 %loop_header %loop_merge
  528. //
  529. // (3)
  530. //
  531. // %l = OpLabel
  532. // ... instructions ...
  533. // %t1 = OpLoad %uint32 %loop_limiter
  534. // %t2 = OpIAdd %uint32 %t1 %one
  535. // OpStore %loop_limiter %t2
  536. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  537. // %t4 = OpLogicalOr %bool %c %t3
  538. // OpBranchConditional %t4 %loop_merge %loop_header
  539. auto back_edge_block = ir_context->cfg()->block(back_edge_block_id);
  540. auto back_edge_block_terminator = back_edge_block->terminator();
  541. bool compare_using_greater_than_equal;
  542. if (back_edge_block_terminator->opcode() == spv::Op::OpBranch) {
  543. compare_using_greater_than_equal = true;
  544. } else {
  545. assert(back_edge_block_terminator->opcode() ==
  546. spv::Op::OpBranchConditional);
  547. assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
  548. loop_header->id() &&
  549. back_edge_block_terminator->GetSingleWordInOperand(2) ==
  550. loop_header->MergeBlockId()) ||
  551. (back_edge_block_terminator->GetSingleWordInOperand(2) ==
  552. loop_header->id() &&
  553. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  554. loop_header->MergeBlockId())) &&
  555. "A back edge edge block must branch to"
  556. " either the loop header or merge");
  557. compare_using_greater_than_equal =
  558. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  559. loop_header->MergeBlockId();
  560. }
  561. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  562. // Add a load from the loop limiter variable, of the form:
  563. // %t1 = OpLoad %uint32 %loop_limiter
  564. new_instructions.push_back(MakeUnique<opt::Instruction>(
  565. ir_context, spv::Op::OpLoad, unsigned_int_type_id,
  566. loop_limiter_info.load_id(),
  567. opt::Instruction::OperandList(
  568. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
  569. // Increment the loaded value:
  570. // %t2 = OpIAdd %uint32 %t1 %one
  571. new_instructions.push_back(MakeUnique<opt::Instruction>(
  572. ir_context, spv::Op::OpIAdd, unsigned_int_type_id,
  573. loop_limiter_info.increment_id(),
  574. opt::Instruction::OperandList(
  575. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  576. {SPV_OPERAND_TYPE_ID, {one_id}}})));
  577. // Store the incremented value back to the loop limiter variable:
  578. // OpStore %loop_limiter %t2
  579. new_instructions.push_back(MakeUnique<opt::Instruction>(
  580. ir_context, spv::Op::OpStore, 0, 0,
  581. opt::Instruction::OperandList(
  582. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
  583. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
  584. // Compare the loaded value with the loop limit; either:
  585. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  586. // or
  587. // %t3 = OpULessThan %bool %t1 %loop_limit
  588. new_instructions.push_back(MakeUnique<opt::Instruction>(
  589. ir_context,
  590. compare_using_greater_than_equal ? spv::Op::OpUGreaterThanEqual
  591. : spv::Op::OpULessThan,
  592. bool_type_id, loop_limiter_info.compare_id(),
  593. opt::Instruction::OperandList(
  594. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  595. {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
  596. if (back_edge_block_terminator->opcode() == spv::Op::OpBranchConditional) {
  597. new_instructions.push_back(MakeUnique<opt::Instruction>(
  598. ir_context,
  599. compare_using_greater_than_equal ? spv::Op::OpLogicalOr
  600. : spv::Op::OpLogicalAnd,
  601. bool_type_id, loop_limiter_info.logical_op_id(),
  602. opt::Instruction::OperandList(
  603. {{SPV_OPERAND_TYPE_ID,
  604. {back_edge_block_terminator->GetSingleWordInOperand(0)}},
  605. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
  606. }
  607. // Add the new instructions at the end of the back edge block, before the
  608. // terminator and any loop merge instruction (as the back edge block can
  609. // be the loop header).
  610. if (back_edge_block->GetLoopMergeInst()) {
  611. back_edge_block->GetLoopMergeInst()->InsertBefore(
  612. std::move(new_instructions));
  613. } else {
  614. back_edge_block_terminator->InsertBefore(std::move(new_instructions));
  615. }
  616. if (back_edge_block_terminator->opcode() == spv::Op::OpBranchConditional) {
  617. back_edge_block_terminator->SetInOperand(
  618. 0, {loop_limiter_info.logical_op_id()});
  619. } else {
  620. assert(back_edge_block_terminator->opcode() == spv::Op::OpBranch &&
  621. "Back-edge terminator must be OpBranch or OpBranchConditional");
  622. // Check that, if the merge block starts with OpPhi instructions, suitable
  623. // ids have been provided to give these instructions a value corresponding
  624. // to the new incoming edge from the back edge block.
  625. auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId());
  626. if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block,
  627. merge_block,
  628. loop_limiter_info.phi_id())) {
  629. return false;
  630. }
  631. // Augment OpPhi instructions at the loop merge with the given ids.
  632. uint32_t phi_index = 0;
  633. for (auto& inst : *merge_block) {
  634. if (inst.opcode() != spv::Op::OpPhi) {
  635. break;
  636. }
  637. assert(phi_index <
  638. static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
  639. "There should be at least one phi id per OpPhi instruction.");
  640. inst.AddOperand(
  641. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
  642. inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
  643. phi_index++;
  644. }
  645. // Add the new edge, by changing OpBranch to OpBranchConditional.
  646. back_edge_block_terminator->SetOpcode(spv::Op::OpBranchConditional);
  647. back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
  648. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
  649. {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}},
  650. {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
  651. }
  652. // Update the module's id bound with respect to the various ids that
  653. // have been used for loop limiter manipulation.
  654. fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id());
  655. fuzzerutil::UpdateModuleIdBound(ir_context,
  656. loop_limiter_info.increment_id());
  657. fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id());
  658. fuzzerutil::UpdateModuleIdBound(ir_context,
  659. loop_limiter_info.logical_op_id());
  660. }
  661. return true;
  662. }
  663. bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
  664. opt::IRContext* ir_context, opt::Function* added_function,
  665. opt::Instruction* kill_or_unreachable_inst) const {
  666. assert((kill_or_unreachable_inst->opcode() == spv::Op::OpKill ||
  667. kill_or_unreachable_inst->opcode() == spv::Op::OpUnreachable) &&
  668. "Precondition: instruction must be OpKill or OpUnreachable.");
  669. // Get the function's return type.
  670. auto function_return_type_inst =
  671. ir_context->get_def_use_mgr()->GetDef(added_function->type_id());
  672. if (function_return_type_inst->opcode() == spv::Op::OpTypeVoid) {
  673. // The function has void return type, so change this instruction to
  674. // OpReturn.
  675. kill_or_unreachable_inst->SetOpcode(spv::Op::OpReturn);
  676. } else {
  677. // The function has non-void return type, so change this instruction
  678. // to OpReturnValue, using the value id provided with the
  679. // transformation.
  680. // We first check that the id, %id, provided with the transformation
  681. // specifically to turn OpKill and OpUnreachable instructions into
  682. // OpReturnValue %id has the same type as the function's return type.
  683. if (ir_context->get_def_use_mgr()
  684. ->GetDef(message_.kill_unreachable_return_value_id())
  685. ->type_id() != function_return_type_inst->result_id()) {
  686. return false;
  687. }
  688. kill_or_unreachable_inst->SetOpcode(spv::Op::OpReturnValue);
  689. kill_or_unreachable_inst->SetInOperands(
  690. {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
  691. }
  692. return true;
  693. }
  694. bool TransformationAddFunction::TryToClampAccessChainIndices(
  695. opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const {
  696. assert((access_chain_inst->opcode() == spv::Op::OpAccessChain ||
  697. access_chain_inst->opcode() == spv::Op::OpInBoundsAccessChain) &&
  698. "Precondition: instruction must be OpAccessChain or "
  699. "OpInBoundsAccessChain.");
  700. // Find the AccessChainClampingInfo associated with this access chain.
  701. const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
  702. nullptr;
  703. for (auto& clamping_info : message_.access_chain_clamping_info()) {
  704. if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
  705. access_chain_clamping_info = &clamping_info;
  706. break;
  707. }
  708. }
  709. if (!access_chain_clamping_info) {
  710. // No access chain clamping information was found; the function cannot be
  711. // made livesafe.
  712. return false;
  713. }
  714. // Check that there is a (compare_id, select_id) pair for every
  715. // index associated with the instruction.
  716. if (static_cast<uint32_t>(
  717. access_chain_clamping_info->compare_and_select_ids().size()) !=
  718. access_chain_inst->NumInOperands() - 1) {
  719. return false;
  720. }
  721. // Walk the access chain, clamping each index to be within bounds if it is
  722. // not a constant.
  723. auto base_object = ir_context->get_def_use_mgr()->GetDef(
  724. access_chain_inst->GetSingleWordInOperand(0));
  725. assert(base_object && "The base object must exist.");
  726. auto pointer_type =
  727. ir_context->get_def_use_mgr()->GetDef(base_object->type_id());
  728. assert(pointer_type && pointer_type->opcode() == spv::Op::OpTypePointer &&
  729. "The base object must have pointer type.");
  730. auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef(
  731. pointer_type->GetSingleWordInOperand(1));
  732. // Consider each index input operand in turn (operand 0 is the base object).
  733. for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
  734. index++) {
  735. // We are going to turn:
  736. //
  737. // %result = OpAccessChain %type %object ... %index ...
  738. //
  739. // into:
  740. //
  741. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  742. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  743. // %result = OpAccessChain %type %object ... %t2 ...
  744. //
  745. // ... unless %index is already a constant.
  746. // Get the bound for the composite being indexed into; e.g. the number of
  747. // columns of matrix or the size of an array.
  748. uint32_t bound = fuzzerutil::GetBoundForCompositeIndex(
  749. *should_be_composite_type, ir_context);
  750. // Get the instruction associated with the index and figure out its integer
  751. // type.
  752. const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
  753. auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
  754. auto index_type_inst =
  755. ir_context->get_def_use_mgr()->GetDef(index_inst->type_id());
  756. assert(index_type_inst->opcode() == spv::Op::OpTypeInt);
  757. assert(index_type_inst->GetSingleWordInOperand(0) == 32);
  758. opt::analysis::Integer* index_int_type =
  759. ir_context->get_type_mgr()
  760. ->GetType(index_type_inst->result_id())
  761. ->AsInteger();
  762. if (index_inst->opcode() != spv::Op::OpConstant ||
  763. index_inst->GetSingleWordInOperand(0) >= bound) {
  764. // The index is either non-constant or an out-of-bounds constant, so we
  765. // need to clamp it.
  766. assert(should_be_composite_type->opcode() != spv::Op::OpTypeStruct &&
  767. "Access chain indices into structures are required to be "
  768. "constants.");
  769. opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
  770. if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
  771. // We do not have an integer constant whose value is |bound| -1.
  772. return false;
  773. }
  774. opt::analysis::Bool bool_type;
  775. uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
  776. if (!bool_type_id) {
  777. // Bool type is not declared; we cannot do a comparison.
  778. return false;
  779. }
  780. uint32_t bound_minus_one_id =
  781. ir_context->get_constant_mgr()
  782. ->GetDefiningInstruction(&bound_minus_one)
  783. ->result_id();
  784. uint32_t compare_id =
  785. access_chain_clamping_info->compare_and_select_ids(index - 1).first();
  786. uint32_t select_id =
  787. access_chain_clamping_info->compare_and_select_ids(index - 1)
  788. .second();
  789. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  790. // Compare the index with the bound via an instruction of the form:
  791. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  792. new_instructions.push_back(MakeUnique<opt::Instruction>(
  793. ir_context, spv::Op::OpULessThanEqual, bool_type_id, compare_id,
  794. opt::Instruction::OperandList(
  795. {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  796. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  797. // Select the index if in-bounds, otherwise one less than the bound:
  798. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  799. new_instructions.push_back(MakeUnique<opt::Instruction>(
  800. ir_context, spv::Op::OpSelect, index_type_inst->result_id(),
  801. select_id,
  802. opt::Instruction::OperandList(
  803. {{SPV_OPERAND_TYPE_ID, {compare_id}},
  804. {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  805. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  806. // Add the new instructions before the access chain
  807. access_chain_inst->InsertBefore(std::move(new_instructions));
  808. // Replace %index with %t2.
  809. access_chain_inst->SetInOperand(index, {select_id});
  810. fuzzerutil::UpdateModuleIdBound(ir_context, compare_id);
  811. fuzzerutil::UpdateModuleIdBound(ir_context, select_id);
  812. }
  813. should_be_composite_type =
  814. FollowCompositeIndex(ir_context, *should_be_composite_type, index_id);
  815. }
  816. return true;
  817. }
  818. opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
  819. opt::IRContext* ir_context, const opt::Instruction& composite_type_inst,
  820. uint32_t index_id) {
  821. uint32_t sub_object_type_id;
  822. switch (composite_type_inst.opcode()) {
  823. case spv::Op::OpTypeArray:
  824. case spv::Op::OpTypeRuntimeArray:
  825. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  826. break;
  827. case spv::Op::OpTypeMatrix:
  828. case spv::Op::OpTypeVector:
  829. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  830. break;
  831. case spv::Op::OpTypeStruct: {
  832. auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
  833. assert(index_inst->opcode() == spv::Op::OpConstant);
  834. assert(ir_context->get_def_use_mgr()
  835. ->GetDef(index_inst->type_id())
  836. ->opcode() == spv::Op::OpTypeInt);
  837. assert(ir_context->get_def_use_mgr()
  838. ->GetDef(index_inst->type_id())
  839. ->GetSingleWordInOperand(0) == 32);
  840. uint32_t index_value = index_inst->GetSingleWordInOperand(0);
  841. sub_object_type_id =
  842. composite_type_inst.GetSingleWordInOperand(index_value);
  843. break;
  844. }
  845. default:
  846. assert(false && "Unknown composite type.");
  847. sub_object_type_id = 0;
  848. break;
  849. }
  850. assert(sub_object_type_id && "No sub-object found.");
  851. return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id);
  852. }
  853. std::unordered_set<uint32_t> TransformationAddFunction::GetFreshIds() const {
  854. std::unordered_set<uint32_t> result;
  855. for (auto& instruction : message_.instruction()) {
  856. result.insert(instruction.result_id());
  857. }
  858. if (message_.is_livesafe()) {
  859. result.insert(message_.loop_limiter_variable_id());
  860. for (auto& loop_limiter_info : message_.loop_limiter_info()) {
  861. result.insert(loop_limiter_info.load_id());
  862. result.insert(loop_limiter_info.increment_id());
  863. result.insert(loop_limiter_info.compare_id());
  864. result.insert(loop_limiter_info.logical_op_id());
  865. }
  866. for (auto& access_chain_clamping_info :
  867. message_.access_chain_clamping_info()) {
  868. for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
  869. result.insert(pair.first());
  870. result.insert(pair.second());
  871. }
  872. }
  873. }
  874. return result;
  875. }
  876. } // namespace fuzz
  877. } // namespace spvtools