divergence_analysis.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. // Copyright (c) 2021 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef SOURCE_LINT_DIVERGENCE_ANALYSIS_H_
  15. #define SOURCE_LINT_DIVERGENCE_ANALYSIS_H_
  16. #include <cstdint>
  17. #include <ostream>
  18. #include <unordered_map>
  19. #include "source/opt/basic_block.h"
  20. #include "source/opt/control_dependence.h"
  21. #include "source/opt/dataflow.h"
  22. #include "source/opt/function.h"
  23. #include "source/opt/instruction.h"
  24. namespace spvtools {
  25. namespace lint {
  26. // Computes the static divergence level for blocks (control flow) and values.
  27. //
  28. // A value is uniform if all threads that execute it are guaranteed to have the
  29. // same value. Similarly, a value is partially uniform if this is true only
  30. // within each derivative group. If neither apply, it is divergent.
  31. //
  32. // Control flow through a block is uniform if for any possible execution and
  33. // point in time, all threads are executing it, or no threads are executing it.
  34. // In particular, it is never possible for some threads to be inside the block
  35. // and some threads not executing.
  36. // TODO(kuhar): Clarify the difference between uniform, divergent, and
  37. // partially-uniform execution in this analysis.
  38. //
  39. // Caveat:
  40. // As we use control dependence to determine how divergence is propagated, this
  41. // analysis can be overly permissive when the merge block for a conditional
  42. // branch or switch is later than (strictly postdominates) the expected merge
  43. // block, which is the immediate postdominator. However, this is not expected to
  44. // be a problem in practice, given that SPIR-V is generally output by compilers
  45. // and other automated tools, which would assign the earliest possible merge
  46. // block, rather than written by hand.
  47. // TODO(kuhar): Handle late merges.
  48. class DivergenceAnalysis : public opt::ForwardDataFlowAnalysis {
  49. public:
  50. // The tightest (most uniform) level of divergence that can be determined
  51. // statically for a value or control flow for a block.
  52. //
  53. // The values are ordered such that A > B means that A is potentially more
  54. // divergent than B.
  55. // TODO(kuhar): Rename |PartiallyUniform' to something less confusing. For
  56. // example, the enum could be based on scopes.
  57. enum class DivergenceLevel {
  58. // The value or control flow is uniform across the entire invocation group.
  59. kUniform = 0,
  60. // The value or control flow is uniform across the derivative group, but not
  61. // the invocation group.
  62. kPartiallyUniform = 1,
  63. // The value or control flow is not statically uniform.
  64. kDivergent = 2,
  65. };
  66. DivergenceAnalysis(opt::IRContext& context)
  67. : ForwardDataFlowAnalysis(context, LabelPosition::kLabelsAtEnd) {}
  68. // Returns the divergence level for the given value (non-label instructions),
  69. // or control flow for the given block.
  70. DivergenceLevel GetDivergenceLevel(uint32_t id) {
  71. auto it = divergence_.find(id);
  72. if (it == divergence_.end()) {
  73. return DivergenceLevel::kUniform;
  74. }
  75. return it->second;
  76. }
  77. // Returns the divergence source for the given id. The following types of
  78. // divergence flows from A to B are possible:
  79. //
  80. // data -> data: A is used as an operand in the definition of B.
  81. // data -> control: B is control-dependent on a branch with condition A.
  82. // control -> data: B is a OpPhi instruction in which A is a block operand.
  83. // control -> control: B is control-dependent on A.
  84. uint32_t GetDivergenceSource(uint32_t id) {
  85. auto it = divergence_source_.find(id);
  86. if (it == divergence_source_.end()) {
  87. return 0;
  88. }
  89. return it->second;
  90. }
  91. // Returns the dependence source for the control dependence for the given id.
  92. // This only exists for data -> control edges.
  93. //
  94. // In other words, if block 2 is dependent on block 1 due to value 3 (e.g.
  95. // block 1 terminates with OpBranchConditional %3 %2 %4):
  96. // * GetDivergenceSource(2) = 3
  97. // * GetDivergenceDependenceSource(2) = 1
  98. //
  99. // Returns 0 if not applicable.
  100. uint32_t GetDivergenceDependenceSource(uint32_t id) {
  101. auto it = divergence_dependence_source_.find(id);
  102. if (it == divergence_dependence_source_.end()) {
  103. return 0;
  104. }
  105. return it->second;
  106. }
  107. void InitializeWorklist(opt::Function* function,
  108. bool is_first_iteration) override {
  109. // Since |EnqueueSuccessors| is complete, we only need one pass.
  110. if (is_first_iteration) {
  111. Setup(function);
  112. opt::ForwardDataFlowAnalysis::InitializeWorklist(function, true);
  113. }
  114. }
  115. void EnqueueSuccessors(opt::Instruction* inst) override;
  116. VisitResult Visit(opt::Instruction* inst) override;
  117. private:
  118. VisitResult VisitBlock(uint32_t id);
  119. VisitResult VisitInstruction(opt::Instruction* inst);
  120. // Computes the divergence level for the result of the given instruction
  121. // based on the current state of the analysis. This is always an
  122. // underapproximation, which will be improved as the analysis proceeds.
  123. DivergenceLevel ComputeInstructionDivergence(opt::Instruction* inst);
  124. // Computes the divergence level for a variable, which is used for loads.
  125. DivergenceLevel ComputeVariableDivergence(opt::Instruction* var);
  126. // Initializes data structures for performing dataflow on the given function.
  127. void Setup(opt::Function* function);
  128. std::unordered_map<uint32_t, DivergenceLevel> divergence_;
  129. std::unordered_map<uint32_t, uint32_t> divergence_source_;
  130. std::unordered_map<uint32_t, uint32_t> divergence_dependence_source_;
  131. // Stores the result of following unconditional branches starting from the
  132. // given block. This is used to detect when reconvergence needs to be
  133. // accounted for.
  134. std::unordered_map<uint32_t, uint32_t> follow_unconditional_branches_;
  135. opt::ControlDependenceAnalysis cd_;
  136. };
  137. std::ostream& operator<<(std::ostream& os,
  138. DivergenceAnalysis::DivergenceLevel level);
  139. } // namespace lint
  140. } // namespace spvtools
  141. #endif // SOURCE_LINT_DIVERGENCE_ANALYSIS_H_