transformation_add_function.cpp 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_add_function.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_message.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationAddFunction::TransformationAddFunction(
  20. const spvtools::fuzz::protobufs::TransformationAddFunction& message)
  21. : message_(message) {}
  22. TransformationAddFunction::TransformationAddFunction(
  23. const std::vector<protobufs::Instruction>& instructions) {
  24. for (auto& instruction : instructions) {
  25. *message_.add_instruction() = instruction;
  26. }
  27. message_.set_is_livesafe(false);
  28. }
  29. TransformationAddFunction::TransformationAddFunction(
  30. const std::vector<protobufs::Instruction>& instructions,
  31. uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
  32. const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
  33. uint32_t kill_unreachable_return_value_id,
  34. const std::vector<protobufs::AccessChainClampingInfo>&
  35. access_chain_clampers) {
  36. for (auto& instruction : instructions) {
  37. *message_.add_instruction() = instruction;
  38. }
  39. message_.set_is_livesafe(true);
  40. message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
  41. message_.set_loop_limit_constant_id(loop_limit_constant_id);
  42. for (auto& loop_limiter : loop_limiters) {
  43. *message_.add_loop_limiter_info() = loop_limiter;
  44. }
  45. message_.set_kill_unreachable_return_value_id(
  46. kill_unreachable_return_value_id);
  47. for (auto& access_clamper : access_chain_clampers) {
  48. *message_.add_access_chain_clamping_info() = access_clamper;
  49. }
  50. }
  51. bool TransformationAddFunction::IsApplicable(
  52. opt::IRContext* ir_context,
  53. const TransformationContext& transformation_context) const {
  54. // This transformation may use a lot of ids, all of which need to be fresh
  55. // and distinct. This set tracks them.
  56. std::set<uint32_t> ids_used_by_this_transformation;
  57. // Ensure that all result ids in the new function are fresh and distinct.
  58. for (auto& instruction : message_.instruction()) {
  59. if (instruction.result_id()) {
  60. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  61. instruction.result_id(), ir_context,
  62. &ids_used_by_this_transformation)) {
  63. return false;
  64. }
  65. }
  66. }
  67. if (message_.is_livesafe()) {
  68. // Ensure that all ids provided for making the function livesafe are fresh
  69. // and distinct.
  70. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  71. message_.loop_limiter_variable_id(), ir_context,
  72. &ids_used_by_this_transformation)) {
  73. return false;
  74. }
  75. for (auto& loop_limiter_info : message_.loop_limiter_info()) {
  76. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  77. loop_limiter_info.load_id(), ir_context,
  78. &ids_used_by_this_transformation)) {
  79. return false;
  80. }
  81. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  82. loop_limiter_info.increment_id(), ir_context,
  83. &ids_used_by_this_transformation)) {
  84. return false;
  85. }
  86. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  87. loop_limiter_info.compare_id(), ir_context,
  88. &ids_used_by_this_transformation)) {
  89. return false;
  90. }
  91. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  92. loop_limiter_info.logical_op_id(), ir_context,
  93. &ids_used_by_this_transformation)) {
  94. return false;
  95. }
  96. }
  97. for (auto& access_chain_clamping_info :
  98. message_.access_chain_clamping_info()) {
  99. for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
  100. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  101. pair.first(), ir_context, &ids_used_by_this_transformation)) {
  102. return false;
  103. }
  104. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  105. pair.second(), ir_context, &ids_used_by_this_transformation)) {
  106. return false;
  107. }
  108. }
  109. }
  110. }
  111. // Because checking all the conditions for a function to be valid is a big
  112. // job that the SPIR-V validator can already do, a "try it and see" approach
  113. // is taken here.
  114. // We first clone the current module, so that we can try adding the new
  115. // function without risking wrecking |ir_context|.
  116. auto cloned_module = fuzzerutil::CloneIRContext(ir_context);
  117. // We try to add a function to the cloned module, which may fail if
  118. // |message_.instruction| is not sufficiently well-formed.
  119. if (!TryToAddFunction(cloned_module.get())) {
  120. return false;
  121. }
  122. // Check whether the cloned module is still valid after adding the function.
  123. // If it is not, the transformation is not applicable.
  124. if (!fuzzerutil::IsValid(cloned_module.get(),
  125. transformation_context.GetValidatorOptions())) {
  126. return false;
  127. }
  128. if (message_.is_livesafe()) {
  129. if (!TryToMakeFunctionLivesafe(cloned_module.get(),
  130. transformation_context)) {
  131. return false;
  132. }
  133. // After making the function livesafe, we check validity of the module
  134. // again. This is because the turning of OpKill, OpUnreachable and OpReturn
  135. // instructions into branches changes control flow graph reachability, which
  136. // has the potential to make the module invalid when it was otherwise valid.
  137. // It is simpler to rely on the validator to guard against this than to
  138. // consider all scenarios when making a function livesafe.
  139. if (!fuzzerutil::IsValid(cloned_module.get(),
  140. transformation_context.GetValidatorOptions())) {
  141. return false;
  142. }
  143. }
  144. return true;
  145. }
  146. void TransformationAddFunction::Apply(
  147. opt::IRContext* ir_context,
  148. TransformationContext* transformation_context) const {
  149. // Add the function to the module. As the transformation is applicable, this
  150. // should succeed.
  151. bool success = TryToAddFunction(ir_context);
  152. assert(success && "The function should be successfully added.");
  153. (void)(success); // Keep release builds happy (otherwise they may complain
  154. // that |success| is not used).
  155. // Record the fact that all pointer parameters and variables declared in the
  156. // function should be regarded as having irrelevant values. This allows other
  157. // passes to store arbitrarily to such variables, and to pass them freely as
  158. // parameters to other functions knowing that it is OK if they get
  159. // over-written.
  160. for (auto& instruction : message_.instruction()) {
  161. switch (instruction.opcode()) {
  162. case SpvOpFunctionParameter:
  163. if (ir_context->get_def_use_mgr()
  164. ->GetDef(instruction.result_type_id())
  165. ->opcode() == SpvOpTypePointer) {
  166. transformation_context->GetFactManager()
  167. ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
  168. }
  169. break;
  170. case SpvOpVariable:
  171. transformation_context->GetFactManager()
  172. ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
  173. break;
  174. default:
  175. break;
  176. }
  177. }
  178. if (message_.is_livesafe()) {
  179. // Make the function livesafe, which also should succeed.
  180. success = TryToMakeFunctionLivesafe(ir_context, *transformation_context);
  181. assert(success && "It should be possible to make the function livesafe.");
  182. (void)(success); // Keep release builds happy.
  183. // Inform the fact manager that the function is livesafe.
  184. assert(message_.instruction(0).opcode() == SpvOpFunction &&
  185. "The first instruction of an 'add function' transformation must be "
  186. "OpFunction.");
  187. transformation_context->GetFactManager()->AddFactFunctionIsLivesafe(
  188. message_.instruction(0).result_id());
  189. } else {
  190. // Inform the fact manager that all blocks in the function are dead.
  191. for (auto& inst : message_.instruction()) {
  192. if (inst.opcode() == SpvOpLabel) {
  193. transformation_context->GetFactManager()->AddFactBlockIsDead(
  194. inst.result_id());
  195. }
  196. }
  197. }
  198. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  199. }
  200. protobufs::Transformation TransformationAddFunction::ToMessage() const {
  201. protobufs::Transformation result;
  202. *result.mutable_add_function() = message_;
  203. return result;
  204. }
  205. bool TransformationAddFunction::TryToAddFunction(
  206. opt::IRContext* ir_context) const {
  207. // This function returns false if |message_.instruction| was not well-formed
  208. // enough to actually create a function and add it to |ir_context|.
  209. // A function must have at least some instructions.
  210. if (message_.instruction().empty()) {
  211. return false;
  212. }
  213. // A function must start with OpFunction.
  214. auto function_begin = message_.instruction(0);
  215. if (function_begin.opcode() != SpvOpFunction) {
  216. return false;
  217. }
  218. // Make a function, headed by the OpFunction instruction.
  219. std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
  220. InstructionFromMessage(ir_context, function_begin));
  221. // Keeps track of which instruction protobuf message we are currently
  222. // considering.
  223. uint32_t instruction_index = 1;
  224. const auto num_instructions =
  225. static_cast<uint32_t>(message_.instruction().size());
  226. // Iterate through all function parameter instructions, adding parameters to
  227. // the new function.
  228. while (instruction_index < num_instructions &&
  229. message_.instruction(instruction_index).opcode() ==
  230. SpvOpFunctionParameter) {
  231. new_function->AddParameter(InstructionFromMessage(
  232. ir_context, message_.instruction(instruction_index)));
  233. instruction_index++;
  234. }
  235. // After the parameters, there needs to be a label.
  236. if (instruction_index == num_instructions ||
  237. message_.instruction(instruction_index).opcode() != SpvOpLabel) {
  238. return false;
  239. }
  240. // Iterate through the instructions block by block until the end of the
  241. // function is reached.
  242. while (instruction_index < num_instructions &&
  243. message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
  244. // Invariant: we should always be at a label instruction at this point.
  245. assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);
  246. // Make a basic block using the label instruction, with the new function
  247. // as its parent.
  248. std::unique_ptr<opt::BasicBlock> block =
  249. MakeUnique<opt::BasicBlock>(InstructionFromMessage(
  250. ir_context, message_.instruction(instruction_index)));
  251. block->SetParent(new_function.get());
  252. // Consider successive instructions until we hit another label or the end
  253. // of the function, adding each such instruction to the block.
  254. instruction_index++;
  255. while (instruction_index < num_instructions &&
  256. message_.instruction(instruction_index).opcode() !=
  257. SpvOpFunctionEnd &&
  258. message_.instruction(instruction_index).opcode() != SpvOpLabel) {
  259. block->AddInstruction(InstructionFromMessage(
  260. ir_context, message_.instruction(instruction_index)));
  261. instruction_index++;
  262. }
  263. // Add the block to the new function.
  264. new_function->AddBasicBlock(std::move(block));
  265. }
  266. // Having considered all the blocks, we should be at the last instruction and
  267. // it needs to be OpFunctionEnd.
  268. if (instruction_index != num_instructions - 1 ||
  269. message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
  270. return false;
  271. }
  272. // Set the function's final instruction, add the function to the module and
  273. // report success.
  274. new_function->SetFunctionEnd(InstructionFromMessage(
  275. ir_context, message_.instruction(instruction_index)));
  276. ir_context->AddFunction(std::move(new_function));
  277. ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  278. return true;
  279. }
  280. bool TransformationAddFunction::TryToMakeFunctionLivesafe(
  281. opt::IRContext* ir_context,
  282. const TransformationContext& transformation_context) const {
  283. assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
  284. // Get a pointer to the added function.
  285. opt::Function* added_function = nullptr;
  286. for (auto& function : *ir_context->module()) {
  287. if (function.result_id() == message_.instruction(0).result_id()) {
  288. added_function = &function;
  289. break;
  290. }
  291. }
  292. assert(added_function && "The added function should have been found.");
  293. if (!TryToAddLoopLimiters(ir_context, added_function)) {
  294. // Adding loop limiters did not work; bail out.
  295. return false;
  296. }
  297. // Consider all the instructions in the function, and:
  298. // - attempt to replace OpKill and OpUnreachable with return instructions
  299. // - attempt to clamp access chains to be within bounds
  300. // - check that OpFunctionCall instructions are only to livesafe functions
  301. for (auto& block : *added_function) {
  302. for (auto& inst : block) {
  303. switch (inst.opcode()) {
  304. case SpvOpKill:
  305. case SpvOpUnreachable:
  306. if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function,
  307. &inst)) {
  308. return false;
  309. }
  310. break;
  311. case SpvOpAccessChain:
  312. case SpvOpInBoundsAccessChain:
  313. if (!TryToClampAccessChainIndices(ir_context, &inst)) {
  314. return false;
  315. }
  316. break;
  317. case SpvOpFunctionCall:
  318. // A livesafe function my only call other livesafe functions.
  319. if (!transformation_context.GetFactManager()->FunctionIsLivesafe(
  320. inst.GetSingleWordInOperand(0))) {
  321. return false;
  322. }
  323. default:
  324. break;
  325. }
  326. }
  327. }
  328. return true;
  329. }
  330. bool TransformationAddFunction::TryToAddLoopLimiters(
  331. opt::IRContext* ir_context, opt::Function* added_function) const {
  332. // Collect up all the loop headers so that we can subsequently add loop
  333. // limiting logic.
  334. std::vector<opt::BasicBlock*> loop_headers;
  335. for (auto& block : *added_function) {
  336. if (block.IsLoopHeader()) {
  337. loop_headers.push_back(&block);
  338. }
  339. }
  340. if (loop_headers.empty()) {
  341. // There are no loops, so no need to add any loop limiters.
  342. return true;
  343. }
  344. // Check that the module contains appropriate ingredients for declaring and
  345. // manipulating a loop limiter.
  346. auto loop_limit_constant_id_instr =
  347. ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
  348. if (!loop_limit_constant_id_instr ||
  349. loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
  350. // The loop limit constant id instruction must exist and have an
  351. // appropriate opcode.
  352. return false;
  353. }
  354. auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef(
  355. loop_limit_constant_id_instr->type_id());
  356. if (loop_limit_type->opcode() != SpvOpTypeInt ||
  357. loop_limit_type->GetSingleWordInOperand(0) != 32) {
  358. // The type of the loop limit constant must be 32-bit integer. It
  359. // doesn't actually matter whether the integer is signed or not.
  360. return false;
  361. }
  362. // Find the id of the "unsigned int" type.
  363. opt::analysis::Integer unsigned_int_type(32, false);
  364. uint32_t unsigned_int_type_id =
  365. ir_context->get_type_mgr()->GetId(&unsigned_int_type);
  366. if (!unsigned_int_type_id) {
  367. // Unsigned int is not available; we need this type in order to add loop
  368. // limiters.
  369. return false;
  370. }
  371. auto registered_unsigned_int_type =
  372. ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
  373. // Look for 0 of type unsigned int.
  374. opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
  375. {0});
  376. auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero);
  377. if (!registered_zero) {
  378. // We need 0 in order to be able to initialize loop limiters.
  379. return false;
  380. }
  381. uint32_t zero_id = ir_context->get_constant_mgr()
  382. ->GetDefiningInstruction(registered_zero)
  383. ->result_id();
  384. // Look for 1 of type unsigned int.
  385. opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
  386. {1});
  387. auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one);
  388. if (!registered_one) {
  389. // We need 1 in order to be able to increment loop limiters.
  390. return false;
  391. }
  392. uint32_t one_id = ir_context->get_constant_mgr()
  393. ->GetDefiningInstruction(registered_one)
  394. ->result_id();
  395. // Look for pointer-to-unsigned int type.
  396. opt::analysis::Pointer pointer_to_unsigned_int_type(
  397. registered_unsigned_int_type, SpvStorageClassFunction);
  398. uint32_t pointer_to_unsigned_int_type_id =
  399. ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
  400. if (!pointer_to_unsigned_int_type_id) {
  401. // We need pointer-to-unsigned int in order to declare the loop limiter
  402. // variable.
  403. return false;
  404. }
  405. // Look for bool type.
  406. opt::analysis::Bool bool_type;
  407. uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
  408. if (!bool_type_id) {
  409. // We need bool in order to compare the loop limiter's value with the loop
  410. // limit constant.
  411. return false;
  412. }
  413. // Declare the loop limiter variable at the start of the function's entry
  414. // block, via an instruction of the form:
  415. // %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
  416. added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
  417. ir_context, SpvOpVariable, pointer_to_unsigned_int_type_id,
  418. message_.loop_limiter_variable_id(),
  419. opt::Instruction::OperandList(
  420. {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
  421. {SPV_OPERAND_TYPE_ID, {zero_id}}})));
  422. // Update the module's id bound since we have added the loop limiter
  423. // variable id.
  424. fuzzerutil::UpdateModuleIdBound(ir_context,
  425. message_.loop_limiter_variable_id());
  426. // Consider each loop in turn.
  427. for (auto loop_header : loop_headers) {
  428. // Look for the loop's back-edge block. This is a predecessor of the loop
  429. // header that is dominated by the loop header.
  430. uint32_t back_edge_block_id = 0;
  431. for (auto pred : ir_context->cfg()->preds(loop_header->id())) {
  432. if (ir_context->GetDominatorAnalysis(added_function)
  433. ->Dominates(loop_header->id(), pred)) {
  434. back_edge_block_id = pred;
  435. break;
  436. }
  437. }
  438. if (!back_edge_block_id) {
  439. // The loop's back-edge block must be unreachable. This means that the
  440. // loop cannot iterate, so there is no need to make it lifesafe; we can
  441. // move on from this loop.
  442. continue;
  443. }
  444. auto back_edge_block = ir_context->cfg()->block(back_edge_block_id);
  445. // Go through the sequence of loop limiter infos and find the one
  446. // corresponding to this loop.
  447. bool found = false;
  448. protobufs::LoopLimiterInfo loop_limiter_info;
  449. for (auto& info : message_.loop_limiter_info()) {
  450. if (info.loop_header_id() == loop_header->id()) {
  451. loop_limiter_info = info;
  452. found = true;
  453. break;
  454. }
  455. }
  456. if (!found) {
  457. // We don't have loop limiter info for this loop header.
  458. return false;
  459. }
  460. // The back-edge block either has the form:
  461. //
  462. // (1)
  463. //
  464. // %l = OpLabel
  465. // ... instructions ...
  466. // OpBranch %loop_header
  467. //
  468. // (2)
  469. //
  470. // %l = OpLabel
  471. // ... instructions ...
  472. // OpBranchConditional %c %loop_header %loop_merge
  473. //
  474. // (3)
  475. //
  476. // %l = OpLabel
  477. // ... instructions ...
  478. // OpBranchConditional %c %loop_merge %loop_header
  479. //
  480. // We turn these into the following:
  481. //
  482. // (1)
  483. //
  484. // %l = OpLabel
  485. // ... instructions ...
  486. // %t1 = OpLoad %uint32 %loop_limiter
  487. // %t2 = OpIAdd %uint32 %t1 %one
  488. // OpStore %loop_limiter %t2
  489. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  490. // OpBranchConditional %t3 %loop_merge %loop_header
  491. //
  492. // (2)
  493. //
  494. // %l = OpLabel
  495. // ... instructions ...
  496. // %t1 = OpLoad %uint32 %loop_limiter
  497. // %t2 = OpIAdd %uint32 %t1 %one
  498. // OpStore %loop_limiter %t2
  499. // %t3 = OpULessThan %bool %t1 %loop_limit
  500. // %t4 = OpLogicalAnd %bool %c %t3
  501. // OpBranchConditional %t4 %loop_header %loop_merge
  502. //
  503. // (3)
  504. //
  505. // %l = OpLabel
  506. // ... instructions ...
  507. // %t1 = OpLoad %uint32 %loop_limiter
  508. // %t2 = OpIAdd %uint32 %t1 %one
  509. // OpStore %loop_limiter %t2
  510. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  511. // %t4 = OpLogicalOr %bool %c %t3
  512. // OpBranchConditional %t4 %loop_merge %loop_header
  513. auto back_edge_block_terminator = back_edge_block->terminator();
  514. bool compare_using_greater_than_equal;
  515. if (back_edge_block_terminator->opcode() == SpvOpBranch) {
  516. compare_using_greater_than_equal = true;
  517. } else {
  518. assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
  519. assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
  520. loop_header->id() &&
  521. back_edge_block_terminator->GetSingleWordInOperand(2) ==
  522. loop_header->MergeBlockId()) ||
  523. (back_edge_block_terminator->GetSingleWordInOperand(2) ==
  524. loop_header->id() &&
  525. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  526. loop_header->MergeBlockId())) &&
  527. "A back edge edge block must branch to"
  528. " either the loop header or merge");
  529. compare_using_greater_than_equal =
  530. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  531. loop_header->MergeBlockId();
  532. }
  533. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  534. // Add a load from the loop limiter variable, of the form:
  535. // %t1 = OpLoad %uint32 %loop_limiter
  536. new_instructions.push_back(MakeUnique<opt::Instruction>(
  537. ir_context, SpvOpLoad, unsigned_int_type_id,
  538. loop_limiter_info.load_id(),
  539. opt::Instruction::OperandList(
  540. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
  541. // Increment the loaded value:
  542. // %t2 = OpIAdd %uint32 %t1 %one
  543. new_instructions.push_back(MakeUnique<opt::Instruction>(
  544. ir_context, SpvOpIAdd, unsigned_int_type_id,
  545. loop_limiter_info.increment_id(),
  546. opt::Instruction::OperandList(
  547. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  548. {SPV_OPERAND_TYPE_ID, {one_id}}})));
  549. // Store the incremented value back to the loop limiter variable:
  550. // OpStore %loop_limiter %t2
  551. new_instructions.push_back(MakeUnique<opt::Instruction>(
  552. ir_context, SpvOpStore, 0, 0,
  553. opt::Instruction::OperandList(
  554. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
  555. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
  556. // Compare the loaded value with the loop limit; either:
  557. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  558. // or
  559. // %t3 = OpULessThan %bool %t1 %loop_limit
  560. new_instructions.push_back(MakeUnique<opt::Instruction>(
  561. ir_context,
  562. compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
  563. : SpvOpULessThan,
  564. bool_type_id, loop_limiter_info.compare_id(),
  565. opt::Instruction::OperandList(
  566. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  567. {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
  568. if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
  569. new_instructions.push_back(MakeUnique<opt::Instruction>(
  570. ir_context,
  571. compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
  572. bool_type_id, loop_limiter_info.logical_op_id(),
  573. opt::Instruction::OperandList(
  574. {{SPV_OPERAND_TYPE_ID,
  575. {back_edge_block_terminator->GetSingleWordInOperand(0)}},
  576. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
  577. }
  578. // Add the new instructions at the end of the back edge block, before the
  579. // terminator and any loop merge instruction (as the back edge block can
  580. // be the loop header).
  581. if (back_edge_block->GetLoopMergeInst()) {
  582. back_edge_block->GetLoopMergeInst()->InsertBefore(
  583. std::move(new_instructions));
  584. } else {
  585. back_edge_block_terminator->InsertBefore(std::move(new_instructions));
  586. }
  587. if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
  588. back_edge_block_terminator->SetInOperand(
  589. 0, {loop_limiter_info.logical_op_id()});
  590. } else {
  591. assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
  592. "Back-edge terminator must be OpBranch or OpBranchConditional");
  593. // Check that, if the merge block starts with OpPhi instructions, suitable
  594. // ids have been provided to give these instructions a value corresponding
  595. // to the new incoming edge from the back edge block.
  596. auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId());
  597. if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block,
  598. merge_block,
  599. loop_limiter_info.phi_id())) {
  600. return false;
  601. }
  602. // Augment OpPhi instructions at the loop merge with the given ids.
  603. uint32_t phi_index = 0;
  604. for (auto& inst : *merge_block) {
  605. if (inst.opcode() != SpvOpPhi) {
  606. break;
  607. }
  608. assert(phi_index <
  609. static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
  610. "There should be at least one phi id per OpPhi instruction.");
  611. inst.AddOperand(
  612. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
  613. inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
  614. phi_index++;
  615. }
  616. // Add the new edge, by changing OpBranch to OpBranchConditional.
  617. // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3162): This
  618. // could be a problem if the merge block was originally unreachable: it
  619. // might now be dominated by other blocks that it appears earlier than in
  620. // the module.
  621. back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
  622. back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
  623. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
  624. {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}
  625. },
  626. {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
  627. }
  628. // Update the module's id bound with respect to the various ids that
  629. // have been used for loop limiter manipulation.
  630. fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id());
  631. fuzzerutil::UpdateModuleIdBound(ir_context,
  632. loop_limiter_info.increment_id());
  633. fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id());
  634. fuzzerutil::UpdateModuleIdBound(ir_context,
  635. loop_limiter_info.logical_op_id());
  636. }
  637. return true;
  638. }
  639. bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
  640. opt::IRContext* ir_context, opt::Function* added_function,
  641. opt::Instruction* kill_or_unreachable_inst) const {
  642. assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
  643. kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
  644. "Precondition: instruction must be OpKill or OpUnreachable.");
  645. // Get the function's return type.
  646. auto function_return_type_inst =
  647. ir_context->get_def_use_mgr()->GetDef(added_function->type_id());
  648. if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
  649. // The function has void return type, so change this instruction to
  650. // OpReturn.
  651. kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
  652. } else {
  653. // The function has non-void return type, so change this instruction
  654. // to OpReturnValue, using the value id provided with the
  655. // transformation.
  656. // We first check that the id, %id, provided with the transformation
  657. // specifically to turn OpKill and OpUnreachable instructions into
  658. // OpReturnValue %id has the same type as the function's return type.
  659. if (ir_context->get_def_use_mgr()
  660. ->GetDef(message_.kill_unreachable_return_value_id())
  661. ->type_id() != function_return_type_inst->result_id()) {
  662. return false;
  663. }
  664. kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
  665. kill_or_unreachable_inst->SetInOperands(
  666. {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
  667. }
  668. return true;
  669. }
  670. bool TransformationAddFunction::TryToClampAccessChainIndices(
  671. opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const {
  672. assert((access_chain_inst->opcode() == SpvOpAccessChain ||
  673. access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
  674. "Precondition: instruction must be OpAccessChain or "
  675. "OpInBoundsAccessChain.");
  676. // Find the AccessChainClampingInfo associated with this access chain.
  677. const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
  678. nullptr;
  679. for (auto& clamping_info : message_.access_chain_clamping_info()) {
  680. if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
  681. access_chain_clamping_info = &clamping_info;
  682. break;
  683. }
  684. }
  685. if (!access_chain_clamping_info) {
  686. // No access chain clamping information was found; the function cannot be
  687. // made livesafe.
  688. return false;
  689. }
  690. // Check that there is a (compare_id, select_id) pair for every
  691. // index associated with the instruction.
  692. if (static_cast<uint32_t>(
  693. access_chain_clamping_info->compare_and_select_ids().size()) !=
  694. access_chain_inst->NumInOperands() - 1) {
  695. return false;
  696. }
  697. // Walk the access chain, clamping each index to be within bounds if it is
  698. // not a constant.
  699. auto base_object = ir_context->get_def_use_mgr()->GetDef(
  700. access_chain_inst->GetSingleWordInOperand(0));
  701. assert(base_object && "The base object must exist.");
  702. auto pointer_type =
  703. ir_context->get_def_use_mgr()->GetDef(base_object->type_id());
  704. assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
  705. "The base object must have pointer type.");
  706. auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef(
  707. pointer_type->GetSingleWordInOperand(1));
  708. // Consider each index input operand in turn (operand 0 is the base object).
  709. for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
  710. index++) {
  711. // We are going to turn:
  712. //
  713. // %result = OpAccessChain %type %object ... %index ...
  714. //
  715. // into:
  716. //
  717. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  718. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  719. // %result = OpAccessChain %type %object ... %t2 ...
  720. //
  721. // ... unless %index is already a constant.
  722. // Get the bound for the composite being indexed into; e.g. the number of
  723. // columns of matrix or the size of an array.
  724. uint32_t bound =
  725. GetBoundForCompositeIndex(ir_context, *should_be_composite_type);
  726. // Get the instruction associated with the index and figure out its integer
  727. // type.
  728. const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
  729. auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
  730. auto index_type_inst =
  731. ir_context->get_def_use_mgr()->GetDef(index_inst->type_id());
  732. assert(index_type_inst->opcode() == SpvOpTypeInt);
  733. assert(index_type_inst->GetSingleWordInOperand(0) == 32);
  734. opt::analysis::Integer* index_int_type =
  735. ir_context->get_type_mgr()
  736. ->GetType(index_type_inst->result_id())
  737. ->AsInteger();
  738. if (index_inst->opcode() != SpvOpConstant ||
  739. index_inst->GetSingleWordInOperand(0) >= bound) {
  740. // The index is either non-constant or an out-of-bounds constant, so we
  741. // need to clamp it.
  742. assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
  743. "Access chain indices into structures are required to be "
  744. "constants.");
  745. opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
  746. if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
  747. // We do not have an integer constant whose value is |bound| -1.
  748. return false;
  749. }
  750. opt::analysis::Bool bool_type;
  751. uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
  752. if (!bool_type_id) {
  753. // Bool type is not declared; we cannot do a comparison.
  754. return false;
  755. }
  756. uint32_t bound_minus_one_id =
  757. ir_context->get_constant_mgr()
  758. ->GetDefiningInstruction(&bound_minus_one)
  759. ->result_id();
  760. uint32_t compare_id =
  761. access_chain_clamping_info->compare_and_select_ids(index - 1).first();
  762. uint32_t select_id =
  763. access_chain_clamping_info->compare_and_select_ids(index - 1)
  764. .second();
  765. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  766. // Compare the index with the bound via an instruction of the form:
  767. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  768. new_instructions.push_back(MakeUnique<opt::Instruction>(
  769. ir_context, SpvOpULessThanEqual, bool_type_id, compare_id,
  770. opt::Instruction::OperandList(
  771. {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  772. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  773. // Select the index if in-bounds, otherwise one less than the bound:
  774. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  775. new_instructions.push_back(MakeUnique<opt::Instruction>(
  776. ir_context, SpvOpSelect, index_type_inst->result_id(), select_id,
  777. opt::Instruction::OperandList(
  778. {{SPV_OPERAND_TYPE_ID, {compare_id}},
  779. {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  780. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  781. // Add the new instructions before the access chain
  782. access_chain_inst->InsertBefore(std::move(new_instructions));
  783. // Replace %index with %t2.
  784. access_chain_inst->SetInOperand(index, {select_id});
  785. fuzzerutil::UpdateModuleIdBound(ir_context, compare_id);
  786. fuzzerutil::UpdateModuleIdBound(ir_context, select_id);
  787. }
  788. should_be_composite_type =
  789. FollowCompositeIndex(ir_context, *should_be_composite_type, index_id);
  790. }
  791. return true;
  792. }
  793. uint32_t TransformationAddFunction::GetBoundForCompositeIndex(
  794. opt::IRContext* ir_context, const opt::Instruction& composite_type_inst) {
  795. switch (composite_type_inst.opcode()) {
  796. case SpvOpTypeArray:
  797. return fuzzerutil::GetArraySize(composite_type_inst, ir_context);
  798. case SpvOpTypeMatrix:
  799. case SpvOpTypeVector:
  800. return composite_type_inst.GetSingleWordInOperand(1);
  801. case SpvOpTypeStruct: {
  802. return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
  803. }
  804. case SpvOpTypeRuntimeArray:
  805. assert(false &&
  806. "GetBoundForCompositeIndex should not be invoked with an "
  807. "OpTypeRuntimeArray, which does not have a static bound.");
  808. return 0;
  809. default:
  810. assert(false && "Unknown composite type.");
  811. return 0;
  812. }
  813. }
  814. opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
  815. opt::IRContext* ir_context, const opt::Instruction& composite_type_inst,
  816. uint32_t index_id) {
  817. uint32_t sub_object_type_id;
  818. switch (composite_type_inst.opcode()) {
  819. case SpvOpTypeArray:
  820. case SpvOpTypeRuntimeArray:
  821. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  822. break;
  823. case SpvOpTypeMatrix:
  824. case SpvOpTypeVector:
  825. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  826. break;
  827. case SpvOpTypeStruct: {
  828. auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
  829. assert(index_inst->opcode() == SpvOpConstant);
  830. assert(ir_context->get_def_use_mgr()
  831. ->GetDef(index_inst->type_id())
  832. ->opcode() == SpvOpTypeInt);
  833. assert(ir_context->get_def_use_mgr()
  834. ->GetDef(index_inst->type_id())
  835. ->GetSingleWordInOperand(0) == 32);
  836. uint32_t index_value = index_inst->GetSingleWordInOperand(0);
  837. sub_object_type_id =
  838. composite_type_inst.GetSingleWordInOperand(index_value);
  839. break;
  840. }
  841. default:
  842. assert(false && "Unknown composite type.");
  843. sub_object_type_id = 0;
  844. break;
  845. }
  846. assert(sub_object_type_id && "No sub-object found.");
  847. return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id);
  848. }
  849. } // namespace fuzz
  850. } // namespace spvtools