fix_func_call_arguments.cpp 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. // Copyright (c) 2022 Advanced Micro Devices, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "fix_func_call_arguments.h"
  15. #include "ir_builder.h"
  16. using namespace spvtools;
  17. using namespace opt;
  18. bool FixFuncCallArgumentsPass::ModuleHasASingleFunction() {
  19. auto funcsNum = get_module()->end() - get_module()->begin();
  20. return funcsNum == 1;
  21. }
  22. Pass::Status FixFuncCallArgumentsPass::Process() {
  23. bool modified = false;
  24. if (ModuleHasASingleFunction()) return Status::SuccessWithoutChange;
  25. for (auto& func : *get_module()) {
  26. func.ForEachInst([this, &modified](Instruction* inst) {
  27. if (inst->opcode() == SpvOpFunctionCall) {
  28. modified |= FixFuncCallArguments(inst);
  29. }
  30. });
  31. }
  32. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  33. }
  34. bool FixFuncCallArgumentsPass::FixFuncCallArguments(
  35. Instruction* func_call_inst) {
  36. bool modified = false;
  37. for (uint32_t i = 0; i < func_call_inst->NumInOperands(); ++i) {
  38. Operand& op = func_call_inst->GetInOperand(i);
  39. if (op.type != SPV_OPERAND_TYPE_ID) continue;
  40. Instruction* operand_inst = get_def_use_mgr()->GetDef(op.AsId());
  41. if (operand_inst->opcode() == SpvOpAccessChain) {
  42. uint32_t var_id =
  43. ReplaceAccessChainFuncCallArguments(func_call_inst, operand_inst);
  44. func_call_inst->SetInOperand(i, {var_id});
  45. modified = true;
  46. }
  47. }
  48. if (modified) {
  49. context()->UpdateDefUse(func_call_inst);
  50. }
  51. return modified;
  52. }
  53. uint32_t FixFuncCallArgumentsPass::ReplaceAccessChainFuncCallArguments(
  54. Instruction* func_call_inst, Instruction* operand_inst) {
  55. InstructionBuilder builder(
  56. context(), func_call_inst,
  57. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  58. Instruction* next_insert_point = func_call_inst->NextNode();
  59. // Get Variable insertion point
  60. Function* func = context()->get_instr_block(func_call_inst)->GetParent();
  61. Instruction* variable_insertion_point = &*(func->begin()->begin());
  62. Instruction* op_ptr_type = get_def_use_mgr()->GetDef(operand_inst->type_id());
  63. Instruction* op_type =
  64. get_def_use_mgr()->GetDef(op_ptr_type->GetSingleWordInOperand(1));
  65. uint32_t varType = context()->get_type_mgr()->FindPointerToType(
  66. op_type->result_id(), SpvStorageClassFunction);
  67. // Create new variable
  68. builder.SetInsertPoint(variable_insertion_point);
  69. Instruction* var = builder.AddVariable(varType, SpvStorageClassFunction);
  70. // Load access chain to the new variable before function call
  71. builder.SetInsertPoint(func_call_inst);
  72. uint32_t operand_id = operand_inst->result_id();
  73. Instruction* load = builder.AddLoad(op_type->result_id(), operand_id);
  74. builder.AddStore(var->result_id(), load->result_id());
  75. // Load return value to the acesschain after function call
  76. builder.SetInsertPoint(next_insert_point);
  77. load = builder.AddLoad(op_type->result_id(), var->result_id());
  78. builder.AddStore(operand_id, load->result_id());
  79. return var->result_id();
  80. }