private_to_local_pass.cpp 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/private_to_local_pass.h"
  15. #include <memory>
  16. #include <utility>
  17. #include <vector>
  18. #include "source/opt/ir_context.h"
  19. #include "source/spirv_constant.h"
  20. namespace spvtools {
  21. namespace opt {
  22. namespace {
  23. constexpr uint32_t kVariableStorageClassInIdx = 0;
  24. constexpr uint32_t kSpvTypePointerTypeIdInIdx = 1;
  25. } // namespace
  26. Pass::Status PrivateToLocalPass::Process() {
  27. bool modified = false;
  28. // Private variables require the shader capability. If this is not a shader,
  29. // there is no work to do.
  30. if (context()->get_feature_mgr()->HasCapability(spv::Capability::Addresses))
  31. return Status::SuccessWithoutChange;
  32. std::vector<std::pair<Instruction*, Function*>> variables_to_move;
  33. std::unordered_set<uint32_t> localized_variables;
  34. for (auto& inst : context()->types_values()) {
  35. if (inst.opcode() != spv::Op::OpVariable) {
  36. continue;
  37. }
  38. if (spv::StorageClass(inst.GetSingleWordInOperand(
  39. kVariableStorageClassInIdx)) != spv::StorageClass::Private) {
  40. continue;
  41. }
  42. Function* target_function = FindLocalFunction(inst);
  43. if (target_function != nullptr) {
  44. variables_to_move.push_back({&inst, target_function});
  45. }
  46. }
  47. modified = !variables_to_move.empty();
  48. for (auto p : variables_to_move) {
  49. if (!MoveVariable(p.first, p.second)) {
  50. return Status::Failure;
  51. }
  52. localized_variables.insert(p.first->result_id());
  53. }
  54. if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
  55. // In SPIR-V 1.4 and later entry points must list private storage class
  56. // variables that are statically used by the entry point. Go through the
  57. // entry points and remove any references to variables that were localized.
  58. for (auto& entry : get_module()->entry_points()) {
  59. std::vector<Operand> new_operands;
  60. for (uint32_t i = 0; i < entry.NumInOperands(); ++i) {
  61. // Execution model, function id and name are always kept.
  62. if (i < 3 ||
  63. !localized_variables.count(entry.GetSingleWordInOperand(i))) {
  64. new_operands.push_back(entry.GetInOperand(i));
  65. }
  66. }
  67. if (new_operands.size() != entry.NumInOperands()) {
  68. entry.SetInOperands(std::move(new_operands));
  69. context()->AnalyzeUses(&entry);
  70. }
  71. }
  72. }
  73. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  74. }
  75. Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const {
  76. bool found_first_use = false;
  77. Function* target_function = nullptr;
  78. context()->get_def_use_mgr()->ForEachUser(
  79. inst.result_id(),
  80. [&target_function, &found_first_use, this](Instruction* use) {
  81. BasicBlock* current_block = context()->get_instr_block(use);
  82. if (current_block == nullptr) {
  83. return;
  84. }
  85. if (!IsValidUse(use)) {
  86. found_first_use = true;
  87. target_function = nullptr;
  88. return;
  89. }
  90. Function* current_function = current_block->GetParent();
  91. if (!found_first_use) {
  92. found_first_use = true;
  93. target_function = current_function;
  94. } else if (target_function != current_function) {
  95. target_function = nullptr;
  96. }
  97. });
  98. return target_function;
  99. } // namespace opt
  100. bool PrivateToLocalPass::MoveVariable(Instruction* variable,
  101. Function* function) {
  102. // The variable needs to be removed from the global section, and placed in the
  103. // header of the function. First step remove from the global list.
  104. variable->RemoveFromList();
  105. std::unique_ptr<Instruction> var(variable); // Take ownership.
  106. context()->ForgetUses(variable);
  107. // Update the storage class of the variable.
  108. variable->SetInOperand(kVariableStorageClassInIdx,
  109. {uint32_t(spv::StorageClass::Function)});
  110. // Update the type as well.
  111. uint32_t new_type_id = GetNewType(variable->type_id());
  112. if (new_type_id == 0) {
  113. return false;
  114. }
  115. variable->SetResultType(new_type_id);
  116. // Place the variable at the start of the first basic block.
  117. context()->AnalyzeUses(variable);
  118. context()->set_instr_block(variable, &*function->begin());
  119. function->begin()->begin()->InsertBefore(std::move(var));
  120. // Update uses where the type may have changed.
  121. return UpdateUses(variable);
  122. }
  123. uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) {
  124. auto type_mgr = context()->get_type_mgr();
  125. Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id);
  126. uint32_t pointee_type_id =
  127. old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
  128. uint32_t new_type_id =
  129. type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::Function);
  130. if (new_type_id != 0) {
  131. context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id));
  132. }
  133. return new_type_id;
  134. }
  135. bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const {
  136. // The cases in this switch have to match the cases in |UpdateUse|.
  137. // If we don't know how to update it, it is not valid.
  138. if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugGlobalVariable) {
  139. return true;
  140. }
  141. switch (inst->opcode()) {
  142. case spv::Op::OpLoad:
  143. case spv::Op::OpStore:
  144. case spv::Op::OpImageTexelPointer: // Treat like a load
  145. return true;
  146. case spv::Op::OpAccessChain:
  147. return context()->get_def_use_mgr()->WhileEachUser(
  148. inst, [this](const Instruction* user) {
  149. if (!IsValidUse(user)) return false;
  150. return true;
  151. });
  152. case spv::Op::OpName:
  153. return true;
  154. default:
  155. return spvOpcodeIsDecoration(inst->opcode());
  156. }
  157. }
  158. bool PrivateToLocalPass::UpdateUse(Instruction* inst, Instruction* user) {
  159. // The cases in this switch have to match the cases in |IsValidUse|. If we
  160. // don't think it is valid, the optimization will not view the variable as a
  161. // candidate, and therefore the use will not be updated.
  162. if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugGlobalVariable) {
  163. context()->get_debug_info_mgr()->ConvertDebugGlobalToLocalVariable(inst,
  164. user);
  165. return true;
  166. }
  167. switch (inst->opcode()) {
  168. case spv::Op::OpLoad:
  169. case spv::Op::OpStore:
  170. case spv::Op::OpImageTexelPointer: // Treat like a load
  171. // The type is fine because it is the type pointed to, and that does not
  172. // change.
  173. break;
  174. case spv::Op::OpAccessChain: {
  175. context()->ForgetUses(inst);
  176. uint32_t new_type_id = GetNewType(inst->type_id());
  177. if (new_type_id == 0) {
  178. return false;
  179. }
  180. inst->SetResultType(new_type_id);
  181. context()->AnalyzeUses(inst);
  182. // Update uses where the type may have changed.
  183. if (!UpdateUses(inst)) {
  184. return false;
  185. }
  186. } break;
  187. case spv::Op::OpName:
  188. case spv::Op::OpEntryPoint: // entry points will be updated separately.
  189. break;
  190. default:
  191. assert(spvOpcodeIsDecoration(inst->opcode()) &&
  192. "Do not know how to update the type for this instruction.");
  193. break;
  194. }
  195. return true;
  196. }
  197. bool PrivateToLocalPass::UpdateUses(Instruction* inst) {
  198. uint32_t id = inst->result_id();
  199. std::vector<Instruction*> uses;
  200. context()->get_def_use_mgr()->ForEachUser(
  201. id, [&uses](Instruction* use) { uses.push_back(use); });
  202. for (Instruction* use : uses) {
  203. if (!UpdateUse(use, inst)) {
  204. return false;
  205. }
  206. }
  207. return true;
  208. }
  209. } // namespace opt
  210. } // namespace spvtools