| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- // Copyright (c) 2021 Google LLC.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- #include "source/opt/dataflow.h"
- #include <map>
- #include <set>
- #include "gtest/gtest.h"
- #include "opt/function_utils.h"
- #include "source/opt/build_module.h"
- namespace spvtools {
- namespace opt {
- namespace {
- using DataFlowTest = ::testing::Test;
- // Simple analyses for testing:
- // Stores the result IDs of visited instructions in visit order.
- struct VisitOrder : public ForwardDataFlowAnalysis {
- std::vector<uint32_t> visited_result_ids;
- VisitOrder(IRContext& context, LabelPosition label_position)
- : ForwardDataFlowAnalysis(context, label_position) {}
- VisitResult Visit(Instruction* inst) override {
- if (inst->HasResultId()) {
- visited_result_ids.push_back(inst->result_id());
- }
- return DataFlowAnalysis::VisitResult::kResultFixed;
- }
- };
- // For each block, stores the set of blocks it can be preceded by.
- // For example, with the following CFG:
- // V-----------.
- // -> 11 -> 12 -> 13 -> 15
- // \-> 14 ---^
- //
- // The answer is:
- // 11: 11, 12, 13
- // 12: 11, 12, 13
- // 13: 11, 12, 13
- // 14: 11, 12, 13
- // 15: 11, 12, 13, 14
- struct BackwardReachability : public ForwardDataFlowAnalysis {
- std::map<uint32_t, std::set<uint32_t>> reachable_from;
- BackwardReachability(IRContext& context)
- : ForwardDataFlowAnalysis(
- context, ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly) {}
- VisitResult Visit(Instruction* inst) override {
- // Conditional branches can be enqueued from labels, so skip them.
- if (inst->opcode() != spv::Op::OpLabel)
- return DataFlowAnalysis::VisitResult::kResultFixed;
- uint32_t id = inst->result_id();
- VisitResult ret = DataFlowAnalysis::VisitResult::kResultFixed;
- std::set<uint32_t>& precedents = reachable_from[id];
- for (uint32_t pred : context().cfg()->preds(id)) {
- bool pred_inserted = precedents.insert(pred).second;
- if (pred_inserted) {
- ret = DataFlowAnalysis::VisitResult::kResultChanged;
- }
- for (uint32_t block : reachable_from[pred]) {
- bool inserted = precedents.insert(block).second;
- if (inserted) {
- ret = DataFlowAnalysis::VisitResult::kResultChanged;
- }
- }
- }
- return ret;
- }
- void InitializeWorklist(Function* function,
- bool is_first_iteration) override {
- // Since successor function is exact, only need one pass.
- if (is_first_iteration) {
- ForwardDataFlowAnalysis::InitializeWorklist(function, true);
- }
- }
- };
- TEST_F(DataFlowTest, ReversePostOrder) {
- // Note: labels and IDs are intentionally out of order.
- //
- // CFG: (order of branches is from bottom to top)
- // V-----------.
- // -> 50 -> 40 -> 20 -> 60 -> 70
- // \-> 30 ---^
- // DFS tree with RPO numbering:
- // -> 50[0] -> 40[1] -> 20[2] 60[4] -> 70[5]
- // \-> 30[3] ---^
- const std::string text = R"(
- OpCapability Shader
- %1 = OpExtInstImport "GLSL.std.450"
- OpMemoryModel Logical GLSL450
- OpEntryPoint Fragment %2 "main"
- OpExecutionMode %2 OriginUpperLeft
- OpSource GLSL 430
- %3 = OpTypeVoid
- %4 = OpTypeFunction %3
- %6 = OpTypeBool
- %5 = OpConstantTrue %6
- %2 = OpFunction %3 None %4
- %50 = OpLabel
- %51 = OpUndef %6
- %52 = OpUndef %6
- OpBranch %40
- %70 = OpLabel
- %69 = OpUndef %6
- OpReturn
- %60 = OpLabel
- %61 = OpUndef %6
- OpBranchConditional %5 %70 %40
- %30 = OpLabel
- %29 = OpUndef %6
- OpBranch %60
- %20 = OpLabel
- %21 = OpUndef %6
- OpBranch %60
- %40 = OpLabel
- %39 = OpUndef %6
- OpBranchConditional %5 %30 %20
- OpFunctionEnd
- )";
- std::unique_ptr<IRContext> context =
- BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
- SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
- ASSERT_NE(context, nullptr);
- Function* function = spvtest::GetFunction(context->module(), 2);
- std::map<ForwardDataFlowAnalysis::LabelPosition, std::vector<uint32_t>>
- expected_order;
- expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly] = {
- 50, 40, 20, 30, 60, 70,
- };
- expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtBeginning] = {
- 50, 51, 52, 40, 39, 20, 21, 30, 29, 60, 61, 70, 69,
- };
- expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtEnd] = {
- 51, 52, 50, 39, 40, 21, 20, 29, 30, 61, 60, 69, 70,
- };
- expected_order[ForwardDataFlowAnalysis::LabelPosition::kNoLabels] = {
- 51, 52, 39, 21, 29, 61, 69,
- };
- for (const auto& test_case : expected_order) {
- VisitOrder analysis(*context, test_case.first);
- analysis.Run(function);
- EXPECT_EQ(test_case.second, analysis.visited_result_ids);
- }
- }
- TEST_F(DataFlowTest, BackwardReachability) {
- // CFG:
- // V-----------.
- // -> 11 -> 12 -> 13 -> 15
- // \-> 14 ---^
- const std::string text = R"(
- OpCapability Shader
- %1 = OpExtInstImport "GLSL.std.450"
- OpMemoryModel Logical GLSL450
- OpEntryPoint Fragment %2 "main"
- OpExecutionMode %2 OriginUpperLeft
- OpSource GLSL 430
- %3 = OpTypeVoid
- %4 = OpTypeFunction %3
- %6 = OpTypeBool
- %5 = OpConstantTrue %6
- %2 = OpFunction %3 None %4
- %11 = OpLabel
- OpBranch %12
- %12 = OpLabel
- OpBranchConditional %5 %14 %13
- %13 = OpLabel
- OpBranchConditional %5 %15 %11
- %14 = OpLabel
- OpBranch %15
- %15 = OpLabel
- OpReturn
- OpFunctionEnd
- )";
- std::unique_ptr<IRContext> context =
- BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
- SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
- ASSERT_NE(context, nullptr);
- Function* function = spvtest::GetFunction(context->module(), 2);
- BackwardReachability analysis(*context);
- analysis.Run(function);
- std::map<uint32_t, std::set<uint32_t>> expected_result;
- expected_result[11] = {11, 12, 13};
- expected_result[12] = {11, 12, 13};
- expected_result[13] = {11, 12, 13};
- expected_result[14] = {11, 12, 13};
- expected_result[15] = {11, 12, 13, 14};
- EXPECT_EQ(expected_result, analysis.reachable_from);
- }
- } // namespace
- } // namespace opt
- } // namespace spvtools
|