transformation_add_function.cpp 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_add_function.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_message.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationAddFunction::TransformationAddFunction(
  20. const spvtools::fuzz::protobufs::TransformationAddFunction& message)
  21. : message_(message) {}
  22. TransformationAddFunction::TransformationAddFunction(
  23. const std::vector<protobufs::Instruction>& instructions) {
  24. for (auto& instruction : instructions) {
  25. *message_.add_instruction() = instruction;
  26. }
  27. message_.set_is_livesafe(false);
  28. }
  29. TransformationAddFunction::TransformationAddFunction(
  30. const std::vector<protobufs::Instruction>& instructions,
  31. uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
  32. const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
  33. uint32_t kill_unreachable_return_value_id,
  34. const std::vector<protobufs::AccessChainClampingInfo>&
  35. access_chain_clampers) {
  36. for (auto& instruction : instructions) {
  37. *message_.add_instruction() = instruction;
  38. }
  39. message_.set_is_livesafe(true);
  40. message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
  41. message_.set_loop_limit_constant_id(loop_limit_constant_id);
  42. for (auto& loop_limiter : loop_limiters) {
  43. *message_.add_loop_limiter_info() = loop_limiter;
  44. }
  45. message_.set_kill_unreachable_return_value_id(
  46. kill_unreachable_return_value_id);
  47. for (auto& access_clamper : access_chain_clampers) {
  48. *message_.add_access_chain_clamping_info() = access_clamper;
  49. }
  50. }
  51. bool TransformationAddFunction::IsApplicable(
  52. opt::IRContext* context,
  53. const spvtools::fuzz::FactManager& fact_manager) const {
  54. // This transformation may use a lot of ids, all of which need to be fresh
  55. // and distinct. This set tracks them.
  56. std::set<uint32_t> ids_used_by_this_transformation;
  57. // Ensure that all result ids in the new function are fresh and distinct.
  58. for (auto& instruction : message_.instruction()) {
  59. if (instruction.result_id()) {
  60. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  61. instruction.result_id(), context,
  62. &ids_used_by_this_transformation)) {
  63. return false;
  64. }
  65. }
  66. }
  67. if (message_.is_livesafe()) {
  68. // Ensure that all ids provided for making the function livesafe are fresh
  69. // and distinct.
  70. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  71. message_.loop_limiter_variable_id(), context,
  72. &ids_used_by_this_transformation)) {
  73. return false;
  74. }
  75. for (auto& loop_limiter_info : message_.loop_limiter_info()) {
  76. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  77. loop_limiter_info.load_id(), context,
  78. &ids_used_by_this_transformation)) {
  79. return false;
  80. }
  81. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  82. loop_limiter_info.increment_id(), context,
  83. &ids_used_by_this_transformation)) {
  84. return false;
  85. }
  86. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  87. loop_limiter_info.compare_id(), context,
  88. &ids_used_by_this_transformation)) {
  89. return false;
  90. }
  91. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  92. loop_limiter_info.logical_op_id(), context,
  93. &ids_used_by_this_transformation)) {
  94. return false;
  95. }
  96. }
  97. for (auto& access_chain_clamping_info :
  98. message_.access_chain_clamping_info()) {
  99. for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
  100. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  101. pair.first(), context, &ids_used_by_this_transformation)) {
  102. return false;
  103. }
  104. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  105. pair.second(), context, &ids_used_by_this_transformation)) {
  106. return false;
  107. }
  108. }
  109. }
  110. }
  111. // Because checking all the conditions for a function to be valid is a big
  112. // job that the SPIR-V validator can already do, a "try it and see" approach
  113. // is taken here.
  114. // We first clone the current module, so that we can try adding the new
  115. // function without risking wrecking |context|.
  116. auto cloned_module = fuzzerutil::CloneIRContext(context);
  117. // We try to add a function to the cloned module, which may fail if
  118. // |message_.instruction| is not sufficiently well-formed.
  119. if (!TryToAddFunction(cloned_module.get())) {
  120. return false;
  121. }
  122. // Check whether the cloned module is still valid after adding the function.
  123. // If it is not, the transformation is not applicable.
  124. if (!fuzzerutil::IsValid(cloned_module.get())) {
  125. return false;
  126. }
  127. if (message_.is_livesafe()) {
  128. if (!TryToMakeFunctionLivesafe(cloned_module.get(), fact_manager)) {
  129. return false;
  130. }
  131. // After making the function livesafe, we check validity of the module
  132. // again. This is because the turning of OpKill, OpUnreachable and OpReturn
  133. // instructions into branches changes control flow graph reachability, which
  134. // has the potential to make the module invalid when it was otherwise valid.
  135. // It is simpler to rely on the validator to guard against this than to
  136. // consider all scenarios when making a function livesafe.
  137. if (!fuzzerutil::IsValid(cloned_module.get())) {
  138. return false;
  139. }
  140. }
  141. return true;
  142. }
  143. void TransformationAddFunction::Apply(
  144. opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const {
  145. // Add the function to the module. As the transformation is applicable, this
  146. // should succeed.
  147. bool success = TryToAddFunction(context);
  148. assert(success && "The function should be successfully added.");
  149. (void)(success); // Keep release builds happy (otherwise they may complain
  150. // that |success| is not used).
  151. // Record the fact that all pointer parameters and variables declared in the
  152. // function should be regarded as having irrelevant values. This allows other
  153. // passes to store arbitrarily to such variables, and to pass them freely as
  154. // parameters to other functions knowing that it is OK if they get
  155. // over-written.
  156. for (auto& instruction : message_.instruction()) {
  157. switch (instruction.opcode()) {
  158. case SpvOpFunctionParameter:
  159. if (context->get_def_use_mgr()
  160. ->GetDef(instruction.result_type_id())
  161. ->opcode() == SpvOpTypePointer) {
  162. fact_manager->AddFactValueOfPointeeIsIrrelevant(
  163. instruction.result_id());
  164. }
  165. break;
  166. case SpvOpVariable:
  167. fact_manager->AddFactValueOfPointeeIsIrrelevant(
  168. instruction.result_id());
  169. break;
  170. default:
  171. break;
  172. }
  173. }
  174. if (message_.is_livesafe()) {
  175. // Make the function livesafe, which also should succeed.
  176. success = TryToMakeFunctionLivesafe(context, *fact_manager);
  177. assert(success && "It should be possible to make the function livesafe.");
  178. (void)(success); // Keep release builds happy.
  179. // Inform the fact manager that the function is livesafe.
  180. assert(message_.instruction(0).opcode() == SpvOpFunction &&
  181. "The first instruction of an 'add function' transformation must be "
  182. "OpFunction.");
  183. fact_manager->AddFactFunctionIsLivesafe(
  184. message_.instruction(0).result_id());
  185. } else {
  186. // Inform the fact manager that all blocks in the function are dead.
  187. for (auto& inst : message_.instruction()) {
  188. if (inst.opcode() == SpvOpLabel) {
  189. fact_manager->AddFactBlockIsDead(inst.result_id());
  190. }
  191. }
  192. }
  193. context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  194. }
  195. protobufs::Transformation TransformationAddFunction::ToMessage() const {
  196. protobufs::Transformation result;
  197. *result.mutable_add_function() = message_;
  198. return result;
  199. }
  200. bool TransformationAddFunction::TryToAddFunction(
  201. opt::IRContext* context) const {
  202. // This function returns false if |message_.instruction| was not well-formed
  203. // enough to actually create a function and add it to |context|.
  204. // A function must have at least some instructions.
  205. if (message_.instruction().empty()) {
  206. return false;
  207. }
  208. // A function must start with OpFunction.
  209. auto function_begin = message_.instruction(0);
  210. if (function_begin.opcode() != SpvOpFunction) {
  211. return false;
  212. }
  213. // Make a function, headed by the OpFunction instruction.
  214. std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
  215. InstructionFromMessage(context, function_begin));
  216. // Keeps track of which instruction protobuf message we are currently
  217. // considering.
  218. uint32_t instruction_index = 1;
  219. const auto num_instructions =
  220. static_cast<uint32_t>(message_.instruction().size());
  221. // Iterate through all function parameter instructions, adding parameters to
  222. // the new function.
  223. while (instruction_index < num_instructions &&
  224. message_.instruction(instruction_index).opcode() ==
  225. SpvOpFunctionParameter) {
  226. new_function->AddParameter(InstructionFromMessage(
  227. context, message_.instruction(instruction_index)));
  228. instruction_index++;
  229. }
  230. // After the parameters, there needs to be a label.
  231. if (instruction_index == num_instructions ||
  232. message_.instruction(instruction_index).opcode() != SpvOpLabel) {
  233. return false;
  234. }
  235. // Iterate through the instructions block by block until the end of the
  236. // function is reached.
  237. while (instruction_index < num_instructions &&
  238. message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
  239. // Invariant: we should always be at a label instruction at this point.
  240. assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);
  241. // Make a basic block using the label instruction, with the new function
  242. // as its parent.
  243. std::unique_ptr<opt::BasicBlock> block =
  244. MakeUnique<opt::BasicBlock>(InstructionFromMessage(
  245. context, message_.instruction(instruction_index)));
  246. block->SetParent(new_function.get());
  247. // Consider successive instructions until we hit another label or the end
  248. // of the function, adding each such instruction to the block.
  249. instruction_index++;
  250. while (instruction_index < num_instructions &&
  251. message_.instruction(instruction_index).opcode() !=
  252. SpvOpFunctionEnd &&
  253. message_.instruction(instruction_index).opcode() != SpvOpLabel) {
  254. block->AddInstruction(InstructionFromMessage(
  255. context, message_.instruction(instruction_index)));
  256. instruction_index++;
  257. }
  258. // Add the block to the new function.
  259. new_function->AddBasicBlock(std::move(block));
  260. }
  261. // Having considered all the blocks, we should be at the last instruction and
  262. // it needs to be OpFunctionEnd.
  263. if (instruction_index != num_instructions - 1 ||
  264. message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
  265. return false;
  266. }
  267. // Set the function's final instruction, add the function to the module and
  268. // report success.
  269. new_function->SetFunctionEnd(
  270. InstructionFromMessage(context, message_.instruction(instruction_index)));
  271. context->AddFunction(std::move(new_function));
  272. context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  273. return true;
  274. }
  275. bool TransformationAddFunction::TryToMakeFunctionLivesafe(
  276. opt::IRContext* context, const FactManager& fact_manager) const {
  277. assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
  278. // Get a pointer to the added function.
  279. opt::Function* added_function = nullptr;
  280. for (auto& function : *context->module()) {
  281. if (function.result_id() == message_.instruction(0).result_id()) {
  282. added_function = &function;
  283. break;
  284. }
  285. }
  286. assert(added_function && "The added function should have been found.");
  287. if (!TryToAddLoopLimiters(context, added_function)) {
  288. // Adding loop limiters did not work; bail out.
  289. return false;
  290. }
  291. // Consider all the instructions in the function, and:
  292. // - attempt to replace OpKill and OpUnreachable with return instructions
  293. // - attempt to clamp access chains to be within bounds
  294. // - check that OpFunctionCall instructions are only to livesafe functions
  295. for (auto& block : *added_function) {
  296. for (auto& inst : block) {
  297. switch (inst.opcode()) {
  298. case SpvOpKill:
  299. case SpvOpUnreachable:
  300. if (!TryToTurnKillOrUnreachableIntoReturn(context, added_function,
  301. &inst)) {
  302. return false;
  303. }
  304. break;
  305. case SpvOpAccessChain:
  306. case SpvOpInBoundsAccessChain:
  307. if (!TryToClampAccessChainIndices(context, &inst)) {
  308. return false;
  309. }
  310. break;
  311. case SpvOpFunctionCall:
  312. // A livesafe function my only call other livesafe functions.
  313. if (!fact_manager.FunctionIsLivesafe(
  314. inst.GetSingleWordInOperand(0))) {
  315. return false;
  316. }
  317. default:
  318. break;
  319. }
  320. }
  321. }
  322. return true;
  323. }
  324. bool TransformationAddFunction::TryToAddLoopLimiters(
  325. opt::IRContext* context, opt::Function* added_function) const {
  326. // Collect up all the loop headers so that we can subsequently add loop
  327. // limiting logic.
  328. std::vector<opt::BasicBlock*> loop_headers;
  329. for (auto& block : *added_function) {
  330. if (block.IsLoopHeader()) {
  331. loop_headers.push_back(&block);
  332. }
  333. }
  334. if (loop_headers.empty()) {
  335. // There are no loops, so no need to add any loop limiters.
  336. return true;
  337. }
  338. // Check that the module contains appropriate ingredients for declaring and
  339. // manipulating a loop limiter.
  340. auto loop_limit_constant_id_instr =
  341. context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
  342. if (!loop_limit_constant_id_instr ||
  343. loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
  344. // The loop limit constant id instruction must exist and have an
  345. // appropriate opcode.
  346. return false;
  347. }
  348. auto loop_limit_type = context->get_def_use_mgr()->GetDef(
  349. loop_limit_constant_id_instr->type_id());
  350. if (loop_limit_type->opcode() != SpvOpTypeInt ||
  351. loop_limit_type->GetSingleWordInOperand(0) != 32) {
  352. // The type of the loop limit constant must be 32-bit integer. It
  353. // doesn't actually matter whether the integer is signed or not.
  354. return false;
  355. }
  356. // Find the id of the "unsigned int" type.
  357. opt::analysis::Integer unsigned_int_type(32, false);
  358. uint32_t unsigned_int_type_id =
  359. context->get_type_mgr()->GetId(&unsigned_int_type);
  360. if (!unsigned_int_type_id) {
  361. // Unsigned int is not available; we need this type in order to add loop
  362. // limiters.
  363. return false;
  364. }
  365. auto registered_unsigned_int_type =
  366. context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
  367. // Look for 0 of type unsigned int.
  368. opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
  369. {0});
  370. auto registered_zero = context->get_constant_mgr()->FindConstant(&zero);
  371. if (!registered_zero) {
  372. // We need 0 in order to be able to initialize loop limiters.
  373. return false;
  374. }
  375. uint32_t zero_id = context->get_constant_mgr()
  376. ->GetDefiningInstruction(registered_zero)
  377. ->result_id();
  378. // Look for 1 of type unsigned int.
  379. opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
  380. {1});
  381. auto registered_one = context->get_constant_mgr()->FindConstant(&one);
  382. if (!registered_one) {
  383. // We need 1 in order to be able to increment loop limiters.
  384. return false;
  385. }
  386. uint32_t one_id = context->get_constant_mgr()
  387. ->GetDefiningInstruction(registered_one)
  388. ->result_id();
  389. // Look for pointer-to-unsigned int type.
  390. opt::analysis::Pointer pointer_to_unsigned_int_type(
  391. registered_unsigned_int_type, SpvStorageClassFunction);
  392. uint32_t pointer_to_unsigned_int_type_id =
  393. context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
  394. if (!pointer_to_unsigned_int_type_id) {
  395. // We need pointer-to-unsigned int in order to declare the loop limiter
  396. // variable.
  397. return false;
  398. }
  399. // Look for bool type.
  400. opt::analysis::Bool bool_type;
  401. uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type);
  402. if (!bool_type_id) {
  403. // We need bool in order to compare the loop limiter's value with the loop
  404. // limit constant.
  405. return false;
  406. }
  407. // Declare the loop limiter variable at the start of the function's entry
  408. // block, via an instruction of the form:
  409. // %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
  410. added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
  411. context, SpvOpVariable, pointer_to_unsigned_int_type_id,
  412. message_.loop_limiter_variable_id(),
  413. opt::Instruction::OperandList(
  414. {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
  415. {SPV_OPERAND_TYPE_ID, {zero_id}}})));
  416. // Update the module's id bound since we have added the loop limiter
  417. // variable id.
  418. fuzzerutil::UpdateModuleIdBound(context, message_.loop_limiter_variable_id());
  419. // Consider each loop in turn.
  420. for (auto loop_header : loop_headers) {
  421. // Look for the loop's back-edge block. This is a predecessor of the loop
  422. // header that is dominated by the loop header.
  423. uint32_t back_edge_block_id = 0;
  424. for (auto pred : context->cfg()->preds(loop_header->id())) {
  425. if (context->GetDominatorAnalysis(added_function)
  426. ->Dominates(loop_header->id(), pred)) {
  427. back_edge_block_id = pred;
  428. break;
  429. }
  430. }
  431. if (!back_edge_block_id) {
  432. // The loop's back-edge block must be unreachable. This means that the
  433. // loop cannot iterate, so there is no need to make it lifesafe; we can
  434. // move on from this loop.
  435. continue;
  436. }
  437. auto back_edge_block = context->cfg()->block(back_edge_block_id);
  438. // Go through the sequence of loop limiter infos and find the one
  439. // corresponding to this loop.
  440. bool found = false;
  441. protobufs::LoopLimiterInfo loop_limiter_info;
  442. for (auto& info : message_.loop_limiter_info()) {
  443. if (info.loop_header_id() == loop_header->id()) {
  444. loop_limiter_info = info;
  445. found = true;
  446. break;
  447. }
  448. }
  449. if (!found) {
  450. // We don't have loop limiter info for this loop header.
  451. return false;
  452. }
  453. // The back-edge block either has the form:
  454. //
  455. // (1)
  456. //
  457. // %l = OpLabel
  458. // ... instructions ...
  459. // OpBranch %loop_header
  460. //
  461. // (2)
  462. //
  463. // %l = OpLabel
  464. // ... instructions ...
  465. // OpBranchConditional %c %loop_header %loop_merge
  466. //
  467. // (3)
  468. //
  469. // %l = OpLabel
  470. // ... instructions ...
  471. // OpBranchConditional %c %loop_merge %loop_header
  472. //
  473. // We turn these into the following:
  474. //
  475. // (1)
  476. //
  477. // %l = OpLabel
  478. // ... instructions ...
  479. // %t1 = OpLoad %uint32 %loop_limiter
  480. // %t2 = OpIAdd %uint32 %t1 %one
  481. // OpStore %loop_limiter %t2
  482. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  483. // OpBranchConditional %t3 %loop_merge %loop_header
  484. //
  485. // (2)
  486. //
  487. // %l = OpLabel
  488. // ... instructions ...
  489. // %t1 = OpLoad %uint32 %loop_limiter
  490. // %t2 = OpIAdd %uint32 %t1 %one
  491. // OpStore %loop_limiter %t2
  492. // %t3 = OpULessThan %bool %t1 %loop_limit
  493. // %t4 = OpLogicalAnd %bool %c %t3
  494. // OpBranchConditional %t4 %loop_header %loop_merge
  495. //
  496. // (3)
  497. //
  498. // %l = OpLabel
  499. // ... instructions ...
  500. // %t1 = OpLoad %uint32 %loop_limiter
  501. // %t2 = OpIAdd %uint32 %t1 %one
  502. // OpStore %loop_limiter %t2
  503. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  504. // %t4 = OpLogicalOr %bool %c %t3
  505. // OpBranchConditional %t4 %loop_merge %loop_header
  506. auto back_edge_block_terminator = back_edge_block->terminator();
  507. bool compare_using_greater_than_equal;
  508. if (back_edge_block_terminator->opcode() == SpvOpBranch) {
  509. compare_using_greater_than_equal = true;
  510. } else {
  511. assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
  512. assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
  513. loop_header->id() &&
  514. back_edge_block_terminator->GetSingleWordInOperand(2) ==
  515. loop_header->MergeBlockId()) ||
  516. (back_edge_block_terminator->GetSingleWordInOperand(2) ==
  517. loop_header->id() &&
  518. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  519. loop_header->MergeBlockId())) &&
  520. "A back edge edge block must branch to"
  521. " either the loop header or merge");
  522. compare_using_greater_than_equal =
  523. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  524. loop_header->MergeBlockId();
  525. }
  526. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  527. // Add a load from the loop limiter variable, of the form:
  528. // %t1 = OpLoad %uint32 %loop_limiter
  529. new_instructions.push_back(MakeUnique<opt::Instruction>(
  530. context, SpvOpLoad, unsigned_int_type_id, loop_limiter_info.load_id(),
  531. opt::Instruction::OperandList(
  532. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
  533. // Increment the loaded value:
  534. // %t2 = OpIAdd %uint32 %t1 %one
  535. new_instructions.push_back(MakeUnique<opt::Instruction>(
  536. context, SpvOpIAdd, unsigned_int_type_id,
  537. loop_limiter_info.increment_id(),
  538. opt::Instruction::OperandList(
  539. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  540. {SPV_OPERAND_TYPE_ID, {one_id}}})));
  541. // Store the incremented value back to the loop limiter variable:
  542. // OpStore %loop_limiter %t2
  543. new_instructions.push_back(MakeUnique<opt::Instruction>(
  544. context, SpvOpStore, 0, 0,
  545. opt::Instruction::OperandList(
  546. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
  547. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
  548. // Compare the loaded value with the loop limit; either:
  549. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  550. // or
  551. // %t3 = OpULessThan %bool %t1 %loop_limit
  552. new_instructions.push_back(MakeUnique<opt::Instruction>(
  553. context,
  554. compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
  555. : SpvOpULessThan,
  556. bool_type_id, loop_limiter_info.compare_id(),
  557. opt::Instruction::OperandList(
  558. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  559. {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
  560. if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
  561. new_instructions.push_back(MakeUnique<opt::Instruction>(
  562. context,
  563. compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
  564. bool_type_id, loop_limiter_info.logical_op_id(),
  565. opt::Instruction::OperandList(
  566. {{SPV_OPERAND_TYPE_ID,
  567. {back_edge_block_terminator->GetSingleWordInOperand(0)}},
  568. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
  569. }
  570. // Add the new instructions at the end of the back edge block, before the
  571. // terminator and any loop merge instruction (as the back edge block can
  572. // be the loop header).
  573. if (back_edge_block->GetLoopMergeInst()) {
  574. back_edge_block->GetLoopMergeInst()->InsertBefore(
  575. std::move(new_instructions));
  576. } else {
  577. back_edge_block_terminator->InsertBefore(std::move(new_instructions));
  578. }
  579. if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
  580. back_edge_block_terminator->SetInOperand(
  581. 0, {loop_limiter_info.logical_op_id()});
  582. } else {
  583. assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
  584. "Back-edge terminator must be OpBranch or OpBranchConditional");
  585. // Check that, if the merge block starts with OpPhi instructions, suitable
  586. // ids have been provided to give these instructions a value corresponding
  587. // to the new incoming edge from the back edge block.
  588. auto merge_block = context->cfg()->block(loop_header->MergeBlockId());
  589. if (!fuzzerutil::PhiIdsOkForNewEdge(context, back_edge_block, merge_block,
  590. loop_limiter_info.phi_id())) {
  591. return false;
  592. }
  593. // Augment OpPhi instructions at the loop merge with the given ids.
  594. uint32_t phi_index = 0;
  595. for (auto& inst : *merge_block) {
  596. if (inst.opcode() != SpvOpPhi) {
  597. break;
  598. }
  599. assert(phi_index <
  600. static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
  601. "There should be at least one phi id per OpPhi instruction.");
  602. inst.AddOperand(
  603. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
  604. inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
  605. phi_index++;
  606. }
  607. // Add the new edge, by changing OpBranch to OpBranchConditional.
  608. // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3162): This
  609. // could be a problem if the merge block was originally unreachable: it
  610. // might now be dominated by other blocks that it appears earlier than in
  611. // the module.
  612. back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
  613. back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
  614. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
  615. {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}
  616. },
  617. {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
  618. }
  619. // Update the module's id bound with respect to the various ids that
  620. // have been used for loop limiter manipulation.
  621. fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.load_id());
  622. fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.increment_id());
  623. fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.compare_id());
  624. fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.logical_op_id());
  625. }
  626. return true;
  627. }
  628. bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
  629. opt::IRContext* context, opt::Function* added_function,
  630. opt::Instruction* kill_or_unreachable_inst) const {
  631. assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
  632. kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
  633. "Precondition: instruction must be OpKill or OpUnreachable.");
  634. // Get the function's return type.
  635. auto function_return_type_inst =
  636. context->get_def_use_mgr()->GetDef(added_function->type_id());
  637. if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
  638. // The function has void return type, so change this instruction to
  639. // OpReturn.
  640. kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
  641. } else {
  642. // The function has non-void return type, so change this instruction
  643. // to OpReturnValue, using the value id provided with the
  644. // transformation.
  645. // We first check that the id, %id, provided with the transformation
  646. // specifically to turn OpKill and OpUnreachable instructions into
  647. // OpReturnValue %id has the same type as the function's return type.
  648. if (context->get_def_use_mgr()
  649. ->GetDef(message_.kill_unreachable_return_value_id())
  650. ->type_id() != function_return_type_inst->result_id()) {
  651. return false;
  652. }
  653. kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
  654. kill_or_unreachable_inst->SetInOperands(
  655. {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
  656. }
  657. return true;
  658. }
  659. bool TransformationAddFunction::TryToClampAccessChainIndices(
  660. opt::IRContext* context, opt::Instruction* access_chain_inst) const {
  661. assert((access_chain_inst->opcode() == SpvOpAccessChain ||
  662. access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
  663. "Precondition: instruction must be OpAccessChain or "
  664. "OpInBoundsAccessChain.");
  665. // Find the AccessChainClampingInfo associated with this access chain.
  666. const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
  667. nullptr;
  668. for (auto& clamping_info : message_.access_chain_clamping_info()) {
  669. if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
  670. access_chain_clamping_info = &clamping_info;
  671. break;
  672. }
  673. }
  674. if (!access_chain_clamping_info) {
  675. // No access chain clamping information was found; the function cannot be
  676. // made livesafe.
  677. return false;
  678. }
  679. // Check that there is a (compare_id, select_id) pair for every
  680. // index associated with the instruction.
  681. if (static_cast<uint32_t>(
  682. access_chain_clamping_info->compare_and_select_ids().size()) !=
  683. access_chain_inst->NumInOperands() - 1) {
  684. return false;
  685. }
  686. // Walk the access chain, clamping each index to be within bounds if it is
  687. // not a constant.
  688. auto base_object = context->get_def_use_mgr()->GetDef(
  689. access_chain_inst->GetSingleWordInOperand(0));
  690. assert(base_object && "The base object must exist.");
  691. auto pointer_type =
  692. context->get_def_use_mgr()->GetDef(base_object->type_id());
  693. assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
  694. "The base object must have pointer type.");
  695. auto should_be_composite_type = context->get_def_use_mgr()->GetDef(
  696. pointer_type->GetSingleWordInOperand(1));
  697. // Consider each index input operand in turn (operand 0 is the base object).
  698. for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
  699. index++) {
  700. // We are going to turn:
  701. //
  702. // %result = OpAccessChain %type %object ... %index ...
  703. //
  704. // into:
  705. //
  706. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  707. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  708. // %result = OpAccessChain %type %object ... %t2 ...
  709. //
  710. // ... unless %index is already a constant.
  711. // Get the bound for the composite being indexed into; e.g. the number of
  712. // columns of matrix or the size of an array.
  713. uint32_t bound =
  714. GetBoundForCompositeIndex(context, *should_be_composite_type);
  715. // Get the instruction associated with the index and figure out its integer
  716. // type.
  717. const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
  718. auto index_inst = context->get_def_use_mgr()->GetDef(index_id);
  719. auto index_type_inst =
  720. context->get_def_use_mgr()->GetDef(index_inst->type_id());
  721. assert(index_type_inst->opcode() == SpvOpTypeInt);
  722. assert(index_type_inst->GetSingleWordInOperand(0) == 32);
  723. opt::analysis::Integer* index_int_type =
  724. context->get_type_mgr()
  725. ->GetType(index_type_inst->result_id())
  726. ->AsInteger();
  727. if (index_inst->opcode() != SpvOpConstant) {
  728. // The index is non-constant so we need to clamp it.
  729. assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
  730. "Access chain indices into structures are required to be "
  731. "constants.");
  732. opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
  733. if (!context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
  734. // We do not have an integer constant whose value is |bound| -1.
  735. return false;
  736. }
  737. opt::analysis::Bool bool_type;
  738. uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type);
  739. if (!bool_type_id) {
  740. // Bool type is not declared; we cannot do a comparison.
  741. return false;
  742. }
  743. uint32_t bound_minus_one_id =
  744. context->get_constant_mgr()
  745. ->GetDefiningInstruction(&bound_minus_one)
  746. ->result_id();
  747. uint32_t compare_id =
  748. access_chain_clamping_info->compare_and_select_ids(index - 1).first();
  749. uint32_t select_id =
  750. access_chain_clamping_info->compare_and_select_ids(index - 1)
  751. .second();
  752. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  753. // Compare the index with the bound via an instruction of the form:
  754. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  755. new_instructions.push_back(MakeUnique<opt::Instruction>(
  756. context, SpvOpULessThanEqual, bool_type_id, compare_id,
  757. opt::Instruction::OperandList(
  758. {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  759. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  760. // Select the index if in-bounds, otherwise one less than the bound:
  761. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  762. new_instructions.push_back(MakeUnique<opt::Instruction>(
  763. context, SpvOpSelect, index_type_inst->result_id(), select_id,
  764. opt::Instruction::OperandList(
  765. {{SPV_OPERAND_TYPE_ID, {compare_id}},
  766. {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  767. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  768. // Add the new instructions before the access chain
  769. access_chain_inst->InsertBefore(std::move(new_instructions));
  770. // Replace %index with %t2.
  771. access_chain_inst->SetInOperand(index, {select_id});
  772. fuzzerutil::UpdateModuleIdBound(context, compare_id);
  773. fuzzerutil::UpdateModuleIdBound(context, select_id);
  774. } else {
  775. // TODO(afd): At present the SPIR-V spec is not clear on whether
  776. // statically out-of-bounds indices mean that a module is invalid (so
  777. // that it should be rejected by the validator), or that such accesses
  778. // yield undefined results. Via the following assertion, we assume that
  779. // functions added to the module do not feature statically out-of-bounds
  780. // accesses.
  781. // Assert that the index is smaller (unsigned) than this value.
  782. // Return false if it is not (to keep compilers happy).
  783. if (index_inst->GetSingleWordInOperand(0) >= bound) {
  784. assert(false &&
  785. "The function has a statically out-of-bounds access; "
  786. "this should not occur.");
  787. return false;
  788. }
  789. }
  790. should_be_composite_type =
  791. FollowCompositeIndex(context, *should_be_composite_type, index_id);
  792. }
  793. return true;
  794. }
  795. uint32_t TransformationAddFunction::GetBoundForCompositeIndex(
  796. opt::IRContext* context, const opt::Instruction& composite_type_inst) {
  797. switch (composite_type_inst.opcode()) {
  798. case SpvOpTypeArray:
  799. return fuzzerutil::GetArraySize(composite_type_inst, context);
  800. case SpvOpTypeMatrix:
  801. case SpvOpTypeVector:
  802. return composite_type_inst.GetSingleWordInOperand(1);
  803. case SpvOpTypeStruct: {
  804. return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
  805. }
  806. default:
  807. assert(false && "Unknown composite type.");
  808. return 0;
  809. }
  810. }
  811. opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
  812. opt::IRContext* context, const opt::Instruction& composite_type_inst,
  813. uint32_t index_id) {
  814. uint32_t sub_object_type_id;
  815. switch (composite_type_inst.opcode()) {
  816. case SpvOpTypeArray:
  817. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  818. break;
  819. case SpvOpTypeMatrix:
  820. case SpvOpTypeVector:
  821. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  822. break;
  823. case SpvOpTypeStruct: {
  824. auto index_inst = context->get_def_use_mgr()->GetDef(index_id);
  825. assert(index_inst->opcode() == SpvOpConstant);
  826. assert(
  827. context->get_def_use_mgr()->GetDef(index_inst->type_id())->opcode() ==
  828. SpvOpTypeInt);
  829. assert(context->get_def_use_mgr()
  830. ->GetDef(index_inst->type_id())
  831. ->GetSingleWordInOperand(0) == 32);
  832. uint32_t index_value = index_inst->GetSingleWordInOperand(0);
  833. sub_object_type_id =
  834. composite_type_inst.GetSingleWordInOperand(index_value);
  835. break;
  836. }
  837. default:
  838. assert(false && "Unknown composite type.");
  839. sub_object_type_id = 0;
  840. break;
  841. }
  842. assert(sub_object_type_id && "No sub-object found.");
  843. return context->get_def_use_mgr()->GetDef(sub_object_type_id);
  844. }
  845. } // namespace fuzz
  846. } // namespace spvtools