control_dependence.cpp 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. // Copyright (c) 2021 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/control_dependence.h"
  15. #include <algorithm>
  16. #include <vector>
  17. #include "gmock/gmock-matchers.h"
  18. #include "gtest/gtest.h"
  19. #include "source/opt/build_module.h"
  20. #include "source/opt/cfg.h"
  21. #include "test/opt/function_utils.h"
  22. namespace spvtools {
  23. namespace opt {
  24. namespace {
  25. void GatherEdges(const ControlDependenceAnalysis& cdg,
  26. std::vector<ControlDependence>& ret) {
  27. cdg.ForEachBlockLabel([&](uint32_t label) {
  28. ret.reserve(ret.size() + cdg.GetDependenceTargets(label).size());
  29. ret.insert(ret.end(), cdg.GetDependenceTargets(label).begin(),
  30. cdg.GetDependenceTargets(label).end());
  31. });
  32. std::sort(ret.begin(), ret.end());
  33. // Verify that reverse graph is the same.
  34. std::vector<ControlDependence> reverse_edges;
  35. reverse_edges.reserve(ret.size());
  36. cdg.ForEachBlockLabel([&](uint32_t label) {
  37. reverse_edges.insert(reverse_edges.end(),
  38. cdg.GetDependenceSources(label).begin(),
  39. cdg.GetDependenceSources(label).end());
  40. });
  41. std::sort(reverse_edges.begin(), reverse_edges.end());
  42. ASSERT_THAT(reverse_edges, testing::ElementsAreArray(ret));
  43. }
  44. using ControlDependenceTest = ::testing::Test;
  45. TEST(ControlDependenceTest, DependenceSimpleCFG) {
  46. const std::string text = R"(
  47. OpCapability Addresses
  48. OpCapability Kernel
  49. OpMemoryModel Physical64 OpenCL
  50. OpEntryPoint Kernel %1 "main"
  51. %2 = OpTypeVoid
  52. %3 = OpTypeFunction %2
  53. %4 = OpTypeBool
  54. %5 = OpTypeInt 32 0
  55. %6 = OpConstant %5 0
  56. %7 = OpConstantFalse %4
  57. %8 = OpConstantTrue %4
  58. %9 = OpConstant %5 1
  59. %1 = OpFunction %2 None %3
  60. %10 = OpLabel
  61. OpBranch %11
  62. %11 = OpLabel
  63. OpSwitch %6 %12 1 %13
  64. %12 = OpLabel
  65. OpBranch %14
  66. %13 = OpLabel
  67. OpBranch %14
  68. %14 = OpLabel
  69. OpBranchConditional %8 %15 %16
  70. %15 = OpLabel
  71. OpBranch %19
  72. %16 = OpLabel
  73. OpBranchConditional %8 %17 %18
  74. %17 = OpLabel
  75. OpBranch %18
  76. %18 = OpLabel
  77. OpBranch %19
  78. %19 = OpLabel
  79. OpReturn
  80. OpFunctionEnd
  81. )";
  82. // CFG: (all edges pointing downward)
  83. // %10
  84. // |
  85. // %11
  86. // / \ (R: %6 == 1, L: default)
  87. // %12 %13
  88. // \ /
  89. // %14
  90. // T/ \F
  91. // %15 %16
  92. // | T/ |F
  93. // | %17|
  94. // | \ |
  95. // | %18
  96. // | /
  97. // %19
  98. std::unique_ptr<IRContext> context =
  99. BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text,
  100. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  101. Module* module = context->module();
  102. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  103. << text << std::endl;
  104. const Function* fn = spvtest::GetFunction(module, 1);
  105. const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10);
  106. EXPECT_EQ(entry, fn->entry().get())
  107. << "The entry node is not the expected one";
  108. {
  109. PostDominatorAnalysis pdom;
  110. const CFG& cfg = *context->cfg();
  111. pdom.InitializeTree(cfg, fn);
  112. ControlDependenceAnalysis cdg;
  113. cdg.ComputeControlDependenceGraph(cfg, pdom);
  114. // Test HasBlock.
  115. for (uint32_t id = 10; id <= 19; id++) {
  116. EXPECT_TRUE(cdg.HasBlock(id));
  117. }
  118. EXPECT_TRUE(cdg.HasBlock(ControlDependenceAnalysis::kPseudoEntryBlock));
  119. // Check blocks before/after valid range.
  120. EXPECT_FALSE(cdg.HasBlock(5));
  121. EXPECT_FALSE(cdg.HasBlock(25));
  122. EXPECT_FALSE(cdg.HasBlock(UINT32_MAX));
  123. // Test ForEachBlockLabel.
  124. std::set<uint32_t> block_labels;
  125. cdg.ForEachBlockLabel([&block_labels](uint32_t id) {
  126. bool inserted = block_labels.insert(id).second;
  127. EXPECT_TRUE(inserted); // Should have no duplicates.
  128. });
  129. EXPECT_THAT(block_labels, testing::ElementsAre(0, 10, 11, 12, 13, 14, 15,
  130. 16, 17, 18, 19));
  131. {
  132. // Test WhileEachBlockLabel.
  133. uint32_t iters = 0;
  134. EXPECT_TRUE(cdg.WhileEachBlockLabel([&iters](uint32_t) {
  135. ++iters;
  136. return true;
  137. }));
  138. EXPECT_EQ((uint32_t)block_labels.size(), iters);
  139. iters = 0;
  140. EXPECT_FALSE(cdg.WhileEachBlockLabel([&iters](uint32_t) {
  141. ++iters;
  142. return false;
  143. }));
  144. EXPECT_EQ(1, iters);
  145. }
  146. // Test IsDependent.
  147. EXPECT_TRUE(cdg.IsDependent(12, 11));
  148. EXPECT_TRUE(cdg.IsDependent(13, 11));
  149. EXPECT_TRUE(cdg.IsDependent(15, 14));
  150. EXPECT_TRUE(cdg.IsDependent(16, 14));
  151. EXPECT_TRUE(cdg.IsDependent(18, 14));
  152. EXPECT_TRUE(cdg.IsDependent(17, 16));
  153. EXPECT_TRUE(cdg.IsDependent(10, 0));
  154. EXPECT_TRUE(cdg.IsDependent(11, 0));
  155. EXPECT_TRUE(cdg.IsDependent(14, 0));
  156. EXPECT_TRUE(cdg.IsDependent(19, 0));
  157. EXPECT_FALSE(cdg.IsDependent(14, 11));
  158. EXPECT_FALSE(cdg.IsDependent(17, 14));
  159. EXPECT_FALSE(cdg.IsDependent(19, 14));
  160. EXPECT_FALSE(cdg.IsDependent(12, 0));
  161. // Test GetDependenceSources/Targets.
  162. std::vector<ControlDependence> edges;
  163. GatherEdges(cdg, edges);
  164. EXPECT_THAT(edges,
  165. testing::ElementsAre(
  166. ControlDependence(0, 10), ControlDependence(0, 11, 10),
  167. ControlDependence(0, 14, 10), ControlDependence(0, 19, 10),
  168. ControlDependence(11, 12), ControlDependence(11, 13),
  169. ControlDependence(14, 15), ControlDependence(14, 16),
  170. ControlDependence(14, 18, 16), ControlDependence(16, 17)));
  171. const uint32_t expected_condition_ids[] = {
  172. 0, 0, 0, 0, 6, 6, 8, 8, 8, 8,
  173. };
  174. for (uint32_t i = 0; i < edges.size(); i++) {
  175. EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg));
  176. }
  177. }
  178. }
  179. TEST(ControlDependenceTest, DependencePaperCFG) {
  180. const std::string text = R"(
  181. OpCapability Addresses
  182. OpCapability Kernel
  183. OpMemoryModel Physical64 OpenCL
  184. OpEntryPoint Kernel %101 "main"
  185. %102 = OpTypeVoid
  186. %103 = OpTypeFunction %102
  187. %104 = OpTypeBool
  188. %108 = OpConstantTrue %104
  189. %101 = OpFunction %102 None %103
  190. %1 = OpLabel
  191. OpBranch %2
  192. %2 = OpLabel
  193. OpBranchConditional %108 %3 %7
  194. %3 = OpLabel
  195. OpBranchConditional %108 %4 %5
  196. %4 = OpLabel
  197. OpBranch %6
  198. %5 = OpLabel
  199. OpBranch %6
  200. %6 = OpLabel
  201. OpBranch %8
  202. %7 = OpLabel
  203. OpBranch %8
  204. %8 = OpLabel
  205. OpBranch %9
  206. %9 = OpLabel
  207. OpBranchConditional %108 %10 %11
  208. %10 = OpLabel
  209. OpBranch %11
  210. %11 = OpLabel
  211. OpBranchConditional %108 %12 %9
  212. %12 = OpLabel
  213. OpBranchConditional %108 %13 %2
  214. %13 = OpLabel
  215. OpReturn
  216. OpFunctionEnd
  217. )";
  218. // CFG: (edges pointing downward if no arrow)
  219. // %1
  220. // |
  221. // %2 <----+
  222. // T/ \F |
  223. // %3 \ |
  224. // T/ \F \ |
  225. // %4 %5 %7 |
  226. // \ / / |
  227. // %6 / |
  228. // \ / |
  229. // %8 |
  230. // | |
  231. // %9 <-+ |
  232. // T/ | | |
  233. // %10 | | |
  234. // \ | | |
  235. // %11-F+ |
  236. // T| |
  237. // %12-F---+
  238. // T|
  239. // %13
  240. std::unique_ptr<IRContext> context =
  241. BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text,
  242. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  243. Module* module = context->module();
  244. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  245. << text << std::endl;
  246. const Function* fn = spvtest::GetFunction(module, 101);
  247. const BasicBlock* entry = spvtest::GetBasicBlock(fn, 1);
  248. EXPECT_EQ(entry, fn->entry().get())
  249. << "The entry node is not the expected one";
  250. {
  251. PostDominatorAnalysis pdom;
  252. const CFG& cfg = *context->cfg();
  253. pdom.InitializeTree(cfg, fn);
  254. ControlDependenceAnalysis cdg;
  255. cdg.ComputeControlDependenceGraph(cfg, pdom);
  256. std::vector<ControlDependence> edges;
  257. GatherEdges(cdg, edges);
  258. EXPECT_THAT(
  259. edges, testing::ElementsAre(
  260. ControlDependence(0, 1), ControlDependence(0, 2, 1),
  261. ControlDependence(0, 8, 1), ControlDependence(0, 9, 1),
  262. ControlDependence(0, 11, 1), ControlDependence(0, 12, 1),
  263. ControlDependence(0, 13, 1), ControlDependence(2, 3),
  264. ControlDependence(2, 6, 3), ControlDependence(2, 7),
  265. ControlDependence(3, 4), ControlDependence(3, 5),
  266. ControlDependence(9, 10), ControlDependence(11, 9),
  267. ControlDependence(11, 11, 9), ControlDependence(12, 2),
  268. ControlDependence(12, 8, 2), ControlDependence(12, 9, 2),
  269. ControlDependence(12, 11, 2), ControlDependence(12, 12, 2)));
  270. const uint32_t expected_condition_ids[] = {
  271. 0, 0, 0, 0, 0, 0, 0, 108, 108, 108,
  272. 108, 108, 108, 108, 108, 108, 108, 108, 108, 108,
  273. };
  274. for (uint32_t i = 0; i < edges.size(); i++) {
  275. EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg));
  276. }
  277. }
  278. }
  279. } // namespace
  280. } // namespace opt
  281. } // namespace spvtools