divergence_analysis.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. // Copyright (c) 2021 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/lint/divergence_analysis.h"
  15. #include "source/opt/basic_block.h"
  16. #include "source/opt/control_dependence.h"
  17. #include "source/opt/dataflow.h"
  18. #include "source/opt/function.h"
  19. #include "source/opt/instruction.h"
  20. namespace spvtools {
  21. namespace lint {
  22. void DivergenceAnalysis::EnqueueSuccessors(opt::Instruction* inst) {
  23. // Enqueue control dependents of block, if applicable.
  24. // There are two ways for a dependence source to be updated:
  25. // 1. control -> control: source block is marked divergent.
  26. // 2. data -> control: branch condition is marked divergent.
  27. uint32_t block_id;
  28. if (inst->IsBlockTerminator()) {
  29. block_id = context().get_instr_block(inst)->id();
  30. } else if (inst->opcode() == spv::Op::OpLabel) {
  31. block_id = inst->result_id();
  32. opt::BasicBlock* bb = context().cfg()->block(block_id);
  33. // Only enqueue phi instructions, as other uses don't affect divergence.
  34. bb->ForEachPhiInst([this](opt::Instruction* phi) { Enqueue(phi); });
  35. } else {
  36. opt::ForwardDataFlowAnalysis::EnqueueUsers(inst);
  37. return;
  38. }
  39. if (!cd_.HasBlock(block_id)) {
  40. return;
  41. }
  42. for (const spvtools::opt::ControlDependence& dep :
  43. cd_.GetDependenceTargets(block_id)) {
  44. opt::Instruction* target_inst =
  45. context().cfg()->block(dep.target_bb_id())->GetLabelInst();
  46. Enqueue(target_inst);
  47. }
  48. }
  49. opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::Visit(
  50. opt::Instruction* inst) {
  51. if (inst->opcode() == spv::Op::OpLabel) {
  52. return VisitBlock(inst->result_id());
  53. } else {
  54. return VisitInstruction(inst);
  55. }
  56. }
  57. opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitBlock(uint32_t id) {
  58. if (!cd_.HasBlock(id)) {
  59. return opt::DataFlowAnalysis::VisitResult::kResultFixed;
  60. }
  61. DivergenceLevel& cur_level = divergence_[id];
  62. if (cur_level == DivergenceLevel::kDivergent) {
  63. return opt::DataFlowAnalysis::VisitResult::kResultFixed;
  64. }
  65. DivergenceLevel orig = cur_level;
  66. for (const spvtools::opt::ControlDependence& dep :
  67. cd_.GetDependenceSources(id)) {
  68. if (divergence_[dep.source_bb_id()] > cur_level) {
  69. cur_level = divergence_[dep.source_bb_id()];
  70. divergence_source_[id] = dep.source_bb_id();
  71. } else if (dep.source_bb_id() != 0) {
  72. uint32_t condition_id = dep.GetConditionID(*context().cfg());
  73. DivergenceLevel dep_level = divergence_[condition_id];
  74. // Check if we are along the chain of unconditional branches starting from
  75. // the branch target.
  76. if (follow_unconditional_branches_[dep.branch_target_bb_id()] !=
  77. follow_unconditional_branches_[dep.target_bb_id()]) {
  78. // We must have reconverged in order to reach this block.
  79. // Promote partially uniform to divergent.
  80. if (dep_level == DivergenceLevel::kPartiallyUniform) {
  81. dep_level = DivergenceLevel::kDivergent;
  82. }
  83. }
  84. if (dep_level > cur_level) {
  85. cur_level = dep_level;
  86. divergence_source_[id] = condition_id;
  87. divergence_dependence_source_[id] = dep.source_bb_id();
  88. }
  89. }
  90. }
  91. return cur_level > orig ? VisitResult::kResultChanged
  92. : VisitResult::kResultFixed;
  93. }
  94. opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitInstruction(
  95. opt::Instruction* inst) {
  96. if (inst->IsBlockTerminator()) {
  97. // This is called only when the condition has changed, so return changed.
  98. return VisitResult::kResultChanged;
  99. }
  100. if (!inst->HasResultId()) {
  101. return VisitResult::kResultFixed;
  102. }
  103. uint32_t id = inst->result_id();
  104. DivergenceLevel& cur_level = divergence_[id];
  105. if (cur_level == DivergenceLevel::kDivergent) {
  106. return opt::DataFlowAnalysis::VisitResult::kResultFixed;
  107. }
  108. DivergenceLevel orig = cur_level;
  109. cur_level = ComputeInstructionDivergence(inst);
  110. return cur_level > orig ? VisitResult::kResultChanged
  111. : VisitResult::kResultFixed;
  112. }
  113. DivergenceAnalysis::DivergenceLevel
  114. DivergenceAnalysis::ComputeInstructionDivergence(opt::Instruction* inst) {
  115. // TODO(kuhar): Check to see if inst is decorated with Uniform or UniformId
  116. // and use that to short circuit other checks. Uniform is for subgroups which
  117. // would satisfy derivative groups too. UniformId takes a scope, so if it is
  118. // subgroup or greater it could satisfy derivative group and
  119. // Device/QueueFamily could satisfy fully uniform.
  120. uint32_t id = inst->result_id();
  121. // Handle divergence roots.
  122. if (inst->opcode() == spv::Op::OpFunctionParameter) {
  123. divergence_source_[id] = 0;
  124. return divergence_[id] = DivergenceLevel::kDivergent;
  125. } else if (inst->IsLoad()) {
  126. spvtools::opt::Instruction* var = inst->GetBaseAddress();
  127. if (var->opcode() != spv::Op::OpVariable) {
  128. // Assume divergent.
  129. divergence_source_[id] = 0;
  130. return DivergenceLevel::kDivergent;
  131. }
  132. DivergenceLevel ret = ComputeVariableDivergence(var);
  133. if (ret > DivergenceLevel::kUniform) {
  134. divergence_source_[inst->result_id()] = 0;
  135. }
  136. return divergence_[id] = ret;
  137. }
  138. // Get the maximum divergence of the operands.
  139. DivergenceLevel ret = DivergenceLevel::kUniform;
  140. inst->ForEachInId([this, inst, &ret](const uint32_t* op) {
  141. if (!op) return;
  142. if (divergence_[*op] > ret) {
  143. divergence_source_[inst->result_id()] = *op;
  144. ret = divergence_[*op];
  145. }
  146. });
  147. divergence_[inst->result_id()] = ret;
  148. return ret;
  149. }
  150. DivergenceAnalysis::DivergenceLevel
  151. DivergenceAnalysis::ComputeVariableDivergence(opt::Instruction* var) {
  152. uint32_t type_id = var->type_id();
  153. spvtools::opt::analysis::Pointer* type =
  154. context().get_type_mgr()->GetType(type_id)->AsPointer();
  155. assert(type != nullptr);
  156. uint32_t def_id = var->result_id();
  157. DivergenceLevel ret;
  158. switch (type->storage_class()) {
  159. case spv::StorageClass::Function:
  160. case spv::StorageClass::Generic:
  161. case spv::StorageClass::AtomicCounter:
  162. case spv::StorageClass::StorageBuffer:
  163. case spv::StorageClass::PhysicalStorageBuffer:
  164. case spv::StorageClass::Output:
  165. case spv::StorageClass::Workgroup:
  166. case spv::StorageClass::Image: // Image atomics probably aren't uniform.
  167. case spv::StorageClass::Private:
  168. ret = DivergenceLevel::kDivergent;
  169. break;
  170. case spv::StorageClass::Input:
  171. ret = DivergenceLevel::kDivergent;
  172. // If this variable has a Flat decoration, it is partially uniform.
  173. // TODO(kuhar): Track access chain indices and also consider Flat members
  174. // of a structure.
  175. context().get_decoration_mgr()->WhileEachDecoration(
  176. def_id, static_cast<uint32_t>(spv::Decoration::Flat),
  177. [&ret](const opt::Instruction&) {
  178. ret = DivergenceLevel::kPartiallyUniform;
  179. return false;
  180. });
  181. break;
  182. case spv::StorageClass::UniformConstant:
  183. // May be a storage image which is also written to; mark those as
  184. // divergent.
  185. if (!var->IsVulkanStorageImage() || var->IsReadOnlyPointer()) {
  186. ret = DivergenceLevel::kUniform;
  187. } else {
  188. ret = DivergenceLevel::kDivergent;
  189. }
  190. break;
  191. case spv::StorageClass::Uniform:
  192. case spv::StorageClass::PushConstant:
  193. case spv::StorageClass::CrossWorkgroup: // Not for shaders; default
  194. // uniform.
  195. default:
  196. ret = DivergenceLevel::kUniform;
  197. break;
  198. }
  199. return ret;
  200. }
  201. void DivergenceAnalysis::Setup(opt::Function* function) {
  202. // TODO(kuhar): Run functions called by |function| so we can detect
  203. // reconvergence caused by multiple returns.
  204. cd_.ComputeControlDependenceGraph(
  205. *context().cfg(), *context().GetPostDominatorAnalysis(function));
  206. context().cfg()->ForEachBlockInPostOrder(
  207. function->entry().get(), [this](const opt::BasicBlock* bb) {
  208. uint32_t id = bb->id();
  209. if (bb->terminator() == nullptr ||
  210. bb->terminator()->opcode() != spv::Op::OpBranch) {
  211. follow_unconditional_branches_[id] = id;
  212. } else {
  213. uint32_t target_id = bb->terminator()->GetSingleWordInOperand(0);
  214. // Target is guaranteed to have been visited before us in postorder.
  215. follow_unconditional_branches_[id] =
  216. follow_unconditional_branches_[target_id];
  217. }
  218. });
  219. }
  220. std::ostream& operator<<(std::ostream& os,
  221. DivergenceAnalysis::DivergenceLevel level) {
  222. switch (level) {
  223. case DivergenceAnalysis::DivergenceLevel::kUniform:
  224. return os << "uniform";
  225. case DivergenceAnalysis::DivergenceLevel::kPartiallyUniform:
  226. return os << "partially uniform";
  227. case DivergenceAnalysis::DivergenceLevel::kDivergent:
  228. return os << "divergent";
  229. default:
  230. return os << "<invalid divergence level>";
  231. }
  232. }
  233. } // namespace lint
  234. } // namespace spvtools