| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- // Copyright (c) 2021 Google LLC.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- #include "source/lint/divergence_analysis.h"
- #include "source/opt/basic_block.h"
- #include "source/opt/control_dependence.h"
- #include "source/opt/dataflow.h"
- #include "source/opt/function.h"
- #include "source/opt/instruction.h"
- namespace spvtools {
- namespace lint {
- void DivergenceAnalysis::EnqueueSuccessors(opt::Instruction* inst) {
- // Enqueue control dependents of block, if applicable.
- // There are two ways for a dependence source to be updated:
- // 1. control -> control: source block is marked divergent.
- // 2. data -> control: branch condition is marked divergent.
- uint32_t block_id;
- if (inst->IsBlockTerminator()) {
- block_id = context().get_instr_block(inst)->id();
- } else if (inst->opcode() == spv::Op::OpLabel) {
- block_id = inst->result_id();
- opt::BasicBlock* bb = context().cfg()->block(block_id);
- // Only enqueue phi instructions, as other uses don't affect divergence.
- bb->ForEachPhiInst([this](opt::Instruction* phi) { Enqueue(phi); });
- } else {
- opt::ForwardDataFlowAnalysis::EnqueueUsers(inst);
- return;
- }
- if (!cd_.HasBlock(block_id)) {
- return;
- }
- for (const spvtools::opt::ControlDependence& dep :
- cd_.GetDependenceTargets(block_id)) {
- opt::Instruction* target_inst =
- context().cfg()->block(dep.target_bb_id())->GetLabelInst();
- Enqueue(target_inst);
- }
- }
- opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::Visit(
- opt::Instruction* inst) {
- if (inst->opcode() == spv::Op::OpLabel) {
- return VisitBlock(inst->result_id());
- } else {
- return VisitInstruction(inst);
- }
- }
- opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitBlock(uint32_t id) {
- if (!cd_.HasBlock(id)) {
- return opt::DataFlowAnalysis::VisitResult::kResultFixed;
- }
- DivergenceLevel& cur_level = divergence_[id];
- if (cur_level == DivergenceLevel::kDivergent) {
- return opt::DataFlowAnalysis::VisitResult::kResultFixed;
- }
- DivergenceLevel orig = cur_level;
- for (const spvtools::opt::ControlDependence& dep :
- cd_.GetDependenceSources(id)) {
- if (divergence_[dep.source_bb_id()] > cur_level) {
- cur_level = divergence_[dep.source_bb_id()];
- divergence_source_[id] = dep.source_bb_id();
- } else if (dep.source_bb_id() != 0) {
- uint32_t condition_id = dep.GetConditionID(*context().cfg());
- DivergenceLevel dep_level = divergence_[condition_id];
- // Check if we are along the chain of unconditional branches starting from
- // the branch target.
- if (follow_unconditional_branches_[dep.branch_target_bb_id()] !=
- follow_unconditional_branches_[dep.target_bb_id()]) {
- // We must have reconverged in order to reach this block.
- // Promote partially uniform to divergent.
- if (dep_level == DivergenceLevel::kPartiallyUniform) {
- dep_level = DivergenceLevel::kDivergent;
- }
- }
- if (dep_level > cur_level) {
- cur_level = dep_level;
- divergence_source_[id] = condition_id;
- divergence_dependence_source_[id] = dep.source_bb_id();
- }
- }
- }
- return cur_level > orig ? VisitResult::kResultChanged
- : VisitResult::kResultFixed;
- }
- opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitInstruction(
- opt::Instruction* inst) {
- if (inst->IsBlockTerminator()) {
- // This is called only when the condition has changed, so return changed.
- return VisitResult::kResultChanged;
- }
- if (!inst->HasResultId()) {
- return VisitResult::kResultFixed;
- }
- uint32_t id = inst->result_id();
- DivergenceLevel& cur_level = divergence_[id];
- if (cur_level == DivergenceLevel::kDivergent) {
- return opt::DataFlowAnalysis::VisitResult::kResultFixed;
- }
- DivergenceLevel orig = cur_level;
- cur_level = ComputeInstructionDivergence(inst);
- return cur_level > orig ? VisitResult::kResultChanged
- : VisitResult::kResultFixed;
- }
- DivergenceAnalysis::DivergenceLevel
- DivergenceAnalysis::ComputeInstructionDivergence(opt::Instruction* inst) {
- // TODO(kuhar): Check to see if inst is decorated with Uniform or UniformId
- // and use that to short circuit other checks. Uniform is for subgroups which
- // would satisfy derivative groups too. UniformId takes a scope, so if it is
- // subgroup or greater it could satisfy derivative group and
- // Device/QueueFamily could satisfy fully uniform.
- uint32_t id = inst->result_id();
- // Handle divergence roots.
- if (inst->opcode() == spv::Op::OpFunctionParameter) {
- divergence_source_[id] = 0;
- return divergence_[id] = DivergenceLevel::kDivergent;
- } else if (inst->IsLoad()) {
- spvtools::opt::Instruction* var = inst->GetBaseAddress();
- if (var->opcode() != spv::Op::OpVariable) {
- // Assume divergent.
- divergence_source_[id] = 0;
- return DivergenceLevel::kDivergent;
- }
- DivergenceLevel ret = ComputeVariableDivergence(var);
- if (ret > DivergenceLevel::kUniform) {
- divergence_source_[inst->result_id()] = 0;
- }
- return divergence_[id] = ret;
- }
- // Get the maximum divergence of the operands.
- DivergenceLevel ret = DivergenceLevel::kUniform;
- inst->ForEachInId([this, inst, &ret](const uint32_t* op) {
- if (!op) return;
- if (divergence_[*op] > ret) {
- divergence_source_[inst->result_id()] = *op;
- ret = divergence_[*op];
- }
- });
- divergence_[inst->result_id()] = ret;
- return ret;
- }
- DivergenceAnalysis::DivergenceLevel
- DivergenceAnalysis::ComputeVariableDivergence(opt::Instruction* var) {
- uint32_t type_id = var->type_id();
- spvtools::opt::analysis::Pointer* type =
- context().get_type_mgr()->GetType(type_id)->AsPointer();
- assert(type != nullptr);
- uint32_t def_id = var->result_id();
- DivergenceLevel ret;
- switch (type->storage_class()) {
- case spv::StorageClass::Function:
- case spv::StorageClass::Generic:
- case spv::StorageClass::AtomicCounter:
- case spv::StorageClass::StorageBuffer:
- case spv::StorageClass::PhysicalStorageBuffer:
- case spv::StorageClass::Output:
- case spv::StorageClass::Workgroup:
- case spv::StorageClass::Image: // Image atomics probably aren't uniform.
- case spv::StorageClass::Private:
- ret = DivergenceLevel::kDivergent;
- break;
- case spv::StorageClass::Input:
- ret = DivergenceLevel::kDivergent;
- // If this variable has a Flat decoration, it is partially uniform.
- // TODO(kuhar): Track access chain indices and also consider Flat members
- // of a structure.
- context().get_decoration_mgr()->WhileEachDecoration(
- def_id, static_cast<uint32_t>(spv::Decoration::Flat),
- [&ret](const opt::Instruction&) {
- ret = DivergenceLevel::kPartiallyUniform;
- return false;
- });
- break;
- case spv::StorageClass::UniformConstant:
- // May be a storage image which is also written to; mark those as
- // divergent.
- if (!var->IsVulkanStorageImage() || var->IsReadOnlyPointer()) {
- ret = DivergenceLevel::kUniform;
- } else {
- ret = DivergenceLevel::kDivergent;
- }
- break;
- case spv::StorageClass::Uniform:
- case spv::StorageClass::PushConstant:
- case spv::StorageClass::CrossWorkgroup: // Not for shaders; default
- // uniform.
- default:
- ret = DivergenceLevel::kUniform;
- break;
- }
- return ret;
- }
- void DivergenceAnalysis::Setup(opt::Function* function) {
- // TODO(kuhar): Run functions called by |function| so we can detect
- // reconvergence caused by multiple returns.
- cd_.ComputeControlDependenceGraph(
- *context().cfg(), *context().GetPostDominatorAnalysis(function));
- context().cfg()->ForEachBlockInPostOrder(
- function->entry().get(), [this](const opt::BasicBlock* bb) {
- uint32_t id = bb->id();
- if (bb->terminator() == nullptr ||
- bb->terminator()->opcode() != spv::Op::OpBranch) {
- follow_unconditional_branches_[id] = id;
- } else {
- uint32_t target_id = bb->terminator()->GetSingleWordInOperand(0);
- // Target is guaranteed to have been visited before us in postorder.
- follow_unconditional_branches_[id] =
- follow_unconditional_branches_[target_id];
- }
- });
- }
- std::ostream& operator<<(std::ostream& os,
- DivergenceAnalysis::DivergenceLevel level) {
- switch (level) {
- case DivergenceAnalysis::DivergenceLevel::kUniform:
- return os << "uniform";
- case DivergenceAnalysis::DivergenceLevel::kPartiallyUniform:
- return os << "partially uniform";
- case DivergenceAnalysis::DivergenceLevel::kDivergent:
- return os << "divergent";
- default:
- return os << "<invalid divergence level>";
- }
- }
- } // namespace lint
- } // namespace spvtools
|