transformation_add_function.cpp 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_add_function.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_message.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationAddFunction::TransformationAddFunction(
  20. const spvtools::fuzz::protobufs::TransformationAddFunction& message)
  21. : message_(message) {}
  22. TransformationAddFunction::TransformationAddFunction(
  23. const std::vector<protobufs::Instruction>& instructions) {
  24. for (auto& instruction : instructions) {
  25. *message_.add_instruction() = instruction;
  26. }
  27. message_.set_is_livesafe(false);
  28. }
  29. TransformationAddFunction::TransformationAddFunction(
  30. const std::vector<protobufs::Instruction>& instructions,
  31. uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
  32. const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
  33. uint32_t kill_unreachable_return_value_id,
  34. const std::vector<protobufs::AccessChainClampingInfo>&
  35. access_chain_clampers) {
  36. for (auto& instruction : instructions) {
  37. *message_.add_instruction() = instruction;
  38. }
  39. message_.set_is_livesafe(true);
  40. message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
  41. message_.set_loop_limit_constant_id(loop_limit_constant_id);
  42. for (auto& loop_limiter : loop_limiters) {
  43. *message_.add_loop_limiter_info() = loop_limiter;
  44. }
  45. message_.set_kill_unreachable_return_value_id(
  46. kill_unreachable_return_value_id);
  47. for (auto& access_clamper : access_chain_clampers) {
  48. *message_.add_access_chain_clamping_info() = access_clamper;
  49. }
  50. }
  51. bool TransformationAddFunction::IsApplicable(
  52. opt::IRContext* ir_context,
  53. const TransformationContext& transformation_context) const {
  54. // This transformation may use a lot of ids, all of which need to be fresh
  55. // and distinct. This set tracks them.
  56. std::set<uint32_t> ids_used_by_this_transformation;
  57. // Ensure that all result ids in the new function are fresh and distinct.
  58. for (auto& instruction : message_.instruction()) {
  59. if (instruction.result_id()) {
  60. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  61. instruction.result_id(), ir_context,
  62. &ids_used_by_this_transformation)) {
  63. return false;
  64. }
  65. }
  66. }
  67. if (message_.is_livesafe()) {
  68. // Ensure that all ids provided for making the function livesafe are fresh
  69. // and distinct.
  70. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  71. message_.loop_limiter_variable_id(), ir_context,
  72. &ids_used_by_this_transformation)) {
  73. return false;
  74. }
  75. for (auto& loop_limiter_info : message_.loop_limiter_info()) {
  76. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  77. loop_limiter_info.load_id(), ir_context,
  78. &ids_used_by_this_transformation)) {
  79. return false;
  80. }
  81. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  82. loop_limiter_info.increment_id(), ir_context,
  83. &ids_used_by_this_transformation)) {
  84. return false;
  85. }
  86. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  87. loop_limiter_info.compare_id(), ir_context,
  88. &ids_used_by_this_transformation)) {
  89. return false;
  90. }
  91. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  92. loop_limiter_info.logical_op_id(), ir_context,
  93. &ids_used_by_this_transformation)) {
  94. return false;
  95. }
  96. }
  97. for (auto& access_chain_clamping_info :
  98. message_.access_chain_clamping_info()) {
  99. for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
  100. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  101. pair.first(), ir_context, &ids_used_by_this_transformation)) {
  102. return false;
  103. }
  104. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  105. pair.second(), ir_context, &ids_used_by_this_transformation)) {
  106. return false;
  107. }
  108. }
  109. }
  110. }
  111. // Because checking all the conditions for a function to be valid is a big
  112. // job that the SPIR-V validator can already do, a "try it and see" approach
  113. // is taken here.
  114. // We first clone the current module, so that we can try adding the new
  115. // function without risking wrecking |ir_context|.
  116. auto cloned_module = fuzzerutil::CloneIRContext(ir_context);
  117. // We try to add a function to the cloned module, which may fail if
  118. // |message_.instruction| is not sufficiently well-formed.
  119. if (!TryToAddFunction(cloned_module.get())) {
  120. return false;
  121. }
  122. // Check whether the cloned module is still valid after adding the function.
  123. // If it is not, the transformation is not applicable.
  124. if (!fuzzerutil::IsValid(cloned_module.get(),
  125. transformation_context.GetValidatorOptions(),
  126. fuzzerutil::kSilentMessageConsumer)) {
  127. return false;
  128. }
  129. if (message_.is_livesafe()) {
  130. if (!TryToMakeFunctionLivesafe(cloned_module.get(),
  131. transformation_context)) {
  132. return false;
  133. }
  134. // After making the function livesafe, we check validity of the module
  135. // again. This is because the turning of OpKill, OpUnreachable and OpReturn
  136. // instructions into branches changes control flow graph reachability, which
  137. // has the potential to make the module invalid when it was otherwise valid.
  138. // It is simpler to rely on the validator to guard against this than to
  139. // consider all scenarios when making a function livesafe.
  140. if (!fuzzerutil::IsValid(cloned_module.get(),
  141. transformation_context.GetValidatorOptions(),
  142. fuzzerutil::kSilentMessageConsumer)) {
  143. return false;
  144. }
  145. }
  146. return true;
  147. }
  148. void TransformationAddFunction::Apply(
  149. opt::IRContext* ir_context,
  150. TransformationContext* transformation_context) const {
  151. // Add the function to the module. As the transformation is applicable, this
  152. // should succeed.
  153. bool success = TryToAddFunction(ir_context);
  154. assert(success && "The function should be successfully added.");
  155. (void)(success); // Keep release builds happy (otherwise they may complain
  156. // that |success| is not used).
  157. if (message_.is_livesafe()) {
  158. // Make the function livesafe, which also should succeed.
  159. success = TryToMakeFunctionLivesafe(ir_context, *transformation_context);
  160. assert(success && "It should be possible to make the function livesafe.");
  161. (void)(success); // Keep release builds happy.
  162. }
  163. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  164. assert(message_.instruction(0).opcode() == SpvOpFunction &&
  165. "The first instruction of an 'add function' transformation must be "
  166. "OpFunction.");
  167. if (message_.is_livesafe()) {
  168. // Inform the fact manager that the function is livesafe.
  169. transformation_context->GetFactManager()->AddFactFunctionIsLivesafe(
  170. message_.instruction(0).result_id());
  171. } else {
  172. // Inform the fact manager that all blocks in the function are dead.
  173. for (auto& inst : message_.instruction()) {
  174. if (inst.opcode() == SpvOpLabel) {
  175. transformation_context->GetFactManager()->AddFactBlockIsDead(
  176. inst.result_id());
  177. }
  178. }
  179. }
  180. // Record the fact that all pointer parameters and variables declared in the
  181. // function should be regarded as having irrelevant values. This allows other
  182. // passes to store arbitrarily to such variables, and to pass them freely as
  183. // parameters to other functions knowing that it is OK if they get
  184. // over-written.
  185. for (auto& instruction : message_.instruction()) {
  186. switch (instruction.opcode()) {
  187. case SpvOpFunctionParameter:
  188. if (ir_context->get_def_use_mgr()
  189. ->GetDef(instruction.result_type_id())
  190. ->opcode() == SpvOpTypePointer) {
  191. transformation_context->GetFactManager()
  192. ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
  193. }
  194. break;
  195. case SpvOpVariable:
  196. transformation_context->GetFactManager()
  197. ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
  198. break;
  199. default:
  200. break;
  201. }
  202. }
  203. }
  204. protobufs::Transformation TransformationAddFunction::ToMessage() const {
  205. protobufs::Transformation result;
  206. *result.mutable_add_function() = message_;
  207. return result;
  208. }
  209. bool TransformationAddFunction::TryToAddFunction(
  210. opt::IRContext* ir_context) const {
  211. // This function returns false if |message_.instruction| was not well-formed
  212. // enough to actually create a function and add it to |ir_context|.
  213. // A function must have at least some instructions.
  214. if (message_.instruction().empty()) {
  215. return false;
  216. }
  217. // A function must start with OpFunction.
  218. auto function_begin = message_.instruction(0);
  219. if (function_begin.opcode() != SpvOpFunction) {
  220. return false;
  221. }
  222. // Make a function, headed by the OpFunction instruction.
  223. std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
  224. InstructionFromMessage(ir_context, function_begin));
  225. // Keeps track of which instruction protobuf message we are currently
  226. // considering.
  227. uint32_t instruction_index = 1;
  228. const auto num_instructions =
  229. static_cast<uint32_t>(message_.instruction().size());
  230. // Iterate through all function parameter instructions, adding parameters to
  231. // the new function.
  232. while (instruction_index < num_instructions &&
  233. message_.instruction(instruction_index).opcode() ==
  234. SpvOpFunctionParameter) {
  235. new_function->AddParameter(InstructionFromMessage(
  236. ir_context, message_.instruction(instruction_index)));
  237. instruction_index++;
  238. }
  239. // After the parameters, there needs to be a label.
  240. if (instruction_index == num_instructions ||
  241. message_.instruction(instruction_index).opcode() != SpvOpLabel) {
  242. return false;
  243. }
  244. // Iterate through the instructions block by block until the end of the
  245. // function is reached.
  246. while (instruction_index < num_instructions &&
  247. message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
  248. // Invariant: we should always be at a label instruction at this point.
  249. assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);
  250. // Make a basic block using the label instruction.
  251. std::unique_ptr<opt::BasicBlock> block =
  252. MakeUnique<opt::BasicBlock>(InstructionFromMessage(
  253. ir_context, message_.instruction(instruction_index)));
  254. // Consider successive instructions until we hit another label or the end
  255. // of the function, adding each such instruction to the block.
  256. instruction_index++;
  257. while (instruction_index < num_instructions &&
  258. message_.instruction(instruction_index).opcode() !=
  259. SpvOpFunctionEnd &&
  260. message_.instruction(instruction_index).opcode() != SpvOpLabel) {
  261. block->AddInstruction(InstructionFromMessage(
  262. ir_context, message_.instruction(instruction_index)));
  263. instruction_index++;
  264. }
  265. // Add the block to the new function.
  266. new_function->AddBasicBlock(std::move(block));
  267. }
  268. // Having considered all the blocks, we should be at the last instruction and
  269. // it needs to be OpFunctionEnd.
  270. if (instruction_index != num_instructions - 1 ||
  271. message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
  272. return false;
  273. }
  274. // Set the function's final instruction, add the function to the module and
  275. // report success.
  276. new_function->SetFunctionEnd(InstructionFromMessage(
  277. ir_context, message_.instruction(instruction_index)));
  278. ir_context->AddFunction(std::move(new_function));
  279. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  280. return true;
  281. }
  282. bool TransformationAddFunction::TryToMakeFunctionLivesafe(
  283. opt::IRContext* ir_context,
  284. const TransformationContext& transformation_context) const {
  285. assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
  286. // Get a pointer to the added function.
  287. opt::Function* added_function = nullptr;
  288. for (auto& function : *ir_context->module()) {
  289. if (function.result_id() == message_.instruction(0).result_id()) {
  290. added_function = &function;
  291. break;
  292. }
  293. }
  294. assert(added_function && "The added function should have been found.");
  295. if (!TryToAddLoopLimiters(ir_context, added_function)) {
  296. // Adding loop limiters did not work; bail out.
  297. return false;
  298. }
  299. // Consider all the instructions in the function, and:
  300. // - attempt to replace OpKill and OpUnreachable with return instructions
  301. // - attempt to clamp access chains to be within bounds
  302. // - check that OpFunctionCall instructions are only to livesafe functions
  303. for (auto& block : *added_function) {
  304. for (auto& inst : block) {
  305. switch (inst.opcode()) {
  306. case SpvOpKill:
  307. case SpvOpUnreachable:
  308. if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function,
  309. &inst)) {
  310. return false;
  311. }
  312. break;
  313. case SpvOpAccessChain:
  314. case SpvOpInBoundsAccessChain:
  315. if (!TryToClampAccessChainIndices(ir_context, &inst)) {
  316. return false;
  317. }
  318. break;
  319. case SpvOpFunctionCall:
  320. // A livesafe function my only call other livesafe functions.
  321. if (!transformation_context.GetFactManager()->FunctionIsLivesafe(
  322. inst.GetSingleWordInOperand(0))) {
  323. return false;
  324. }
  325. default:
  326. break;
  327. }
  328. }
  329. }
  330. return true;
  331. }
  332. uint32_t TransformationAddFunction::GetBackEdgeBlockId(
  333. opt::IRContext* ir_context, uint32_t loop_header_block_id) {
  334. const auto* loop_header_block =
  335. ir_context->cfg()->block(loop_header_block_id);
  336. assert(loop_header_block && "|loop_header_block_id| is invalid");
  337. for (auto pred : ir_context->cfg()->preds(loop_header_block_id)) {
  338. if (ir_context->GetDominatorAnalysis(loop_header_block->GetParent())
  339. ->Dominates(loop_header_block_id, pred)) {
  340. return pred;
  341. }
  342. }
  343. return 0;
  344. }
  345. bool TransformationAddFunction::TryToAddLoopLimiters(
  346. opt::IRContext* ir_context, opt::Function* added_function) const {
  347. // Collect up all the loop headers so that we can subsequently add loop
  348. // limiting logic.
  349. std::vector<opt::BasicBlock*> loop_headers;
  350. for (auto& block : *added_function) {
  351. if (block.IsLoopHeader()) {
  352. loop_headers.push_back(&block);
  353. }
  354. }
  355. if (loop_headers.empty()) {
  356. // There are no loops, so no need to add any loop limiters.
  357. return true;
  358. }
  359. // Check that the module contains appropriate ingredients for declaring and
  360. // manipulating a loop limiter.
  361. auto loop_limit_constant_id_instr =
  362. ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
  363. if (!loop_limit_constant_id_instr ||
  364. loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
  365. // The loop limit constant id instruction must exist and have an
  366. // appropriate opcode.
  367. return false;
  368. }
  369. auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef(
  370. loop_limit_constant_id_instr->type_id());
  371. if (loop_limit_type->opcode() != SpvOpTypeInt ||
  372. loop_limit_type->GetSingleWordInOperand(0) != 32) {
  373. // The type of the loop limit constant must be 32-bit integer. It
  374. // doesn't actually matter whether the integer is signed or not.
  375. return false;
  376. }
  377. // Find the id of the "unsigned int" type.
  378. opt::analysis::Integer unsigned_int_type(32, false);
  379. uint32_t unsigned_int_type_id =
  380. ir_context->get_type_mgr()->GetId(&unsigned_int_type);
  381. if (!unsigned_int_type_id) {
  382. // Unsigned int is not available; we need this type in order to add loop
  383. // limiters.
  384. return false;
  385. }
  386. auto registered_unsigned_int_type =
  387. ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
  388. // Look for 0 of type unsigned int.
  389. opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
  390. {0});
  391. auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero);
  392. if (!registered_zero) {
  393. // We need 0 in order to be able to initialize loop limiters.
  394. return false;
  395. }
  396. uint32_t zero_id = ir_context->get_constant_mgr()
  397. ->GetDefiningInstruction(registered_zero)
  398. ->result_id();
  399. // Look for 1 of type unsigned int.
  400. opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
  401. {1});
  402. auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one);
  403. if (!registered_one) {
  404. // We need 1 in order to be able to increment loop limiters.
  405. return false;
  406. }
  407. uint32_t one_id = ir_context->get_constant_mgr()
  408. ->GetDefiningInstruction(registered_one)
  409. ->result_id();
  410. // Look for pointer-to-unsigned int type.
  411. opt::analysis::Pointer pointer_to_unsigned_int_type(
  412. registered_unsigned_int_type, SpvStorageClassFunction);
  413. uint32_t pointer_to_unsigned_int_type_id =
  414. ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
  415. if (!pointer_to_unsigned_int_type_id) {
  416. // We need pointer-to-unsigned int in order to declare the loop limiter
  417. // variable.
  418. return false;
  419. }
  420. // Look for bool type.
  421. opt::analysis::Bool bool_type;
  422. uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
  423. if (!bool_type_id) {
  424. // We need bool in order to compare the loop limiter's value with the loop
  425. // limit constant.
  426. return false;
  427. }
  428. // Declare the loop limiter variable at the start of the function's entry
  429. // block, via an instruction of the form:
  430. // %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
  431. added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
  432. ir_context, SpvOpVariable, pointer_to_unsigned_int_type_id,
  433. message_.loop_limiter_variable_id(),
  434. opt::Instruction::OperandList(
  435. {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
  436. {SPV_OPERAND_TYPE_ID, {zero_id}}})));
  437. // Update the module's id bound since we have added the loop limiter
  438. // variable id.
  439. fuzzerutil::UpdateModuleIdBound(ir_context,
  440. message_.loop_limiter_variable_id());
  441. // Consider each loop in turn.
  442. for (auto loop_header : loop_headers) {
  443. // Look for the loop's back-edge block. This is a predecessor of the loop
  444. // header that is dominated by the loop header.
  445. const auto back_edge_block_id =
  446. GetBackEdgeBlockId(ir_context, loop_header->id());
  447. if (!back_edge_block_id) {
  448. // The loop's back-edge block must be unreachable. This means that the
  449. // loop cannot iterate, so there is no need to make it lifesafe; we can
  450. // move on from this loop.
  451. continue;
  452. }
  453. // If the loop's merge block is unreachable, then there are no constraints
  454. // on where the merge block appears in relation to the blocks of the loop.
  455. // This means we need to be careful when adding a branch from the back-edge
  456. // block to the merge block: the branch might make the loop merge reachable,
  457. // and it might then be dominated by the loop header and possibly by other
  458. // blocks in the loop. Since a block needs to appear before those blocks it
  459. // strictly dominates, this could make the module invalid. To avoid this
  460. // problem we bail out in the case where the loop header does not dominate
  461. // the loop merge.
  462. if (!ir_context->GetDominatorAnalysis(added_function)
  463. ->Dominates(loop_header->id(), loop_header->MergeBlockId())) {
  464. return false;
  465. }
  466. // Go through the sequence of loop limiter infos and find the one
  467. // corresponding to this loop.
  468. bool found = false;
  469. protobufs::LoopLimiterInfo loop_limiter_info;
  470. for (auto& info : message_.loop_limiter_info()) {
  471. if (info.loop_header_id() == loop_header->id()) {
  472. loop_limiter_info = info;
  473. found = true;
  474. break;
  475. }
  476. }
  477. if (!found) {
  478. // We don't have loop limiter info for this loop header.
  479. return false;
  480. }
  481. // The back-edge block either has the form:
  482. //
  483. // (1)
  484. //
  485. // %l = OpLabel
  486. // ... instructions ...
  487. // OpBranch %loop_header
  488. //
  489. // (2)
  490. //
  491. // %l = OpLabel
  492. // ... instructions ...
  493. // OpBranchConditional %c %loop_header %loop_merge
  494. //
  495. // (3)
  496. //
  497. // %l = OpLabel
  498. // ... instructions ...
  499. // OpBranchConditional %c %loop_merge %loop_header
  500. //
  501. // We turn these into the following:
  502. //
  503. // (1)
  504. //
  505. // %l = OpLabel
  506. // ... instructions ...
  507. // %t1 = OpLoad %uint32 %loop_limiter
  508. // %t2 = OpIAdd %uint32 %t1 %one
  509. // OpStore %loop_limiter %t2
  510. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  511. // OpBranchConditional %t3 %loop_merge %loop_header
  512. //
  513. // (2)
  514. //
  515. // %l = OpLabel
  516. // ... instructions ...
  517. // %t1 = OpLoad %uint32 %loop_limiter
  518. // %t2 = OpIAdd %uint32 %t1 %one
  519. // OpStore %loop_limiter %t2
  520. // %t3 = OpULessThan %bool %t1 %loop_limit
  521. // %t4 = OpLogicalAnd %bool %c %t3
  522. // OpBranchConditional %t4 %loop_header %loop_merge
  523. //
  524. // (3)
  525. //
  526. // %l = OpLabel
  527. // ... instructions ...
  528. // %t1 = OpLoad %uint32 %loop_limiter
  529. // %t2 = OpIAdd %uint32 %t1 %one
  530. // OpStore %loop_limiter %t2
  531. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  532. // %t4 = OpLogicalOr %bool %c %t3
  533. // OpBranchConditional %t4 %loop_merge %loop_header
  534. auto back_edge_block = ir_context->cfg()->block(back_edge_block_id);
  535. auto back_edge_block_terminator = back_edge_block->terminator();
  536. bool compare_using_greater_than_equal;
  537. if (back_edge_block_terminator->opcode() == SpvOpBranch) {
  538. compare_using_greater_than_equal = true;
  539. } else {
  540. assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
  541. assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
  542. loop_header->id() &&
  543. back_edge_block_terminator->GetSingleWordInOperand(2) ==
  544. loop_header->MergeBlockId()) ||
  545. (back_edge_block_terminator->GetSingleWordInOperand(2) ==
  546. loop_header->id() &&
  547. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  548. loop_header->MergeBlockId())) &&
  549. "A back edge edge block must branch to"
  550. " either the loop header or merge");
  551. compare_using_greater_than_equal =
  552. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  553. loop_header->MergeBlockId();
  554. }
  555. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  556. // Add a load from the loop limiter variable, of the form:
  557. // %t1 = OpLoad %uint32 %loop_limiter
  558. new_instructions.push_back(MakeUnique<opt::Instruction>(
  559. ir_context, SpvOpLoad, unsigned_int_type_id,
  560. loop_limiter_info.load_id(),
  561. opt::Instruction::OperandList(
  562. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
  563. // Increment the loaded value:
  564. // %t2 = OpIAdd %uint32 %t1 %one
  565. new_instructions.push_back(MakeUnique<opt::Instruction>(
  566. ir_context, SpvOpIAdd, unsigned_int_type_id,
  567. loop_limiter_info.increment_id(),
  568. opt::Instruction::OperandList(
  569. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  570. {SPV_OPERAND_TYPE_ID, {one_id}}})));
  571. // Store the incremented value back to the loop limiter variable:
  572. // OpStore %loop_limiter %t2
  573. new_instructions.push_back(MakeUnique<opt::Instruction>(
  574. ir_context, SpvOpStore, 0, 0,
  575. opt::Instruction::OperandList(
  576. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
  577. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
  578. // Compare the loaded value with the loop limit; either:
  579. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  580. // or
  581. // %t3 = OpULessThan %bool %t1 %loop_limit
  582. new_instructions.push_back(MakeUnique<opt::Instruction>(
  583. ir_context,
  584. compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
  585. : SpvOpULessThan,
  586. bool_type_id, loop_limiter_info.compare_id(),
  587. opt::Instruction::OperandList(
  588. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  589. {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
  590. if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
  591. new_instructions.push_back(MakeUnique<opt::Instruction>(
  592. ir_context,
  593. compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
  594. bool_type_id, loop_limiter_info.logical_op_id(),
  595. opt::Instruction::OperandList(
  596. {{SPV_OPERAND_TYPE_ID,
  597. {back_edge_block_terminator->GetSingleWordInOperand(0)}},
  598. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
  599. }
  600. // Add the new instructions at the end of the back edge block, before the
  601. // terminator and any loop merge instruction (as the back edge block can
  602. // be the loop header).
  603. if (back_edge_block->GetLoopMergeInst()) {
  604. back_edge_block->GetLoopMergeInst()->InsertBefore(
  605. std::move(new_instructions));
  606. } else {
  607. back_edge_block_terminator->InsertBefore(std::move(new_instructions));
  608. }
  609. if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
  610. back_edge_block_terminator->SetInOperand(
  611. 0, {loop_limiter_info.logical_op_id()});
  612. } else {
  613. assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
  614. "Back-edge terminator must be OpBranch or OpBranchConditional");
  615. // Check that, if the merge block starts with OpPhi instructions, suitable
  616. // ids have been provided to give these instructions a value corresponding
  617. // to the new incoming edge from the back edge block.
  618. auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId());
  619. if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block,
  620. merge_block,
  621. loop_limiter_info.phi_id())) {
  622. return false;
  623. }
  624. // Augment OpPhi instructions at the loop merge with the given ids.
  625. uint32_t phi_index = 0;
  626. for (auto& inst : *merge_block) {
  627. if (inst.opcode() != SpvOpPhi) {
  628. break;
  629. }
  630. assert(phi_index <
  631. static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
  632. "There should be at least one phi id per OpPhi instruction.");
  633. inst.AddOperand(
  634. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
  635. inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
  636. phi_index++;
  637. }
  638. // Add the new edge, by changing OpBranch to OpBranchConditional.
  639. back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
  640. back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
  641. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
  642. {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}},
  643. {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
  644. }
  645. // Update the module's id bound with respect to the various ids that
  646. // have been used for loop limiter manipulation.
  647. fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id());
  648. fuzzerutil::UpdateModuleIdBound(ir_context,
  649. loop_limiter_info.increment_id());
  650. fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id());
  651. fuzzerutil::UpdateModuleIdBound(ir_context,
  652. loop_limiter_info.logical_op_id());
  653. }
  654. return true;
  655. }
  656. bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
  657. opt::IRContext* ir_context, opt::Function* added_function,
  658. opt::Instruction* kill_or_unreachable_inst) const {
  659. assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
  660. kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
  661. "Precondition: instruction must be OpKill or OpUnreachable.");
  662. // Get the function's return type.
  663. auto function_return_type_inst =
  664. ir_context->get_def_use_mgr()->GetDef(added_function->type_id());
  665. if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
  666. // The function has void return type, so change this instruction to
  667. // OpReturn.
  668. kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
  669. } else {
  670. // The function has non-void return type, so change this instruction
  671. // to OpReturnValue, using the value id provided with the
  672. // transformation.
  673. // We first check that the id, %id, provided with the transformation
  674. // specifically to turn OpKill and OpUnreachable instructions into
  675. // OpReturnValue %id has the same type as the function's return type.
  676. if (ir_context->get_def_use_mgr()
  677. ->GetDef(message_.kill_unreachable_return_value_id())
  678. ->type_id() != function_return_type_inst->result_id()) {
  679. return false;
  680. }
  681. kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
  682. kill_or_unreachable_inst->SetInOperands(
  683. {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
  684. }
  685. return true;
  686. }
  687. bool TransformationAddFunction::TryToClampAccessChainIndices(
  688. opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const {
  689. assert((access_chain_inst->opcode() == SpvOpAccessChain ||
  690. access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
  691. "Precondition: instruction must be OpAccessChain or "
  692. "OpInBoundsAccessChain.");
  693. // Find the AccessChainClampingInfo associated with this access chain.
  694. const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
  695. nullptr;
  696. for (auto& clamping_info : message_.access_chain_clamping_info()) {
  697. if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
  698. access_chain_clamping_info = &clamping_info;
  699. break;
  700. }
  701. }
  702. if (!access_chain_clamping_info) {
  703. // No access chain clamping information was found; the function cannot be
  704. // made livesafe.
  705. return false;
  706. }
  707. // Check that there is a (compare_id, select_id) pair for every
  708. // index associated with the instruction.
  709. if (static_cast<uint32_t>(
  710. access_chain_clamping_info->compare_and_select_ids().size()) !=
  711. access_chain_inst->NumInOperands() - 1) {
  712. return false;
  713. }
  714. // Walk the access chain, clamping each index to be within bounds if it is
  715. // not a constant.
  716. auto base_object = ir_context->get_def_use_mgr()->GetDef(
  717. access_chain_inst->GetSingleWordInOperand(0));
  718. assert(base_object && "The base object must exist.");
  719. auto pointer_type =
  720. ir_context->get_def_use_mgr()->GetDef(base_object->type_id());
  721. assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
  722. "The base object must have pointer type.");
  723. auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef(
  724. pointer_type->GetSingleWordInOperand(1));
  725. // Consider each index input operand in turn (operand 0 is the base object).
  726. for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
  727. index++) {
  728. // We are going to turn:
  729. //
  730. // %result = OpAccessChain %type %object ... %index ...
  731. //
  732. // into:
  733. //
  734. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  735. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  736. // %result = OpAccessChain %type %object ... %t2 ...
  737. //
  738. // ... unless %index is already a constant.
  739. // Get the bound for the composite being indexed into; e.g. the number of
  740. // columns of matrix or the size of an array.
  741. uint32_t bound = fuzzerutil::GetBoundForCompositeIndex(
  742. *should_be_composite_type, ir_context);
  743. // Get the instruction associated with the index and figure out its integer
  744. // type.
  745. const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
  746. auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
  747. auto index_type_inst =
  748. ir_context->get_def_use_mgr()->GetDef(index_inst->type_id());
  749. assert(index_type_inst->opcode() == SpvOpTypeInt);
  750. assert(index_type_inst->GetSingleWordInOperand(0) == 32);
  751. opt::analysis::Integer* index_int_type =
  752. ir_context->get_type_mgr()
  753. ->GetType(index_type_inst->result_id())
  754. ->AsInteger();
  755. if (index_inst->opcode() != SpvOpConstant ||
  756. index_inst->GetSingleWordInOperand(0) >= bound) {
  757. // The index is either non-constant or an out-of-bounds constant, so we
  758. // need to clamp it.
  759. assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
  760. "Access chain indices into structures are required to be "
  761. "constants.");
  762. opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
  763. if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
  764. // We do not have an integer constant whose value is |bound| -1.
  765. return false;
  766. }
  767. opt::analysis::Bool bool_type;
  768. uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
  769. if (!bool_type_id) {
  770. // Bool type is not declared; we cannot do a comparison.
  771. return false;
  772. }
  773. uint32_t bound_minus_one_id =
  774. ir_context->get_constant_mgr()
  775. ->GetDefiningInstruction(&bound_minus_one)
  776. ->result_id();
  777. uint32_t compare_id =
  778. access_chain_clamping_info->compare_and_select_ids(index - 1).first();
  779. uint32_t select_id =
  780. access_chain_clamping_info->compare_and_select_ids(index - 1)
  781. .second();
  782. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  783. // Compare the index with the bound via an instruction of the form:
  784. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  785. new_instructions.push_back(MakeUnique<opt::Instruction>(
  786. ir_context, SpvOpULessThanEqual, bool_type_id, compare_id,
  787. opt::Instruction::OperandList(
  788. {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  789. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  790. // Select the index if in-bounds, otherwise one less than the bound:
  791. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  792. new_instructions.push_back(MakeUnique<opt::Instruction>(
  793. ir_context, SpvOpSelect, index_type_inst->result_id(), select_id,
  794. opt::Instruction::OperandList(
  795. {{SPV_OPERAND_TYPE_ID, {compare_id}},
  796. {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  797. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  798. // Add the new instructions before the access chain
  799. access_chain_inst->InsertBefore(std::move(new_instructions));
  800. // Replace %index with %t2.
  801. access_chain_inst->SetInOperand(index, {select_id});
  802. fuzzerutil::UpdateModuleIdBound(ir_context, compare_id);
  803. fuzzerutil::UpdateModuleIdBound(ir_context, select_id);
  804. }
  805. should_be_composite_type =
  806. FollowCompositeIndex(ir_context, *should_be_composite_type, index_id);
  807. }
  808. return true;
  809. }
  810. opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
  811. opt::IRContext* ir_context, const opt::Instruction& composite_type_inst,
  812. uint32_t index_id) {
  813. uint32_t sub_object_type_id;
  814. switch (composite_type_inst.opcode()) {
  815. case SpvOpTypeArray:
  816. case SpvOpTypeRuntimeArray:
  817. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  818. break;
  819. case SpvOpTypeMatrix:
  820. case SpvOpTypeVector:
  821. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  822. break;
  823. case SpvOpTypeStruct: {
  824. auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
  825. assert(index_inst->opcode() == SpvOpConstant);
  826. assert(ir_context->get_def_use_mgr()
  827. ->GetDef(index_inst->type_id())
  828. ->opcode() == SpvOpTypeInt);
  829. assert(ir_context->get_def_use_mgr()
  830. ->GetDef(index_inst->type_id())
  831. ->GetSingleWordInOperand(0) == 32);
  832. uint32_t index_value = index_inst->GetSingleWordInOperand(0);
  833. sub_object_type_id =
  834. composite_type_inst.GetSingleWordInOperand(index_value);
  835. break;
  836. }
  837. default:
  838. assert(false && "Unknown composite type.");
  839. sub_object_type_id = 0;
  840. break;
  841. }
  842. assert(sub_object_type_id && "No sub-object found.");
  843. return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id);
  844. }
  845. std::unordered_set<uint32_t> TransformationAddFunction::GetFreshIds() const {
  846. std::unordered_set<uint32_t> result;
  847. for (auto& instruction : message_.instruction()) {
  848. result.insert(instruction.result_id());
  849. }
  850. if (message_.is_livesafe()) {
  851. result.insert(message_.loop_limiter_variable_id());
  852. for (auto& loop_limiter_info : message_.loop_limiter_info()) {
  853. result.insert(loop_limiter_info.load_id());
  854. result.insert(loop_limiter_info.increment_id());
  855. result.insert(loop_limiter_info.compare_id());
  856. result.insert(loop_limiter_info.logical_op_id());
  857. }
  858. for (auto& access_chain_clamping_info :
  859. message_.access_chain_clamping_info()) {
  860. for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
  861. result.insert(pair.first());
  862. result.insert(pair.second());
  863. }
  864. }
  865. }
  866. return result;
  867. }
  868. } // namespace fuzz
  869. } // namespace spvtools