lint_divergent_derivatives.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. // Copyright (c) 2021 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <cassert>
  15. #include <sstream>
  16. #include <string>
  17. #include "source/diagnostic.h"
  18. #include "source/lint/divergence_analysis.h"
  19. #include "source/lint/lints.h"
  20. #include "source/opt/basic_block.h"
  21. #include "source/opt/cfg.h"
  22. #include "source/opt/control_dependence.h"
  23. #include "source/opt/def_use_manager.h"
  24. #include "source/opt/dominator_analysis.h"
  25. #include "source/opt/instruction.h"
  26. #include "source/opt/ir_context.h"
  27. #include "spirv-tools/libspirv.h"
  28. namespace spvtools {
  29. namespace lint {
  30. namespace lints {
  31. namespace {
  32. // Returns the %name[id], where `name` is the first name associated with the
  33. // given id, or just %id if one is not found.
  34. std::string GetFriendlyName(opt::IRContext* context, uint32_t id) {
  35. auto names = context->GetNames(id);
  36. std::stringstream ss;
  37. ss << "%";
  38. if (names.empty()) {
  39. ss << id;
  40. } else {
  41. opt::Instruction* inst_name = names.begin()->second;
  42. if (inst_name->opcode() == spv::Op::OpName) {
  43. ss << names.begin()->second->GetInOperand(0).AsString();
  44. ss << "[" << id << "]";
  45. } else {
  46. ss << id;
  47. }
  48. }
  49. return ss.str();
  50. }
  51. bool InstructionHasDerivative(const opt::Instruction& inst) {
  52. static const spv::Op derivative_opcodes[] = {
  53. // Implicit derivatives.
  54. spv::Op::OpImageSampleImplicitLod,
  55. spv::Op::OpImageSampleDrefImplicitLod,
  56. spv::Op::OpImageSampleProjImplicitLod,
  57. spv::Op::OpImageSampleProjDrefImplicitLod,
  58. spv::Op::OpImageSparseSampleImplicitLod,
  59. spv::Op::OpImageSparseSampleDrefImplicitLod,
  60. spv::Op::OpImageSparseSampleProjImplicitLod,
  61. spv::Op::OpImageSparseSampleProjDrefImplicitLod,
  62. // Explicit derivatives.
  63. spv::Op::OpDPdx,
  64. spv::Op::OpDPdy,
  65. spv::Op::OpFwidth,
  66. spv::Op::OpDPdxFine,
  67. spv::Op::OpDPdyFine,
  68. spv::Op::OpFwidthFine,
  69. spv::Op::OpDPdxCoarse,
  70. spv::Op::OpDPdyCoarse,
  71. spv::Op::OpFwidthCoarse,
  72. };
  73. return std::find(std::begin(derivative_opcodes), std::end(derivative_opcodes),
  74. inst.opcode()) != std::end(derivative_opcodes);
  75. }
  76. spvtools::DiagnosticStream Warn(opt::IRContext* context,
  77. opt::Instruction* inst) {
  78. if (inst == nullptr) {
  79. return DiagnosticStream({0, 0, 0}, context->consumer(), "", SPV_WARNING);
  80. } else {
  81. // TODO(kuhar): Use line numbers based on debug info.
  82. return DiagnosticStream(
  83. {0, 0, 0}, context->consumer(),
  84. inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES),
  85. SPV_WARNING);
  86. }
  87. }
  88. void PrintDivergenceFlow(opt::IRContext* context, DivergenceAnalysis div,
  89. uint32_t id) {
  90. opt::analysis::DefUseManager* def_use = context->get_def_use_mgr();
  91. opt::CFG* cfg = context->cfg();
  92. while (id != 0) {
  93. bool is_block = def_use->GetDef(id)->opcode() == spv::Op::OpLabel;
  94. if (is_block) {
  95. Warn(context, nullptr)
  96. << "block " << GetFriendlyName(context, id) << " is divergent";
  97. uint32_t source = div.GetDivergenceSource(id);
  98. // Skip intermediate blocks.
  99. while (source != 0 &&
  100. def_use->GetDef(source)->opcode() == spv::Op::OpLabel) {
  101. id = source;
  102. source = div.GetDivergenceSource(id);
  103. }
  104. if (source == 0) break;
  105. spvtools::opt::Instruction* branch =
  106. cfg->block(div.GetDivergenceDependenceSource(id))->terminator();
  107. Warn(context, branch)
  108. << "because it depends on a conditional branch on divergent value "
  109. << GetFriendlyName(context, source) << "";
  110. id = source;
  111. } else {
  112. Warn(context, nullptr)
  113. << "value " << GetFriendlyName(context, id) << " is divergent";
  114. uint32_t source = div.GetDivergenceSource(id);
  115. opt::Instruction* def = def_use->GetDef(id);
  116. opt::Instruction* source_def =
  117. source == 0 ? nullptr : def_use->GetDef(source);
  118. // First print data -> data dependencies.
  119. while (source != 0 && source_def->opcode() != spv::Op::OpLabel) {
  120. Warn(context, def_use->GetDef(id))
  121. << "because " << GetFriendlyName(context, id) << " uses value "
  122. << GetFriendlyName(context, source)
  123. << "in its definition, which is divergent";
  124. id = source;
  125. def = source_def;
  126. source = div.GetDivergenceSource(id);
  127. source_def = def_use->GetDef(source);
  128. }
  129. if (source == 0) {
  130. Warn(context, def) << "because it has a divergent definition";
  131. break;
  132. }
  133. Warn(context, def) << "because it is conditionally set in block "
  134. << GetFriendlyName(context, source);
  135. id = source;
  136. }
  137. }
  138. }
  139. } // namespace
  140. bool CheckDivergentDerivatives(opt::IRContext* context) {
  141. DivergenceAnalysis div(*context);
  142. for (opt::Function& func : *context->module()) {
  143. div.Run(&func);
  144. for (const opt::BasicBlock& bb : func) {
  145. for (const opt::Instruction& inst : bb) {
  146. if (InstructionHasDerivative(inst) &&
  147. div.GetDivergenceLevel(bb.id()) >
  148. DivergenceAnalysis::DivergenceLevel::kPartiallyUniform) {
  149. Warn(context, nullptr)
  150. << "derivative with divergent control flow"
  151. << " located in block " << GetFriendlyName(context, bb.id());
  152. PrintDivergenceFlow(context, div, bb.id());
  153. }
  154. }
  155. }
  156. }
  157. return true;
  158. }
  159. } // namespace lints
  160. } // namespace lint
  161. } // namespace spvtools