pass.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. // Copyright (c) 2017 The Khronos Group Inc.
  2. // Copyright (c) 2017 Valve Corporation
  3. // Copyright (c) 2017 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "source/opt/pass.h"
  17. #include "source/opt/iterator.h"
  18. namespace spvtools {
  19. namespace opt {
  20. namespace {
  21. const uint32_t kEntryPointFunctionIdInIdx = 1;
  22. const uint32_t kTypePointerTypeIdInIdx = 1;
  23. } // namespace
  24. Pass::Pass() : consumer_(nullptr), context_(nullptr), already_run_(false) {}
  25. void Pass::AddCalls(Function* func, std::queue<uint32_t>* todo) {
  26. for (auto bi = func->begin(); bi != func->end(); ++bi)
  27. for (auto ii = bi->begin(); ii != bi->end(); ++ii)
  28. if (ii->opcode() == SpvOpFunctionCall)
  29. todo->push(ii->GetSingleWordInOperand(0));
  30. }
  31. bool Pass::ProcessEntryPointCallTree(ProcessFunction& pfn, Module* module) {
  32. // Map from function's result id to function
  33. std::unordered_map<uint32_t, Function*> id2function;
  34. for (auto& fn : *module) id2function[fn.result_id()] = &fn;
  35. // Collect all of the entry points as the roots.
  36. std::queue<uint32_t> roots;
  37. for (auto& e : module->entry_points())
  38. roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
  39. return ProcessCallTreeFromRoots(pfn, id2function, &roots);
  40. }
  41. bool Pass::ProcessReachableCallTree(ProcessFunction& pfn,
  42. IRContext* irContext) {
  43. // Map from function's result id to function
  44. std::unordered_map<uint32_t, Function*> id2function;
  45. for (auto& fn : *irContext->module()) id2function[fn.result_id()] = &fn;
  46. std::queue<uint32_t> roots;
  47. // Add all entry points since they can be reached from outside the module.
  48. for (auto& e : irContext->module()->entry_points())
  49. roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
  50. // Add all exported functions since they can be reached from outside the
  51. // module.
  52. for (auto& a : irContext->annotations()) {
  53. // TODO: Handle group decorations as well. Currently not generate by any
  54. // front-end, but could be coming.
  55. if (a.opcode() == SpvOp::SpvOpDecorate) {
  56. if (a.GetSingleWordOperand(1) ==
  57. SpvDecoration::SpvDecorationLinkageAttributes) {
  58. uint32_t lastOperand = a.NumOperands() - 1;
  59. if (a.GetSingleWordOperand(lastOperand) ==
  60. SpvLinkageType::SpvLinkageTypeExport) {
  61. uint32_t id = a.GetSingleWordOperand(0);
  62. if (id2function.count(id) != 0) roots.push(id);
  63. }
  64. }
  65. }
  66. }
  67. return ProcessCallTreeFromRoots(pfn, id2function, &roots);
  68. }
  69. bool Pass::ProcessCallTreeFromRoots(
  70. ProcessFunction& pfn,
  71. const std::unordered_map<uint32_t, Function*>& id2function,
  72. std::queue<uint32_t>* roots) {
  73. // Process call tree
  74. bool modified = false;
  75. std::unordered_set<uint32_t> done;
  76. while (!roots->empty()) {
  77. const uint32_t fi = roots->front();
  78. roots->pop();
  79. if (done.insert(fi).second) {
  80. Function* fn = id2function.at(fi);
  81. modified = pfn(fn) || modified;
  82. AddCalls(fn, roots);
  83. }
  84. }
  85. return modified;
  86. }
  87. Pass::Status Pass::Run(IRContext* ctx) {
  88. if (already_run_) {
  89. return Status::Failure;
  90. }
  91. already_run_ = true;
  92. context_ = ctx;
  93. Pass::Status status = Process();
  94. context_ = nullptr;
  95. if (status == Status::SuccessWithChange) {
  96. ctx->InvalidateAnalysesExceptFor(GetPreservedAnalyses());
  97. }
  98. assert(ctx->IsConsistent());
  99. return status;
  100. }
  101. uint32_t Pass::GetPointeeTypeId(const Instruction* ptrInst) const {
  102. const uint32_t ptrTypeId = ptrInst->type_id();
  103. const Instruction* ptrTypeInst = get_def_use_mgr()->GetDef(ptrTypeId);
  104. return ptrTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
  105. }
  106. } // namespace opt
  107. } // namespace spvtools