fix_func_call_arguments.cpp 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. // Copyright (c) 2022 Advanced Micro Devices, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "fix_func_call_arguments.h"
  15. #include "ir_builder.h"
  16. using namespace spvtools;
  17. using namespace opt;
  18. bool FixFuncCallArgumentsPass::ModuleHasASingleFunction() {
  19. auto funcsNum = get_module()->end() - get_module()->begin();
  20. return funcsNum == 1;
  21. }
  22. Pass::Status FixFuncCallArgumentsPass::Process() {
  23. bool modified = false;
  24. if (ModuleHasASingleFunction()) return Status::SuccessWithoutChange;
  25. for (auto& func : *get_module()) {
  26. func.ForEachInst([this, &modified](Instruction* inst) {
  27. if (inst->opcode() == spv::Op::OpFunctionCall) {
  28. modified |= FixFuncCallArguments(inst);
  29. }
  30. });
  31. }
  32. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  33. }
  34. bool FixFuncCallArgumentsPass::FixFuncCallArguments(
  35. Instruction* func_call_inst) {
  36. bool modified = false;
  37. for (uint32_t i = 0; i < func_call_inst->NumInOperands(); ++i) {
  38. Operand& op = func_call_inst->GetInOperand(i);
  39. if (op.type != SPV_OPERAND_TYPE_ID) continue;
  40. Instruction* operand_inst = get_def_use_mgr()->GetDef(op.AsId());
  41. if (operand_inst->opcode() == spv::Op::OpAccessChain) {
  42. uint32_t var_id =
  43. ReplaceAccessChainFuncCallArguments(func_call_inst, operand_inst);
  44. func_call_inst->SetInOperand(i, {var_id});
  45. modified = true;
  46. }
  47. }
  48. if (modified) {
  49. context()->UpdateDefUse(func_call_inst);
  50. }
  51. return modified;
  52. }
  53. uint32_t FixFuncCallArgumentsPass::ReplaceAccessChainFuncCallArguments(
  54. Instruction* func_call_inst, Instruction* operand_inst) {
  55. InstructionBuilder builder(
  56. context(), func_call_inst,
  57. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  58. Instruction* next_insert_point = func_call_inst->NextNode();
  59. // Get Variable insertion point
  60. Function* func = context()->get_instr_block(func_call_inst)->GetParent();
  61. Instruction* variable_insertion_point = &*(func->begin()->begin());
  62. Instruction* op_ptr_type = get_def_use_mgr()->GetDef(operand_inst->type_id());
  63. Instruction* op_type =
  64. get_def_use_mgr()->GetDef(op_ptr_type->GetSingleWordInOperand(1));
  65. uint32_t varType = context()->get_type_mgr()->FindPointerToType(
  66. op_type->result_id(), spv::StorageClass::Function);
  67. // Create new variable
  68. builder.SetInsertPoint(variable_insertion_point);
  69. Instruction* var =
  70. builder.AddVariable(varType, uint32_t(spv::StorageClass::Function));
  71. // Load access chain to the new variable before function call
  72. builder.SetInsertPoint(func_call_inst);
  73. uint32_t operand_id = operand_inst->result_id();
  74. Instruction* load = builder.AddLoad(op_type->result_id(), operand_id);
  75. builder.AddStore(var->result_id(), load->result_id());
  76. // Load return value to the acesschain after function call
  77. builder.SetInsertPoint(next_insert_point);
  78. load = builder.AddLoad(op_type->result_id(), var->result_id());
  79. builder.AddStore(operand_id, load->result_id());
  80. return var->result_id();
  81. }