opextinst_forward_ref_fixup_pass.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. // Copyright (c) 2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/opextinst_forward_ref_fixup_pass.h"
  15. #include <string>
  16. #include <unordered_set>
  17. #include "source/extensions.h"
  18. #include "source/opt/ir_context.h"
  19. #include "source/opt/module.h"
  20. #include "type_manager.h"
  21. namespace spvtools {
  22. namespace opt {
  23. namespace {
  24. // Returns true if the instruction |inst| has a forward reference to another
  25. // debug instruction.
  26. // |debug_ids| contains the list of IDs belonging to debug instructions.
  27. // |seen_ids| contains the list of IDs already seen.
  28. bool HasForwardReference(const Instruction& inst,
  29. const std::unordered_set<uint32_t>& debug_ids,
  30. const std::unordered_set<uint32_t>& seen_ids) {
  31. const uint32_t num_in_operands = inst.NumInOperands();
  32. for (uint32_t i = 0; i < num_in_operands; ++i) {
  33. const Operand& op = inst.GetInOperand(i);
  34. if (!spvIsIdType(op.type)) continue;
  35. if (debug_ids.count(op.AsId()) == 0) continue;
  36. if (seen_ids.count(op.AsId()) == 0) return true;
  37. }
  38. return false;
  39. }
  40. // Replace |inst| opcode with OpExtInstWithForwardRefsKHR or OpExtInst
  41. // if required to comply with forward references.
  42. bool ReplaceOpcodeIfRequired(Instruction& inst, bool hasForwardReferences) {
  43. if (hasForwardReferences &&
  44. inst.opcode() != spv::Op::OpExtInstWithForwardRefsKHR)
  45. inst.SetOpcode(spv::Op::OpExtInstWithForwardRefsKHR);
  46. else if (!hasForwardReferences && inst.opcode() != spv::Op::OpExtInst)
  47. inst.SetOpcode(spv::Op::OpExtInst);
  48. else
  49. return false;
  50. return true;
  51. }
  52. // Returns all the result IDs of the instructions in |range|.
  53. std::unordered_set<uint32_t> gatherResultIds(
  54. const IteratorRange<Module::inst_iterator>& range) {
  55. std::unordered_set<uint32_t> output;
  56. for (const auto& it : range) output.insert(it.result_id());
  57. return output;
  58. }
  59. } // namespace
  60. Pass::Status OpExtInstWithForwardReferenceFixupPass::Process() {
  61. std::unordered_set<uint32_t> seen_ids =
  62. gatherResultIds(get_module()->ext_inst_imports());
  63. std::unordered_set<uint32_t> debug_ids =
  64. gatherResultIds(get_module()->ext_inst_debuginfo());
  65. for (uint32_t id : seen_ids) debug_ids.insert(id);
  66. bool moduleChanged = false;
  67. bool hasAtLeastOneForwardReference = false;
  68. IRContext* ctx = context();
  69. for (Instruction& inst : get_module()->ext_inst_debuginfo()) {
  70. if (inst.opcode() != spv::Op::OpExtInst &&
  71. inst.opcode() != spv::Op::OpExtInstWithForwardRefsKHR)
  72. continue;
  73. seen_ids.insert(inst.result_id());
  74. bool hasForwardReferences = HasForwardReference(inst, debug_ids, seen_ids);
  75. hasAtLeastOneForwardReference |= hasForwardReferences;
  76. if (ReplaceOpcodeIfRequired(inst, hasForwardReferences)) {
  77. moduleChanged = true;
  78. ctx->AnalyzeUses(&inst);
  79. }
  80. }
  81. if (hasAtLeastOneForwardReference !=
  82. ctx->get_feature_mgr()->HasExtension(
  83. kSPV_KHR_relaxed_extended_instruction)) {
  84. if (hasAtLeastOneForwardReference)
  85. ctx->AddExtension("SPV_KHR_relaxed_extended_instruction");
  86. else
  87. ctx->RemoveExtension(Extension::kSPV_KHR_relaxed_extended_instruction);
  88. moduleChanged = true;
  89. }
  90. return moduleChanged ? Status::SuccessWithChange
  91. : Status::SuccessWithoutChange;
  92. }
  93. } // namespace opt
  94. } // namespace spvtools