resolve_binding_conflicts_pass.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. // Copyright (c) 2025 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/resolve_binding_conflicts_pass.h"
  15. #include <algorithm>
  16. #include <unordered_map>
  17. #include <unordered_set>
  18. #include <vector>
  19. #include "source/opt/decoration_manager.h"
  20. #include "source/opt/def_use_manager.h"
  21. #include "source/opt/instruction.h"
  22. #include "source/opt/ir_builder.h"
  23. #include "source/opt/ir_context.h"
  24. #include "spirv/unified1/spirv.h"
  25. namespace spvtools {
  26. namespace opt {
  27. // A VarBindingInfo contains the binding information for a single resource
  28. // variable.
  29. //
  30. // Exactly one such object is created per resource variable in the
  31. // module. In particular, when a resource variable is statically used by
  32. // more than one entry point, those entry points share the same VarBindingInfo
  33. // object for that variable.
  34. struct VarBindingInfo {
  35. const Instruction* const var;
  36. const uint32_t descriptor_set;
  37. Instruction* const binding_decoration;
  38. // Returns the binding number.
  39. uint32_t binding() const {
  40. return binding_decoration->GetSingleWordInOperand(2);
  41. }
  42. // Sets the binding number to 'b'.
  43. void updateBinding(uint32_t b) { binding_decoration->SetOperand(2, {b}); }
  44. };
  45. // The bindings in the same descriptor set that are used by an entry point.
  46. using BindingList = std::vector<VarBindingInfo*>;
  47. // A map from descriptor set number to the list of bindings in that descriptor
  48. // set, as used by a particular entry point.
  49. using DescriptorSets = std::unordered_map<uint32_t, BindingList>;
  50. IRContext::Analysis ResolveBindingConflictsPass::GetPreservedAnalyses() {
  51. // All analyses are kept up to date.
  52. // At most this modifies the Binding numbers on variables.
  53. return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping |
  54. IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
  55. IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis |
  56. IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap |
  57. IRContext::kAnalysisScalarEvolution |
  58. IRContext::kAnalysisRegisterPressure |
  59. IRContext::kAnalysisValueNumberTable |
  60. IRContext::kAnalysisStructuredCFG | IRContext::kAnalysisBuiltinVarId |
  61. IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisConstants |
  62. IRContext::kAnalysisTypes | IRContext::kAnalysisDebugInfo |
  63. IRContext::kAnalysisLiveness;
  64. }
  65. // Orders variable binding info objects.
  66. // * The binding number is most signficant;
  67. // * Then a sampler-like object compares greater than non-sampler like object.
  68. // * Otherwise compare based on variable ID.
  69. // This provides a total order among bindings in a descriptor set for a valid
  70. // Vulkan module.
  71. bool Less(const VarBindingInfo* const lhs, const VarBindingInfo* const rhs) {
  72. if (lhs->binding() < rhs->binding()) return true;
  73. if (lhs->binding() > rhs->binding()) return false;
  74. // Examine types.
  75. // In valid Vulkan the only conflict can occur between
  76. // images and samplers. We only care about a specific
  77. // comparison when one is a image-like thing and the other
  78. // is a sampler-like thing of the same shape. So unwrap
  79. // types until we hit one of those two.
  80. auto* def_use_mgr = lhs->var->context()->get_def_use_mgr();
  81. // Returns the type found by iteratively following pointer pointee type,
  82. // or array element type.
  83. auto unwrap = [&def_use_mgr](Instruction* ty) {
  84. bool keep_going = true;
  85. do {
  86. switch (ty->opcode()) {
  87. case spv::Op::OpTypePointer:
  88. ty = def_use_mgr->GetDef(ty->GetSingleWordInOperand(1));
  89. break;
  90. case spv::Op::OpTypeArray:
  91. case spv::Op::OpTypeRuntimeArray:
  92. ty = def_use_mgr->GetDef(ty->GetSingleWordInOperand(0));
  93. break;
  94. default:
  95. keep_going = false;
  96. break;
  97. }
  98. } while (keep_going);
  99. return ty;
  100. };
  101. auto* lhs_ty = unwrap(def_use_mgr->GetDef(lhs->var->type_id()));
  102. auto* rhs_ty = unwrap(def_use_mgr->GetDef(rhs->var->type_id()));
  103. if (lhs_ty->opcode() == rhs_ty->opcode()) {
  104. // Pick based on variable ID.
  105. return lhs->var->result_id() < rhs->var->result_id();
  106. }
  107. // A sampler is always greater than an image.
  108. if (lhs_ty->opcode() == spv::Op::OpTypeSampler) {
  109. return false;
  110. }
  111. if (rhs_ty->opcode() == spv::Op::OpTypeSampler) {
  112. return true;
  113. }
  114. // Pick based on variable ID.
  115. return lhs->var->result_id() < rhs->var->result_id();
  116. }
  117. // Summarizes the caller-callee relationships between functions in a module.
  118. class CallGraph {
  119. public:
  120. // Returns the list of all functions statically reachable from entry points,
  121. // where callees precede callers.
  122. const std::vector<uint32_t>& CalleesBeforeCallers() const {
  123. return visit_order_;
  124. }
  125. // Returns the list functions called from a given function.
  126. const std::unordered_set<uint32_t>& Callees(uint32_t caller) {
  127. return calls_[caller];
  128. }
  129. CallGraph(IRContext& context) {
  130. // Populate calls_.
  131. std::queue<uint32_t> callee_queue;
  132. for (const auto& fn : *context.module()) {
  133. auto& callees = calls_[fn.result_id()];
  134. context.AddCalls(&fn, &callee_queue);
  135. while (!callee_queue.empty()) {
  136. callees.insert(callee_queue.front());
  137. callee_queue.pop();
  138. }
  139. }
  140. // Perform depth-first search, starting from each entry point.
  141. // Populates visit_order_.
  142. for (const auto& ep : context.module()->entry_points()) {
  143. Visit(ep.GetSingleWordInOperand(1));
  144. }
  145. }
  146. private:
  147. // Visits a function, recursively visiting its callees. Adds this ID
  148. // to the visit_order after all callees have been visited.
  149. void Visit(uint32_t func_id) {
  150. if (visited_.count(func_id)) {
  151. return;
  152. }
  153. visited_.insert(func_id);
  154. for (auto callee_id : calls_[func_id]) {
  155. Visit(callee_id);
  156. }
  157. visit_order_.push_back(func_id);
  158. }
  159. // Maps the ID of a function to the IDs of functions it calls.
  160. std::unordered_map<uint32_t, std::unordered_set<uint32_t>> calls_;
  161. // IDs of visited functions;
  162. std::unordered_set<uint32_t> visited_;
  163. // IDs of functions, where callees precede callers.
  164. std::vector<uint32_t> visit_order_;
  165. };
  166. // Returns vector binding info for all resource variables in the module.
  167. auto GetVarBindings(IRContext& context) {
  168. std::vector<VarBindingInfo> vars;
  169. auto* deco_mgr = context.get_decoration_mgr();
  170. for (auto& inst : context.module()->types_values()) {
  171. if (inst.opcode() == spv::Op::OpVariable) {
  172. Instruction* descriptor_set_deco = nullptr;
  173. Instruction* binding_deco = nullptr;
  174. for (auto* deco : deco_mgr->GetDecorationsFor(inst.result_id(), false)) {
  175. switch (static_cast<spv::Decoration>(deco->GetSingleWordInOperand(1))) {
  176. case spv::Decoration::DescriptorSet:
  177. assert(!descriptor_set_deco);
  178. descriptor_set_deco = deco;
  179. break;
  180. case spv::Decoration::Binding:
  181. assert(!binding_deco);
  182. binding_deco = deco;
  183. break;
  184. default:
  185. break;
  186. }
  187. }
  188. if (descriptor_set_deco && binding_deco) {
  189. vars.push_back({&inst, descriptor_set_deco->GetSingleWordInOperand(2),
  190. binding_deco});
  191. }
  192. }
  193. }
  194. return vars;
  195. }
  196. // Merges the bindings from source into sink. Maintains order and uniqueness
  197. // within a list of bindings.
  198. void Merge(DescriptorSets& sink, const DescriptorSets& source) {
  199. for (auto index_and_bindings : source) {
  200. const uint32_t index = index_and_bindings.first;
  201. const BindingList& src1 = index_and_bindings.second;
  202. const BindingList& src2 = sink[index];
  203. BindingList merged;
  204. merged.resize(src1.size() + src2.size());
  205. auto merged_end = std::merge(src1.begin(), src1.end(), src2.begin(),
  206. src2.end(), merged.begin(), Less);
  207. auto unique_end = std::unique(merged.begin(), merged_end);
  208. merged.resize(unique_end - merged.begin());
  209. sink[index] = std::move(merged);
  210. }
  211. }
  212. // Resolves conflicts within this binding list, so the binding number on an
  213. // item is at least one more than the binding number on the previous item.
  214. // When this does not yet hold, increase the binding number on the second
  215. // item in the pair. Returns true if any changes were applied.
  216. bool ResolveConflicts(BindingList& bl) {
  217. bool changed = false;
  218. for (size_t i = 1; i < bl.size(); i++) {
  219. const auto prev_num = bl[i - 1]->binding();
  220. if (prev_num >= bl[i]->binding()) {
  221. bl[i]->updateBinding(prev_num + 1);
  222. changed = true;
  223. }
  224. }
  225. return changed;
  226. }
  227. Pass::Status ResolveBindingConflictsPass::Process() {
  228. // Assumes the descriptor set and binding decorations are not provided
  229. // via decoration groups. Decoration groups were deprecated in SPIR-V 1.3
  230. // Revision 6. I have not seen any compiler generate them. --dneto
  231. auto vars = GetVarBindings(*context());
  232. // Maps a function ID to the variables used directly or indirectly by the
  233. // function, organized into descriptor sets. Each descriptor set
  234. // consists of a BindingList of distinct variables.
  235. std::unordered_map<uint32_t, DescriptorSets> used_vars;
  236. // Determine variables directly used by functions.
  237. auto* def_use_mgr = context()->get_def_use_mgr();
  238. for (auto& var : vars) {
  239. std::unordered_set<uint32_t> visited_functions_for_var;
  240. def_use_mgr->ForEachUser(var.var, [&](Instruction* user) {
  241. if (auto* block = context()->get_instr_block(user)) {
  242. auto* fn = block->GetParent();
  243. assert(fn);
  244. const auto fn_id = fn->result_id();
  245. if (visited_functions_for_var.insert(fn_id).second) {
  246. used_vars[fn_id][var.descriptor_set].push_back(&var);
  247. }
  248. }
  249. });
  250. }
  251. // Sort within a descriptor set by binding number.
  252. for (auto& sets_for_fn : used_vars) {
  253. for (auto& ds : sets_for_fn.second) {
  254. BindingList& bindings = ds.second;
  255. std::stable_sort(bindings.begin(), bindings.end(), Less);
  256. }
  257. }
  258. // Propagate from callees to callers.
  259. CallGraph call_graph(*context());
  260. for (const uint32_t caller : call_graph.CalleesBeforeCallers()) {
  261. DescriptorSets& caller_ds = used_vars[caller];
  262. for (const uint32_t callee : call_graph.Callees(caller)) {
  263. Merge(caller_ds, used_vars[callee]);
  264. }
  265. }
  266. // At this point, the descriptor sets associated with each entry point
  267. // capture exactly the set of resource variables statically used
  268. // by the static call tree of that entry point.
  269. // Resolve conflicts.
  270. // VarBindingInfo objects may be shared between the bindings lists.
  271. // Updating a binding in one list can require updating another list later.
  272. // So repeat updates until settling.
  273. // The union of BindingLists across all entry points.
  274. std::vector<BindingList*> ep_bindings;
  275. for (auto& ep : context()->module()->entry_points()) {
  276. for (auto& ds : used_vars[ep.GetSingleWordInOperand(1)]) {
  277. BindingList& bindings = ds.second;
  278. ep_bindings.push_back(&bindings);
  279. }
  280. }
  281. bool modified = false;
  282. bool found_conflict;
  283. do {
  284. found_conflict = false;
  285. for (BindingList* bl : ep_bindings) {
  286. found_conflict |= ResolveConflicts(*bl);
  287. }
  288. modified |= found_conflict;
  289. } while (found_conflict);
  290. return modified ? Pass::Status::SuccessWithChange
  291. : Pass::Status::SuccessWithoutChange;
  292. }
  293. } // namespace opt
  294. } // namespace spvtools