transformation_add_function.cpp 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/fuzz/transformation_add_function.h"
  15. #include "source/fuzz/fuzzer_util.h"
  16. #include "source/fuzz/instruction_message.h"
  17. namespace spvtools {
  18. namespace fuzz {
  19. TransformationAddFunction::TransformationAddFunction(
  20. const spvtools::fuzz::protobufs::TransformationAddFunction& message)
  21. : message_(message) {}
  22. TransformationAddFunction::TransformationAddFunction(
  23. const std::vector<protobufs::Instruction>& instructions) {
  24. for (auto& instruction : instructions) {
  25. *message_.add_instruction() = instruction;
  26. }
  27. message_.set_is_livesafe(false);
  28. }
  29. TransformationAddFunction::TransformationAddFunction(
  30. const std::vector<protobufs::Instruction>& instructions,
  31. uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
  32. const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
  33. uint32_t kill_unreachable_return_value_id,
  34. const std::vector<protobufs::AccessChainClampingInfo>&
  35. access_chain_clampers) {
  36. for (auto& instruction : instructions) {
  37. *message_.add_instruction() = instruction;
  38. }
  39. message_.set_is_livesafe(true);
  40. message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
  41. message_.set_loop_limit_constant_id(loop_limit_constant_id);
  42. for (auto& loop_limiter : loop_limiters) {
  43. *message_.add_loop_limiter_info() = loop_limiter;
  44. }
  45. message_.set_kill_unreachable_return_value_id(
  46. kill_unreachable_return_value_id);
  47. for (auto& access_clamper : access_chain_clampers) {
  48. *message_.add_access_chain_clamping_info() = access_clamper;
  49. }
  50. }
  51. bool TransformationAddFunction::IsApplicable(
  52. opt::IRContext* context,
  53. const spvtools::fuzz::FactManager& fact_manager) const {
  54. // This transformation may use a lot of ids, all of which need to be fresh
  55. // and distinct. This set tracks them.
  56. std::set<uint32_t> ids_used_by_this_transformation;
  57. // Ensure that all result ids in the new function are fresh and distinct.
  58. for (auto& instruction : message_.instruction()) {
  59. if (instruction.result_id()) {
  60. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  61. instruction.result_id(), context,
  62. &ids_used_by_this_transformation)) {
  63. return false;
  64. }
  65. }
  66. }
  67. if (message_.is_livesafe()) {
  68. // Ensure that all ids provided for making the function livesafe are fresh
  69. // and distinct.
  70. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  71. message_.loop_limiter_variable_id(), context,
  72. &ids_used_by_this_transformation)) {
  73. return false;
  74. }
  75. for (auto& loop_limiter_info : message_.loop_limiter_info()) {
  76. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  77. loop_limiter_info.load_id(), context,
  78. &ids_used_by_this_transformation)) {
  79. return false;
  80. }
  81. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  82. loop_limiter_info.increment_id(), context,
  83. &ids_used_by_this_transformation)) {
  84. return false;
  85. }
  86. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  87. loop_limiter_info.compare_id(), context,
  88. &ids_used_by_this_transformation)) {
  89. return false;
  90. }
  91. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  92. loop_limiter_info.logical_op_id(), context,
  93. &ids_used_by_this_transformation)) {
  94. return false;
  95. }
  96. }
  97. for (auto& access_chain_clamping_info :
  98. message_.access_chain_clamping_info()) {
  99. for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
  100. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  101. pair.first(), context, &ids_used_by_this_transformation)) {
  102. return false;
  103. }
  104. if (!CheckIdIsFreshAndNotUsedByThisTransformation(
  105. pair.second(), context, &ids_used_by_this_transformation)) {
  106. return false;
  107. }
  108. }
  109. }
  110. }
  111. // Because checking all the conditions for a function to be valid is a big
  112. // job that the SPIR-V validator can already do, a "try it and see" approach
  113. // is taken here.
  114. // We first clone the current module, so that we can try adding the new
  115. // function without risking wrecking |context|.
  116. auto cloned_module = fuzzerutil::CloneIRContext(context);
  117. // We try to add a function to the cloned module, which may fail if
  118. // |message_.instruction| is not sufficiently well-formed.
  119. if (!TryToAddFunction(cloned_module.get())) {
  120. return false;
  121. }
  122. if (message_.is_livesafe()) {
  123. // We make the cloned module livesafe.
  124. if (!TryToMakeFunctionLivesafe(cloned_module.get(), fact_manager)) {
  125. return false;
  126. }
  127. }
  128. // Having managed to add the new function to the cloned module, and
  129. // potentially also made it livesafe, we ascertain whether the cloned module
  130. // is still valid. If it is, the transformation is applicable.
  131. return fuzzerutil::IsValid(cloned_module.get());
  132. }
  133. void TransformationAddFunction::Apply(
  134. opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const {
  135. // Add the function to the module. As the transformation is applicable, this
  136. // should succeed.
  137. bool success = TryToAddFunction(context);
  138. assert(success && "The function should be successfully added.");
  139. (void)(success); // Keep release builds happy (otherwise they may complain
  140. // that |success| is not used).
  141. // Record the fact that all pointer parameters and variables declared in the
  142. // function should be regarded as having arbitrary values. This allows other
  143. // passes to store arbitrarily to such variables, and to pass them freely as
  144. // parameters to other functions knowing that it is OK if they get
  145. // over-written.
  146. for (auto& instruction : message_.instruction()) {
  147. switch (instruction.opcode()) {
  148. case SpvOpFunctionParameter:
  149. if (context->get_def_use_mgr()
  150. ->GetDef(instruction.result_type_id())
  151. ->opcode() == SpvOpTypePointer) {
  152. fact_manager->AddFactValueOfVariableIsArbitrary(
  153. instruction.result_id());
  154. }
  155. break;
  156. case SpvOpVariable:
  157. fact_manager->AddFactValueOfVariableIsArbitrary(
  158. instruction.result_id());
  159. break;
  160. default:
  161. break;
  162. }
  163. }
  164. if (message_.is_livesafe()) {
  165. // Make the function livesafe, which also should succeed.
  166. success = TryToMakeFunctionLivesafe(context, *fact_manager);
  167. assert(success && "It should be possible to make the function livesafe.");
  168. (void)(success); // Keep release builds happy.
  169. // Inform the fact manager that the function is livesafe.
  170. assert(message_.instruction(0).opcode() == SpvOpFunction &&
  171. "The first instruction of an 'add function' transformation must be "
  172. "OpFunction.");
  173. fact_manager->AddFactFunctionIsLivesafe(
  174. message_.instruction(0).result_id());
  175. } else {
  176. // Inform the fact manager that all blocks in the function are dead.
  177. for (auto& inst : message_.instruction()) {
  178. if (inst.opcode() == SpvOpLabel) {
  179. fact_manager->AddFactBlockIsDead(inst.result_id());
  180. }
  181. }
  182. }
  183. context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  184. }
  185. protobufs::Transformation TransformationAddFunction::ToMessage() const {
  186. protobufs::Transformation result;
  187. *result.mutable_add_function() = message_;
  188. return result;
  189. }
  190. bool TransformationAddFunction::TryToAddFunction(
  191. opt::IRContext* context) const {
  192. // This function returns false if |message_.instruction| was not well-formed
  193. // enough to actually create a function and add it to |context|.
  194. // A function must have at least some instructions.
  195. if (message_.instruction().empty()) {
  196. return false;
  197. }
  198. // A function must start with OpFunction.
  199. auto function_begin = message_.instruction(0);
  200. if (function_begin.opcode() != SpvOpFunction) {
  201. return false;
  202. }
  203. // Make a function, headed by the OpFunction instruction.
  204. std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
  205. InstructionFromMessage(context, function_begin));
  206. // Keeps track of which instruction protobuf message we are currently
  207. // considering.
  208. uint32_t instruction_index = 1;
  209. const auto num_instructions =
  210. static_cast<uint32_t>(message_.instruction().size());
  211. // Iterate through all function parameter instructions, adding parameters to
  212. // the new function.
  213. while (instruction_index < num_instructions &&
  214. message_.instruction(instruction_index).opcode() ==
  215. SpvOpFunctionParameter) {
  216. new_function->AddParameter(InstructionFromMessage(
  217. context, message_.instruction(instruction_index)));
  218. instruction_index++;
  219. }
  220. // After the parameters, there needs to be a label.
  221. if (instruction_index == num_instructions ||
  222. message_.instruction(instruction_index).opcode() != SpvOpLabel) {
  223. return false;
  224. }
  225. // Iterate through the instructions block by block until the end of the
  226. // function is reached.
  227. while (instruction_index < num_instructions &&
  228. message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
  229. // Invariant: we should always be at a label instruction at this point.
  230. assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);
  231. // Make a basic block using the label instruction, with the new function
  232. // as its parent.
  233. std::unique_ptr<opt::BasicBlock> block =
  234. MakeUnique<opt::BasicBlock>(InstructionFromMessage(
  235. context, message_.instruction(instruction_index)));
  236. block->SetParent(new_function.get());
  237. // Consider successive instructions until we hit another label or the end
  238. // of the function, adding each such instruction to the block.
  239. instruction_index++;
  240. while (instruction_index < num_instructions &&
  241. message_.instruction(instruction_index).opcode() !=
  242. SpvOpFunctionEnd &&
  243. message_.instruction(instruction_index).opcode() != SpvOpLabel) {
  244. block->AddInstruction(InstructionFromMessage(
  245. context, message_.instruction(instruction_index)));
  246. instruction_index++;
  247. }
  248. // Add the block to the new function.
  249. new_function->AddBasicBlock(std::move(block));
  250. }
  251. // Having considered all the blocks, we should be at the last instruction and
  252. // it needs to be OpFunctionEnd.
  253. if (instruction_index != num_instructions - 1 ||
  254. message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
  255. return false;
  256. }
  257. // Set the function's final instruction, add the function to the module and
  258. // report success.
  259. new_function->SetFunctionEnd(
  260. InstructionFromMessage(context, message_.instruction(instruction_index)));
  261. context->AddFunction(std::move(new_function));
  262. context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
  263. return true;
  264. }
  265. bool TransformationAddFunction::TryToMakeFunctionLivesafe(
  266. opt::IRContext* context, const FactManager& fact_manager) const {
  267. assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
  268. // Get a pointer to the added function.
  269. opt::Function* added_function = nullptr;
  270. for (auto& function : *context->module()) {
  271. if (function.result_id() == message_.instruction(0).result_id()) {
  272. added_function = &function;
  273. break;
  274. }
  275. }
  276. assert(added_function && "The added function should have been found.");
  277. if (!TryToAddLoopLimiters(context, added_function)) {
  278. // Adding loop limiters did not work; bail out.
  279. return false;
  280. }
  281. // Consider all the instructions in the function, and:
  282. // - attempt to replace OpKill and OpUnreachable with return instructions
  283. // - attempt to clamp access chains to be within bounds
  284. // - check that OpFunctionCall instructions are only to livesafe functions
  285. for (auto& block : *added_function) {
  286. for (auto& inst : block) {
  287. switch (inst.opcode()) {
  288. case SpvOpKill:
  289. case SpvOpUnreachable:
  290. if (!TryToTurnKillOrUnreachableIntoReturn(context, added_function,
  291. &inst)) {
  292. return false;
  293. }
  294. break;
  295. case SpvOpAccessChain:
  296. case SpvOpInBoundsAccessChain:
  297. if (!TryToClampAccessChainIndices(context, &inst)) {
  298. return false;
  299. }
  300. break;
  301. case SpvOpFunctionCall:
  302. // A livesafe function my only call other livesafe functions.
  303. if (!fact_manager.FunctionIsLivesafe(
  304. inst.GetSingleWordInOperand(0))) {
  305. return false;
  306. }
  307. default:
  308. break;
  309. }
  310. }
  311. }
  312. return true;
  313. }
  314. bool TransformationAddFunction::TryToAddLoopLimiters(
  315. opt::IRContext* context, opt::Function* added_function) const {
  316. // Collect up all the loop headers so that we can subsequently add loop
  317. // limiting logic.
  318. std::vector<opt::BasicBlock*> loop_headers;
  319. for (auto& block : *added_function) {
  320. if (block.IsLoopHeader()) {
  321. loop_headers.push_back(&block);
  322. }
  323. }
  324. if (loop_headers.empty()) {
  325. // There are no loops, so no need to add any loop limiters.
  326. return true;
  327. }
  328. // Check that the module contains appropriate ingredients for declaring and
  329. // manipulating a loop limiter.
  330. auto loop_limit_constant_id_instr =
  331. context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
  332. if (!loop_limit_constant_id_instr ||
  333. loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
  334. // The loop limit constant id instruction must exist and have an
  335. // appropriate opcode.
  336. return false;
  337. }
  338. auto loop_limit_type = context->get_def_use_mgr()->GetDef(
  339. loop_limit_constant_id_instr->type_id());
  340. if (loop_limit_type->opcode() != SpvOpTypeInt ||
  341. loop_limit_type->GetSingleWordInOperand(0) != 32) {
  342. // The type of the loop limit constant must be 32-bit integer. It
  343. // doesn't actually matter whether the integer is signed or not.
  344. return false;
  345. }
  346. // Find the id of the "unsigned int" type.
  347. opt::analysis::Integer unsigned_int_type(32, false);
  348. uint32_t unsigned_int_type_id =
  349. context->get_type_mgr()->GetId(&unsigned_int_type);
  350. if (!unsigned_int_type_id) {
  351. // Unsigned int is not available; we need this type in order to add loop
  352. // limiters.
  353. return false;
  354. }
  355. auto registered_unsigned_int_type =
  356. context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
  357. // Look for 0 of type unsigned int.
  358. opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
  359. {0});
  360. auto registered_zero = context->get_constant_mgr()->FindConstant(&zero);
  361. if (!registered_zero) {
  362. // We need 0 in order to be able to initialize loop limiters.
  363. return false;
  364. }
  365. uint32_t zero_id = context->get_constant_mgr()
  366. ->GetDefiningInstruction(registered_zero)
  367. ->result_id();
  368. // Look for 1 of type unsigned int.
  369. opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
  370. {1});
  371. auto registered_one = context->get_constant_mgr()->FindConstant(&one);
  372. if (!registered_one) {
  373. // We need 1 in order to be able to increment loop limiters.
  374. return false;
  375. }
  376. uint32_t one_id = context->get_constant_mgr()
  377. ->GetDefiningInstruction(registered_one)
  378. ->result_id();
  379. // Look for pointer-to-unsigned int type.
  380. opt::analysis::Pointer pointer_to_unsigned_int_type(
  381. registered_unsigned_int_type, SpvStorageClassFunction);
  382. uint32_t pointer_to_unsigned_int_type_id =
  383. context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
  384. if (!pointer_to_unsigned_int_type_id) {
  385. // We need pointer-to-unsigned int in order to declare the loop limiter
  386. // variable.
  387. return false;
  388. }
  389. // Look for bool type.
  390. opt::analysis::Bool bool_type;
  391. uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type);
  392. if (!bool_type_id) {
  393. // We need bool in order to compare the loop limiter's value with the loop
  394. // limit constant.
  395. return false;
  396. }
  397. // Declare the loop limiter variable at the start of the function's entry
  398. // block, via an instruction of the form:
  399. // %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
  400. added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
  401. context, SpvOpVariable, pointer_to_unsigned_int_type_id,
  402. message_.loop_limiter_variable_id(),
  403. opt::Instruction::OperandList(
  404. {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
  405. {SPV_OPERAND_TYPE_ID, {zero_id}}})));
  406. // Update the module's id bound since we have added the loop limiter
  407. // variable id.
  408. fuzzerutil::UpdateModuleIdBound(context, message_.loop_limiter_variable_id());
  409. // Consider each loop in turn.
  410. for (auto loop_header : loop_headers) {
  411. // Look for the loop's back-edge block. This is a predecessor of the loop
  412. // header that is dominated by the loop header.
  413. uint32_t back_edge_block_id = 0;
  414. for (auto pred : context->cfg()->preds(loop_header->id())) {
  415. if (context->GetDominatorAnalysis(added_function)
  416. ->Dominates(loop_header->id(), pred)) {
  417. back_edge_block_id = pred;
  418. break;
  419. }
  420. }
  421. if (!back_edge_block_id) {
  422. // The loop's back-edge block must be unreachable. This means that the
  423. // loop cannot iterate, so there is no need to make it lifesafe; we can
  424. // move on from this loop.
  425. continue;
  426. }
  427. auto back_edge_block = context->cfg()->block(back_edge_block_id);
  428. // Go through the sequence of loop limiter infos and find the one
  429. // corresponding to this loop.
  430. bool found = false;
  431. protobufs::LoopLimiterInfo loop_limiter_info;
  432. for (auto& info : message_.loop_limiter_info()) {
  433. if (info.loop_header_id() == loop_header->id()) {
  434. loop_limiter_info = info;
  435. found = true;
  436. break;
  437. }
  438. }
  439. if (!found) {
  440. // We don't have loop limiter info for this loop header.
  441. return false;
  442. }
  443. // The back-edge block either has the form:
  444. //
  445. // (1)
  446. //
  447. // %l = OpLabel
  448. // ... instructions ...
  449. // OpBranch %loop_header
  450. //
  451. // (2)
  452. //
  453. // %l = OpLabel
  454. // ... instructions ...
  455. // OpBranchConditional %c %loop_header %loop_merge
  456. //
  457. // (3)
  458. //
  459. // %l = OpLabel
  460. // ... instructions ...
  461. // OpBranchConditional %c %loop_merge %loop_header
  462. //
  463. // We turn these into the following:
  464. //
  465. // (1)
  466. //
  467. // %l = OpLabel
  468. // ... instructions ...
  469. // %t1 = OpLoad %uint32 %loop_limiter
  470. // %t2 = OpIAdd %uint32 %t1 %one
  471. // OpStore %loop_limiter %t2
  472. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  473. // OpBranchConditional %t3 %loop_merge %loop_header
  474. //
  475. // (2)
  476. //
  477. // %l = OpLabel
  478. // ... instructions ...
  479. // %t1 = OpLoad %uint32 %loop_limiter
  480. // %t2 = OpIAdd %uint32 %t1 %one
  481. // OpStore %loop_limiter %t2
  482. // %t3 = OpULessThan %bool %t1 %loop_limit
  483. // %t4 = OpLogicalAnd %bool %c %t3
  484. // OpBranchConditional %t4 %loop_header %loop_merge
  485. //
  486. // (3)
  487. //
  488. // %l = OpLabel
  489. // ... instructions ...
  490. // %t1 = OpLoad %uint32 %loop_limiter
  491. // %t2 = OpIAdd %uint32 %t1 %one
  492. // OpStore %loop_limiter %t2
  493. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  494. // %t4 = OpLogicalOr %bool %c %t3
  495. // OpBranchConditional %t4 %loop_merge %loop_header
  496. auto back_edge_block_terminator = back_edge_block->terminator();
  497. bool compare_using_greater_than_equal;
  498. if (back_edge_block_terminator->opcode() == SpvOpBranch) {
  499. compare_using_greater_than_equal = true;
  500. } else {
  501. assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
  502. assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
  503. loop_header->id() &&
  504. back_edge_block_terminator->GetSingleWordInOperand(2) ==
  505. loop_header->MergeBlockId()) ||
  506. (back_edge_block_terminator->GetSingleWordInOperand(2) ==
  507. loop_header->id() &&
  508. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  509. loop_header->MergeBlockId())) &&
  510. "A back edge edge block must branch to"
  511. " either the loop header or merge");
  512. compare_using_greater_than_equal =
  513. back_edge_block_terminator->GetSingleWordInOperand(1) ==
  514. loop_header->MergeBlockId();
  515. }
  516. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  517. // Add a load from the loop limiter variable, of the form:
  518. // %t1 = OpLoad %uint32 %loop_limiter
  519. new_instructions.push_back(MakeUnique<opt::Instruction>(
  520. context, SpvOpLoad, unsigned_int_type_id, loop_limiter_info.load_id(),
  521. opt::Instruction::OperandList(
  522. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
  523. // Increment the loaded value:
  524. // %t2 = OpIAdd %uint32 %t1 %one
  525. new_instructions.push_back(MakeUnique<opt::Instruction>(
  526. context, SpvOpIAdd, unsigned_int_type_id,
  527. loop_limiter_info.increment_id(),
  528. opt::Instruction::OperandList(
  529. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  530. {SPV_OPERAND_TYPE_ID, {one_id}}})));
  531. // Store the incremented value back to the loop limiter variable:
  532. // OpStore %loop_limiter %t2
  533. new_instructions.push_back(MakeUnique<opt::Instruction>(
  534. context, SpvOpStore, 0, 0,
  535. opt::Instruction::OperandList(
  536. {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
  537. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
  538. // Compare the loaded value with the loop limit; either:
  539. // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
  540. // or
  541. // %t3 = OpULessThan %bool %t1 %loop_limit
  542. new_instructions.push_back(MakeUnique<opt::Instruction>(
  543. context,
  544. compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
  545. : SpvOpULessThan,
  546. bool_type_id, loop_limiter_info.compare_id(),
  547. opt::Instruction::OperandList(
  548. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
  549. {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
  550. if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
  551. new_instructions.push_back(MakeUnique<opt::Instruction>(
  552. context,
  553. compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
  554. bool_type_id, loop_limiter_info.logical_op_id(),
  555. opt::Instruction::OperandList(
  556. {{SPV_OPERAND_TYPE_ID,
  557. {back_edge_block_terminator->GetSingleWordInOperand(0)}},
  558. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
  559. }
  560. // Add the new instructions at the end of the back edge block, before the
  561. // terminator and any loop merge instruction (as the back edge block can
  562. // be the loop header).
  563. if (back_edge_block->GetLoopMergeInst()) {
  564. back_edge_block->GetLoopMergeInst()->InsertBefore(
  565. std::move(new_instructions));
  566. } else {
  567. back_edge_block_terminator->InsertBefore(std::move(new_instructions));
  568. }
  569. if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
  570. back_edge_block_terminator->SetInOperand(
  571. 0, {loop_limiter_info.logical_op_id()});
  572. } else {
  573. assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
  574. "Back-edge terminator must be OpBranch or OpBranchConditional");
  575. // Check that, if the merge block starts with OpPhi instructions, suitable
  576. // ids have been provided to give these instructions a value corresponding
  577. // to the new incoming edge from the back edge block.
  578. auto merge_block = context->cfg()->block(loop_header->MergeBlockId());
  579. if (!fuzzerutil::PhiIdsOkForNewEdge(context, back_edge_block, merge_block,
  580. loop_limiter_info.phi_id())) {
  581. return false;
  582. }
  583. // Augment OpPhi instructions at the loop merge with the given ids.
  584. uint32_t phi_index = 0;
  585. for (auto& inst : *merge_block) {
  586. if (inst.opcode() != SpvOpPhi) {
  587. break;
  588. }
  589. assert(phi_index <
  590. static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
  591. "There should be at least one phi id per OpPhi instruction.");
  592. inst.AddOperand(
  593. {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
  594. inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
  595. phi_index++;
  596. }
  597. // Add the new edge, by changing OpBranch to OpBranchConditional.
  598. // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3162): This
  599. // could be a problem if the merge block was originally unreachable: it
  600. // might now be dominated by other blocks that it appears earlier than in
  601. // the module.
  602. back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
  603. back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
  604. {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
  605. {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}
  606. },
  607. {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
  608. }
  609. // Update the module's id bound with respect to the various ids that
  610. // have been used for loop limiter manipulation.
  611. fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.load_id());
  612. fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.increment_id());
  613. fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.compare_id());
  614. fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.logical_op_id());
  615. }
  616. return true;
  617. }
  618. bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
  619. opt::IRContext* context, opt::Function* added_function,
  620. opt::Instruction* kill_or_unreachable_inst) const {
  621. assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
  622. kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
  623. "Precondition: instruction must be OpKill or OpUnreachable.");
  624. // Get the function's return type.
  625. auto function_return_type_inst =
  626. context->get_def_use_mgr()->GetDef(added_function->type_id());
  627. if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
  628. // The function has void return type, so change this instruction to
  629. // OpReturn.
  630. kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
  631. } else {
  632. // The function has non-void return type, so change this instruction
  633. // to OpReturnValue, using the value id provided with the
  634. // transformation.
  635. // We first check that the id, %id, provided with the transformation
  636. // specifically to turn OpKill and OpUnreachable instructions into
  637. // OpReturnValue %id has the same type as the function's return type.
  638. if (context->get_def_use_mgr()
  639. ->GetDef(message_.kill_unreachable_return_value_id())
  640. ->type_id() != function_return_type_inst->result_id()) {
  641. return false;
  642. }
  643. kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
  644. kill_or_unreachable_inst->SetInOperands(
  645. {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
  646. }
  647. return true;
  648. }
  649. bool TransformationAddFunction::TryToClampAccessChainIndices(
  650. opt::IRContext* context, opt::Instruction* access_chain_inst) const {
  651. assert((access_chain_inst->opcode() == SpvOpAccessChain ||
  652. access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
  653. "Precondition: instruction must be OpAccessChain or "
  654. "OpInBoundsAccessChain.");
  655. // Find the AccessChainClampingInfo associated with this access chain.
  656. const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
  657. nullptr;
  658. for (auto& clamping_info : message_.access_chain_clamping_info()) {
  659. if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
  660. access_chain_clamping_info = &clamping_info;
  661. break;
  662. }
  663. }
  664. if (!access_chain_clamping_info) {
  665. // No access chain clamping information was found; the function cannot be
  666. // made livesafe.
  667. return false;
  668. }
  669. // Check that there is a (compare_id, select_id) pair for every
  670. // index associated with the instruction.
  671. if (static_cast<uint32_t>(
  672. access_chain_clamping_info->compare_and_select_ids().size()) !=
  673. access_chain_inst->NumInOperands() - 1) {
  674. return false;
  675. }
  676. // Walk the access chain, clamping each index to be within bounds if it is
  677. // not a constant.
  678. auto base_object = context->get_def_use_mgr()->GetDef(
  679. access_chain_inst->GetSingleWordInOperand(0));
  680. assert(base_object && "The base object must exist.");
  681. auto pointer_type =
  682. context->get_def_use_mgr()->GetDef(base_object->type_id());
  683. assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
  684. "The base object must have pointer type.");
  685. auto should_be_composite_type = context->get_def_use_mgr()->GetDef(
  686. pointer_type->GetSingleWordInOperand(1));
  687. // Consider each index input operand in turn (operand 0 is the base object).
  688. for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
  689. index++) {
  690. // We are going to turn:
  691. //
  692. // %result = OpAccessChain %type %object ... %index ...
  693. //
  694. // into:
  695. //
  696. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  697. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  698. // %result = OpAccessChain %type %object ... %t2 ...
  699. //
  700. // ... unless %index is already a constant.
  701. // Get the bound for the composite being indexed into; e.g. the number of
  702. // columns of matrix or the size of an array.
  703. uint32_t bound =
  704. GetBoundForCompositeIndex(context, *should_be_composite_type);
  705. // Get the instruction associated with the index and figure out its integer
  706. // type.
  707. const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
  708. auto index_inst = context->get_def_use_mgr()->GetDef(index_id);
  709. auto index_type_inst =
  710. context->get_def_use_mgr()->GetDef(index_inst->type_id());
  711. assert(index_type_inst->opcode() == SpvOpTypeInt);
  712. assert(index_type_inst->GetSingleWordInOperand(0) == 32);
  713. opt::analysis::Integer* index_int_type =
  714. context->get_type_mgr()
  715. ->GetType(index_type_inst->result_id())
  716. ->AsInteger();
  717. if (index_inst->opcode() != SpvOpConstant) {
  718. // The index is non-constant so we need to clamp it.
  719. assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
  720. "Access chain indices into structures are required to be "
  721. "constants.");
  722. opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
  723. if (!context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
  724. // We do not have an integer constant whose value is |bound| -1.
  725. return false;
  726. }
  727. opt::analysis::Bool bool_type;
  728. uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type);
  729. if (!bool_type_id) {
  730. // Bool type is not declared; we cannot do a comparison.
  731. return false;
  732. }
  733. uint32_t bound_minus_one_id =
  734. context->get_constant_mgr()
  735. ->GetDefiningInstruction(&bound_minus_one)
  736. ->result_id();
  737. uint32_t compare_id =
  738. access_chain_clamping_info->compare_and_select_ids(index - 1).first();
  739. uint32_t select_id =
  740. access_chain_clamping_info->compare_and_select_ids(index - 1)
  741. .second();
  742. std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
  743. // Compare the index with the bound via an instruction of the form:
  744. // %t1 = OpULessThanEqual %bool %index %bound_minus_one
  745. new_instructions.push_back(MakeUnique<opt::Instruction>(
  746. context, SpvOpULessThanEqual, bool_type_id, compare_id,
  747. opt::Instruction::OperandList(
  748. {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  749. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  750. // Select the index if in-bounds, otherwise one less than the bound:
  751. // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
  752. new_instructions.push_back(MakeUnique<opt::Instruction>(
  753. context, SpvOpSelect, index_type_inst->result_id(), select_id,
  754. opt::Instruction::OperandList(
  755. {{SPV_OPERAND_TYPE_ID, {compare_id}},
  756. {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
  757. {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
  758. // Add the new instructions before the access chain
  759. access_chain_inst->InsertBefore(std::move(new_instructions));
  760. // Replace %index with %t2.
  761. access_chain_inst->SetInOperand(index, {select_id});
  762. fuzzerutil::UpdateModuleIdBound(context, compare_id);
  763. fuzzerutil::UpdateModuleIdBound(context, select_id);
  764. } else {
  765. // TODO(afd): At present the SPIR-V spec is not clear on whether
  766. // statically out-of-bounds indices mean that a module is invalid (so
  767. // that it should be rejected by the validator), or that such accesses
  768. // yield undefined results. Via the following assertion, we assume that
  769. // functions added to the module do not feature statically out-of-bounds
  770. // accesses.
  771. // Assert that the index is smaller (unsigned) than this value.
  772. // Return false if it is not (to keep compilers happy).
  773. if (index_inst->GetSingleWordInOperand(0) >= bound) {
  774. assert(false &&
  775. "The function has a statically out-of-bounds access; "
  776. "this should not occur.");
  777. return false;
  778. }
  779. }
  780. should_be_composite_type =
  781. FollowCompositeIndex(context, *should_be_composite_type, index_id);
  782. }
  783. return true;
  784. }
  785. uint32_t TransformationAddFunction::GetBoundForCompositeIndex(
  786. opt::IRContext* context, const opt::Instruction& composite_type_inst) {
  787. switch (composite_type_inst.opcode()) {
  788. case SpvOpTypeArray:
  789. return fuzzerutil::GetArraySize(composite_type_inst, context);
  790. case SpvOpTypeMatrix:
  791. case SpvOpTypeVector:
  792. return composite_type_inst.GetSingleWordInOperand(1);
  793. case SpvOpTypeStruct: {
  794. return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
  795. }
  796. default:
  797. assert(false && "Unknown composite type.");
  798. return 0;
  799. }
  800. }
  801. opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
  802. opt::IRContext* context, const opt::Instruction& composite_type_inst,
  803. uint32_t index_id) {
  804. uint32_t sub_object_type_id;
  805. switch (composite_type_inst.opcode()) {
  806. case SpvOpTypeArray:
  807. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  808. break;
  809. case SpvOpTypeMatrix:
  810. case SpvOpTypeVector:
  811. sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
  812. break;
  813. case SpvOpTypeStruct: {
  814. auto index_inst = context->get_def_use_mgr()->GetDef(index_id);
  815. assert(index_inst->opcode() == SpvOpConstant);
  816. assert(
  817. context->get_def_use_mgr()->GetDef(index_inst->type_id())->opcode() ==
  818. SpvOpTypeInt);
  819. assert(context->get_def_use_mgr()
  820. ->GetDef(index_inst->type_id())
  821. ->GetSingleWordInOperand(0) == 32);
  822. uint32_t index_value = index_inst->GetSingleWordInOperand(0);
  823. sub_object_type_id =
  824. composite_type_inst.GetSingleWordInOperand(index_value);
  825. break;
  826. }
  827. default:
  828. assert(false && "Unknown composite type.");
  829. sub_object_type_id = 0;
  830. break;
  831. }
  832. assert(sub_object_type_id && "No sub-object found.");
  833. return context->get_def_use_mgr()->GetDef(sub_object_type_id);
  834. }
  835. } // namespace fuzz
  836. } // namespace spvtools