scalar_analysis.cpp 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <memory>
  15. #include <string>
  16. #include <unordered_set>
  17. #include <vector>
  18. #include "gmock/gmock.h"
  19. #include "source/opt/iterator.h"
  20. #include "source/opt/loop_descriptor.h"
  21. #include "source/opt/pass.h"
  22. #include "source/opt/scalar_analysis.h"
  23. #include "source/opt/tree_iterator.h"
  24. #include "test/opt/assembly_builder.h"
  25. #include "test/opt/function_utils.h"
  26. #include "test/opt/pass_fixture.h"
  27. #include "test/opt/pass_utils.h"
  28. namespace spvtools {
  29. namespace opt {
  30. namespace {
  31. using ::testing::UnorderedElementsAre;
  32. using ScalarAnalysisTest = PassTest<::testing::Test>;
  33. /*
  34. Generated from the following GLSL + --eliminate-local-multi-store
  35. #version 410 core
  36. layout (location = 1) out float array[10];
  37. void main() {
  38. for (int i = 0; i < 10; ++i) {
  39. array[i] = array[i+1];
  40. }
  41. }
  42. */
  43. TEST_F(ScalarAnalysisTest, BasicEvolutionTest) {
  44. const std::string text = R"(
  45. OpCapability Shader
  46. %1 = OpExtInstImport "GLSL.std.450"
  47. OpMemoryModel Logical GLSL450
  48. OpEntryPoint Fragment %4 "main" %24
  49. OpExecutionMode %4 OriginUpperLeft
  50. OpSource GLSL 410
  51. OpName %4 "main"
  52. OpName %24 "array"
  53. OpDecorate %24 Location 1
  54. %2 = OpTypeVoid
  55. %3 = OpTypeFunction %2
  56. %6 = OpTypeInt 32 1
  57. %7 = OpTypePointer Function %6
  58. %9 = OpConstant %6 0
  59. %16 = OpConstant %6 10
  60. %17 = OpTypeBool
  61. %19 = OpTypeFloat 32
  62. %20 = OpTypeInt 32 0
  63. %21 = OpConstant %20 10
  64. %22 = OpTypeArray %19 %21
  65. %23 = OpTypePointer Output %22
  66. %24 = OpVariable %23 Output
  67. %27 = OpConstant %6 1
  68. %29 = OpTypePointer Output %19
  69. %4 = OpFunction %2 None %3
  70. %5 = OpLabel
  71. OpBranch %10
  72. %10 = OpLabel
  73. %35 = OpPhi %6 %9 %5 %34 %13
  74. OpLoopMerge %12 %13 None
  75. OpBranch %14
  76. %14 = OpLabel
  77. %18 = OpSLessThan %17 %35 %16
  78. OpBranchConditional %18 %11 %12
  79. %11 = OpLabel
  80. %28 = OpIAdd %6 %35 %27
  81. %30 = OpAccessChain %29 %24 %28
  82. %31 = OpLoad %19 %30
  83. %32 = OpAccessChain %29 %24 %35
  84. OpStore %32 %31
  85. OpBranch %13
  86. %13 = OpLabel
  87. %34 = OpIAdd %6 %35 %27
  88. OpBranch %10
  89. %12 = OpLabel
  90. OpReturn
  91. OpFunctionEnd
  92. )";
  93. // clang-format on
  94. std::unique_ptr<IRContext> context =
  95. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  96. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  97. Module* module = context->module();
  98. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  99. << text << std::endl;
  100. const Function* f = spvtest::GetFunction(module, 4);
  101. ScalarEvolutionAnalysis analysis{context.get()};
  102. const Instruction* store = nullptr;
  103. const Instruction* load = nullptr;
  104. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) {
  105. if (inst.opcode() == SpvOp::SpvOpStore) {
  106. store = &inst;
  107. }
  108. if (inst.opcode() == SpvOp::SpvOpLoad) {
  109. load = &inst;
  110. }
  111. }
  112. EXPECT_NE(load, nullptr);
  113. EXPECT_NE(store, nullptr);
  114. Instruction* access_chain =
  115. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  116. Instruction* child = context->get_def_use_mgr()->GetDef(
  117. access_chain->GetSingleWordInOperand(1));
  118. const SENode* node = analysis.AnalyzeInstruction(child);
  119. EXPECT_NE(node, nullptr);
  120. // Unsimplified node should have the form of ADD(REC(0,1), 1)
  121. EXPECT_EQ(node->GetType(), SENode::Add);
  122. const SENode* child_1 = node->GetChild(0);
  123. EXPECT_TRUE(child_1->GetType() == SENode::Constant ||
  124. child_1->GetType() == SENode::RecurrentAddExpr);
  125. const SENode* child_2 = node->GetChild(1);
  126. EXPECT_TRUE(child_2->GetType() == SENode::Constant ||
  127. child_2->GetType() == SENode::RecurrentAddExpr);
  128. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  129. // Simplified should be in the form of REC(1,1)
  130. EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
  131. EXPECT_EQ(simplified->GetChild(0)->GetType(), SENode::Constant);
  132. EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(),
  133. 1);
  134. EXPECT_EQ(simplified->GetChild(1)->GetType(), SENode::Constant);
  135. EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(),
  136. 1);
  137. EXPECT_EQ(simplified->GetChild(0), simplified->GetChild(1));
  138. }
  139. /*
  140. Generated from the following GLSL + --eliminate-local-multi-store
  141. #version 410 core
  142. layout (location = 1) out float array[10];
  143. layout (location = 2) flat in int loop_invariant;
  144. void main() {
  145. for (int i = 0; i < 10; ++i) {
  146. array[i] = array[i+loop_invariant];
  147. }
  148. }
  149. */
  150. TEST_F(ScalarAnalysisTest, LoadTest) {
  151. const std::string text = R"(
  152. OpCapability Shader
  153. %1 = OpExtInstImport "GLSL.std.450"
  154. OpMemoryModel Logical GLSL450
  155. OpEntryPoint Fragment %2 "main" %3 %4
  156. OpExecutionMode %2 OriginUpperLeft
  157. OpSource GLSL 430
  158. OpName %2 "main"
  159. OpName %3 "array"
  160. OpName %4 "loop_invariant"
  161. OpDecorate %3 Location 1
  162. OpDecorate %4 Flat
  163. OpDecorate %4 Location 2
  164. %5 = OpTypeVoid
  165. %6 = OpTypeFunction %5
  166. %7 = OpTypeInt 32 1
  167. %8 = OpTypePointer Function %7
  168. %9 = OpConstant %7 0
  169. %10 = OpConstant %7 10
  170. %11 = OpTypeBool
  171. %12 = OpTypeFloat 32
  172. %13 = OpTypeInt 32 0
  173. %14 = OpConstant %13 10
  174. %15 = OpTypeArray %12 %14
  175. %16 = OpTypePointer Output %15
  176. %3 = OpVariable %16 Output
  177. %17 = OpTypePointer Input %7
  178. %4 = OpVariable %17 Input
  179. %18 = OpTypePointer Output %12
  180. %19 = OpConstant %7 1
  181. %2 = OpFunction %5 None %6
  182. %20 = OpLabel
  183. OpBranch %21
  184. %21 = OpLabel
  185. %22 = OpPhi %7 %9 %20 %23 %24
  186. OpLoopMerge %25 %24 None
  187. OpBranch %26
  188. %26 = OpLabel
  189. %27 = OpSLessThan %11 %22 %10
  190. OpBranchConditional %27 %28 %25
  191. %28 = OpLabel
  192. %29 = OpLoad %7 %4
  193. %30 = OpIAdd %7 %22 %29
  194. %31 = OpAccessChain %18 %3 %30
  195. %32 = OpLoad %12 %31
  196. %33 = OpAccessChain %18 %3 %22
  197. OpStore %33 %32
  198. OpBranch %24
  199. %24 = OpLabel
  200. %23 = OpIAdd %7 %22 %19
  201. OpBranch %21
  202. %25 = OpLabel
  203. OpReturn
  204. OpFunctionEnd
  205. )";
  206. // clang-format on
  207. std::unique_ptr<IRContext> context =
  208. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  209. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  210. Module* module = context->module();
  211. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  212. << text << std::endl;
  213. const Function* f = spvtest::GetFunction(module, 2);
  214. ScalarEvolutionAnalysis analysis{context.get()};
  215. const Instruction* load = nullptr;
  216. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 28)) {
  217. if (inst.opcode() == SpvOp::SpvOpLoad) {
  218. load = &inst;
  219. }
  220. }
  221. EXPECT_NE(load, nullptr);
  222. Instruction* access_chain =
  223. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  224. Instruction* child = context->get_def_use_mgr()->GetDef(
  225. access_chain->GetSingleWordInOperand(1));
  226. // const SENode* node =
  227. // analysis.GetNodeFromInstruction(child->unique_id());
  228. const SENode* node = analysis.AnalyzeInstruction(child);
  229. EXPECT_NE(node, nullptr);
  230. // Unsimplified node should have the form of ADD(REC(0,1), X)
  231. EXPECT_EQ(node->GetType(), SENode::Add);
  232. const SENode* child_1 = node->GetChild(0);
  233. EXPECT_TRUE(child_1->GetType() == SENode::ValueUnknown ||
  234. child_1->GetType() == SENode::RecurrentAddExpr);
  235. const SENode* child_2 = node->GetChild(1);
  236. EXPECT_TRUE(child_2->GetType() == SENode::ValueUnknown ||
  237. child_2->GetType() == SENode::RecurrentAddExpr);
  238. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  239. EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
  240. const SERecurrentNode* rec = simplified->AsSERecurrentNode();
  241. EXPECT_NE(rec->GetChild(0), rec->GetChild(1));
  242. EXPECT_EQ(rec->GetOffset()->GetType(), SENode::ValueUnknown);
  243. EXPECT_EQ(rec->GetCoefficient()->GetType(), SENode::Constant);
  244. EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u);
  245. }
  246. /*
  247. Generated from the following GLSL + --eliminate-local-multi-store
  248. #version 410 core
  249. layout (location = 1) out float array[10];
  250. layout (location = 2) flat in int loop_invariant;
  251. void main() {
  252. array[0] = array[loop_invariant * 2 + 4 + 5 - 24 - loop_invariant -
  253. loop_invariant+ 16 * 3];
  254. }
  255. */
  256. TEST_F(ScalarAnalysisTest, SimplifySimple) {
  257. const std::string text = R"(
  258. OpCapability Shader
  259. %1 = OpExtInstImport "GLSL.std.450"
  260. OpMemoryModel Logical GLSL450
  261. OpEntryPoint Fragment %2 "main" %3 %4
  262. OpExecutionMode %2 OriginUpperLeft
  263. OpSource GLSL 430
  264. OpName %2 "main"
  265. OpName %3 "array"
  266. OpName %4 "loop_invariant"
  267. OpDecorate %3 Location 1
  268. OpDecorate %4 Flat
  269. OpDecorate %4 Location 2
  270. %5 = OpTypeVoid
  271. %6 = OpTypeFunction %5
  272. %7 = OpTypeFloat 32
  273. %8 = OpTypeInt 32 0
  274. %9 = OpConstant %8 10
  275. %10 = OpTypeArray %7 %9
  276. %11 = OpTypePointer Output %10
  277. %3 = OpVariable %11 Output
  278. %12 = OpTypeInt 32 1
  279. %13 = OpConstant %12 0
  280. %14 = OpTypePointer Input %12
  281. %4 = OpVariable %14 Input
  282. %15 = OpConstant %12 2
  283. %16 = OpConstant %12 4
  284. %17 = OpConstant %12 5
  285. %18 = OpConstant %12 24
  286. %19 = OpConstant %12 48
  287. %20 = OpTypePointer Output %7
  288. %2 = OpFunction %5 None %6
  289. %21 = OpLabel
  290. %22 = OpLoad %12 %4
  291. %23 = OpIMul %12 %22 %15
  292. %24 = OpIAdd %12 %23 %16
  293. %25 = OpIAdd %12 %24 %17
  294. %26 = OpISub %12 %25 %18
  295. %28 = OpISub %12 %26 %22
  296. %30 = OpISub %12 %28 %22
  297. %31 = OpIAdd %12 %30 %19
  298. %32 = OpAccessChain %20 %3 %31
  299. %33 = OpLoad %7 %32
  300. %34 = OpAccessChain %20 %3 %13
  301. OpStore %34 %33
  302. OpReturn
  303. OpFunctionEnd
  304. )";
  305. // clang-format on
  306. std::unique_ptr<IRContext> context =
  307. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  308. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  309. Module* module = context->module();
  310. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  311. << text << std::endl;
  312. const Function* f = spvtest::GetFunction(module, 2);
  313. ScalarEvolutionAnalysis analysis{context.get()};
  314. const Instruction* load = nullptr;
  315. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
  316. if (inst.opcode() == SpvOp::SpvOpLoad && inst.result_id() == 33) {
  317. load = &inst;
  318. }
  319. }
  320. EXPECT_NE(load, nullptr);
  321. Instruction* access_chain =
  322. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  323. Instruction* child = context->get_def_use_mgr()->GetDef(
  324. access_chain->GetSingleWordInOperand(1));
  325. const SENode* node = analysis.AnalyzeInstruction(child);
  326. // Unsimplified is a very large graph with an add at the top.
  327. EXPECT_NE(node, nullptr);
  328. EXPECT_EQ(node->GetType(), SENode::Add);
  329. // Simplified node should resolve down to a constant expression as the loads
  330. // will eliminate themselves.
  331. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  332. EXPECT_EQ(simplified->GetType(), SENode::Constant);
  333. EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u);
  334. }
  335. /*
  336. Generated from the following GLSL + --eliminate-local-multi-store
  337. #version 410 core
  338. layout(location = 0) in vec4 c;
  339. layout (location = 1) out float array[10];
  340. void main() {
  341. int N = int(c.x);
  342. for (int i = 0; i < 10; ++i) {
  343. array[i] = array[i];
  344. array[i] = array[i-1];
  345. array[i] = array[i+1];
  346. array[i+1] = array[i+1];
  347. array[i+N] = array[i+N];
  348. array[i] = array[i+N];
  349. }
  350. }
  351. */
  352. TEST_F(ScalarAnalysisTest, Simplify) {
  353. const std::string text = R"( OpCapability Shader
  354. %1 = OpExtInstImport "GLSL.std.450"
  355. OpMemoryModel Logical GLSL450
  356. OpEntryPoint Fragment %4 "main" %12 %33
  357. OpExecutionMode %4 OriginUpperLeft
  358. OpSource GLSL 410
  359. OpName %4 "main"
  360. OpName %8 "N"
  361. OpName %12 "c"
  362. OpName %19 "i"
  363. OpName %33 "array"
  364. OpDecorate %12 Location 0
  365. OpDecorate %33 Location 1
  366. %2 = OpTypeVoid
  367. %3 = OpTypeFunction %2
  368. %6 = OpTypeInt 32 1
  369. %7 = OpTypePointer Function %6
  370. %9 = OpTypeFloat 32
  371. %10 = OpTypeVector %9 4
  372. %11 = OpTypePointer Input %10
  373. %12 = OpVariable %11 Input
  374. %13 = OpTypeInt 32 0
  375. %14 = OpConstant %13 0
  376. %15 = OpTypePointer Input %9
  377. %20 = OpConstant %6 0
  378. %27 = OpConstant %6 10
  379. %28 = OpTypeBool
  380. %30 = OpConstant %13 10
  381. %31 = OpTypeArray %9 %30
  382. %32 = OpTypePointer Output %31
  383. %33 = OpVariable %32 Output
  384. %36 = OpTypePointer Output %9
  385. %42 = OpConstant %6 1
  386. %4 = OpFunction %2 None %3
  387. %5 = OpLabel
  388. %8 = OpVariable %7 Function
  389. %19 = OpVariable %7 Function
  390. %16 = OpAccessChain %15 %12 %14
  391. %17 = OpLoad %9 %16
  392. %18 = OpConvertFToS %6 %17
  393. OpStore %8 %18
  394. OpStore %19 %20
  395. OpBranch %21
  396. %21 = OpLabel
  397. %78 = OpPhi %6 %20 %5 %77 %24
  398. OpLoopMerge %23 %24 None
  399. OpBranch %25
  400. %25 = OpLabel
  401. %29 = OpSLessThan %28 %78 %27
  402. OpBranchConditional %29 %22 %23
  403. %22 = OpLabel
  404. %37 = OpAccessChain %36 %33 %78
  405. %38 = OpLoad %9 %37
  406. %39 = OpAccessChain %36 %33 %78
  407. OpStore %39 %38
  408. %43 = OpISub %6 %78 %42
  409. %44 = OpAccessChain %36 %33 %43
  410. %45 = OpLoad %9 %44
  411. %46 = OpAccessChain %36 %33 %78
  412. OpStore %46 %45
  413. %49 = OpIAdd %6 %78 %42
  414. %50 = OpAccessChain %36 %33 %49
  415. %51 = OpLoad %9 %50
  416. %52 = OpAccessChain %36 %33 %78
  417. OpStore %52 %51
  418. %54 = OpIAdd %6 %78 %42
  419. %56 = OpIAdd %6 %78 %42
  420. %57 = OpAccessChain %36 %33 %56
  421. %58 = OpLoad %9 %57
  422. %59 = OpAccessChain %36 %33 %54
  423. OpStore %59 %58
  424. %62 = OpIAdd %6 %78 %18
  425. %65 = OpIAdd %6 %78 %18
  426. %66 = OpAccessChain %36 %33 %65
  427. %67 = OpLoad %9 %66
  428. %68 = OpAccessChain %36 %33 %62
  429. OpStore %68 %67
  430. %72 = OpIAdd %6 %78 %18
  431. %73 = OpAccessChain %36 %33 %72
  432. %74 = OpLoad %9 %73
  433. %75 = OpAccessChain %36 %33 %78
  434. OpStore %75 %74
  435. OpBranch %24
  436. %24 = OpLabel
  437. %77 = OpIAdd %6 %78 %42
  438. OpStore %19 %77
  439. OpBranch %21
  440. %23 = OpLabel
  441. OpReturn
  442. OpFunctionEnd
  443. )";
  444. // clang-format on
  445. std::unique_ptr<IRContext> context =
  446. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  447. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  448. Module* module = context->module();
  449. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  450. << text << std::endl;
  451. const Function* f = spvtest::GetFunction(module, 4);
  452. ScalarEvolutionAnalysis analysis{context.get()};
  453. const Instruction* loads[6];
  454. const Instruction* stores[6];
  455. int load_count = 0;
  456. int store_count = 0;
  457. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) {
  458. if (inst.opcode() == SpvOp::SpvOpLoad) {
  459. loads[load_count] = &inst;
  460. ++load_count;
  461. }
  462. if (inst.opcode() == SpvOp::SpvOpStore) {
  463. stores[store_count] = &inst;
  464. ++store_count;
  465. }
  466. }
  467. EXPECT_EQ(load_count, 6);
  468. EXPECT_EQ(store_count, 6);
  469. Instruction* load_access_chain;
  470. Instruction* store_access_chain;
  471. Instruction* load_child;
  472. Instruction* store_child;
  473. SENode* load_node;
  474. SENode* store_node;
  475. SENode* subtract_node;
  476. SENode* simplified_node;
  477. // Testing [i] - [i] == 0
  478. load_access_chain =
  479. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  480. store_access_chain =
  481. context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
  482. load_child = context->get_def_use_mgr()->GetDef(
  483. load_access_chain->GetSingleWordInOperand(1));
  484. store_child = context->get_def_use_mgr()->GetDef(
  485. store_access_chain->GetSingleWordInOperand(1));
  486. load_node = analysis.AnalyzeInstruction(load_child);
  487. store_node = analysis.AnalyzeInstruction(store_child);
  488. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  489. simplified_node = analysis.SimplifyExpression(subtract_node);
  490. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  491. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  492. // Testing [i] - [i-1] == 1
  493. load_access_chain =
  494. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  495. store_access_chain =
  496. context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
  497. load_child = context->get_def_use_mgr()->GetDef(
  498. load_access_chain->GetSingleWordInOperand(1));
  499. store_child = context->get_def_use_mgr()->GetDef(
  500. store_access_chain->GetSingleWordInOperand(1));
  501. load_node = analysis.AnalyzeInstruction(load_child);
  502. store_node = analysis.AnalyzeInstruction(store_child);
  503. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  504. simplified_node = analysis.SimplifyExpression(subtract_node);
  505. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  506. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u);
  507. // Testing [i] - [i+1] == -1
  508. load_access_chain =
  509. context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
  510. store_access_chain =
  511. context->get_def_use_mgr()->GetDef(stores[2]->GetSingleWordInOperand(0));
  512. load_child = context->get_def_use_mgr()->GetDef(
  513. load_access_chain->GetSingleWordInOperand(1));
  514. store_child = context->get_def_use_mgr()->GetDef(
  515. store_access_chain->GetSingleWordInOperand(1));
  516. load_node = analysis.AnalyzeInstruction(load_child);
  517. store_node = analysis.AnalyzeInstruction(store_child);
  518. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  519. simplified_node = analysis.SimplifyExpression(subtract_node);
  520. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  521. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1);
  522. // Testing [i+1] - [i+1] == 0
  523. load_access_chain =
  524. context->get_def_use_mgr()->GetDef(loads[3]->GetSingleWordInOperand(0));
  525. store_access_chain =
  526. context->get_def_use_mgr()->GetDef(stores[3]->GetSingleWordInOperand(0));
  527. load_child = context->get_def_use_mgr()->GetDef(
  528. load_access_chain->GetSingleWordInOperand(1));
  529. store_child = context->get_def_use_mgr()->GetDef(
  530. store_access_chain->GetSingleWordInOperand(1));
  531. load_node = analysis.AnalyzeInstruction(load_child);
  532. store_node = analysis.AnalyzeInstruction(store_child);
  533. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  534. simplified_node = analysis.SimplifyExpression(subtract_node);
  535. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  536. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  537. // Testing [i+N] - [i+N] == 0
  538. load_access_chain =
  539. context->get_def_use_mgr()->GetDef(loads[4]->GetSingleWordInOperand(0));
  540. store_access_chain =
  541. context->get_def_use_mgr()->GetDef(stores[4]->GetSingleWordInOperand(0));
  542. load_child = context->get_def_use_mgr()->GetDef(
  543. load_access_chain->GetSingleWordInOperand(1));
  544. store_child = context->get_def_use_mgr()->GetDef(
  545. store_access_chain->GetSingleWordInOperand(1));
  546. load_node = analysis.AnalyzeInstruction(load_child);
  547. store_node = analysis.AnalyzeInstruction(store_child);
  548. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  549. simplified_node = analysis.SimplifyExpression(subtract_node);
  550. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  551. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  552. // Testing [i] - [i+N] == -N
  553. load_access_chain =
  554. context->get_def_use_mgr()->GetDef(loads[5]->GetSingleWordInOperand(0));
  555. store_access_chain =
  556. context->get_def_use_mgr()->GetDef(stores[5]->GetSingleWordInOperand(0));
  557. load_child = context->get_def_use_mgr()->GetDef(
  558. load_access_chain->GetSingleWordInOperand(1));
  559. store_child = context->get_def_use_mgr()->GetDef(
  560. store_access_chain->GetSingleWordInOperand(1));
  561. load_node = analysis.AnalyzeInstruction(load_child);
  562. store_node = analysis.AnalyzeInstruction(store_child);
  563. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  564. simplified_node = analysis.SimplifyExpression(subtract_node);
  565. EXPECT_EQ(simplified_node->GetType(), SENode::Negative);
  566. }
  567. /*
  568. Generated from the following GLSL + --eliminate-local-multi-store
  569. #version 430
  570. layout(location = 1) out float array[10];
  571. layout(location = 2) flat in int loop_invariant;
  572. void main(void) {
  573. for (int i = 0; i < 10; ++i) {
  574. array[i * 2 + i * 5] = array[i * i * 2];
  575. array[i * 2] = array[i * 5];
  576. }
  577. }
  578. */
  579. TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) {
  580. const std::string text = R"(
  581. OpCapability Shader
  582. %1 = OpExtInstImport "GLSL.std.450"
  583. OpMemoryModel Logical GLSL450
  584. OpEntryPoint Fragment %2 "main" %3 %4
  585. OpExecutionMode %2 OriginUpperLeft
  586. OpSource GLSL 430
  587. OpName %2 "main"
  588. OpName %5 "i"
  589. OpName %3 "array"
  590. OpName %4 "loop_invariant"
  591. OpDecorate %3 Location 1
  592. OpDecorate %4 Flat
  593. OpDecorate %4 Location 2
  594. %6 = OpTypeVoid
  595. %7 = OpTypeFunction %6
  596. %8 = OpTypeInt 32 1
  597. %9 = OpTypePointer Function %8
  598. %10 = OpConstant %8 0
  599. %11 = OpConstant %8 10
  600. %12 = OpTypeBool
  601. %13 = OpTypeFloat 32
  602. %14 = OpTypeInt 32 0
  603. %15 = OpConstant %14 10
  604. %16 = OpTypeArray %13 %15
  605. %17 = OpTypePointer Output %16
  606. %3 = OpVariable %17 Output
  607. %18 = OpConstant %8 2
  608. %19 = OpConstant %8 5
  609. %20 = OpTypePointer Output %13
  610. %21 = OpConstant %8 1
  611. %22 = OpTypePointer Input %8
  612. %4 = OpVariable %22 Input
  613. %2 = OpFunction %6 None %7
  614. %23 = OpLabel
  615. %5 = OpVariable %9 Function
  616. OpStore %5 %10
  617. OpBranch %24
  618. %24 = OpLabel
  619. %25 = OpPhi %8 %10 %23 %26 %27
  620. OpLoopMerge %28 %27 None
  621. OpBranch %29
  622. %29 = OpLabel
  623. %30 = OpSLessThan %12 %25 %11
  624. OpBranchConditional %30 %31 %28
  625. %31 = OpLabel
  626. %32 = OpIMul %8 %25 %18
  627. %33 = OpIMul %8 %25 %19
  628. %34 = OpIAdd %8 %32 %33
  629. %35 = OpIMul %8 %25 %25
  630. %36 = OpIMul %8 %35 %18
  631. %37 = OpAccessChain %20 %3 %36
  632. %38 = OpLoad %13 %37
  633. %39 = OpAccessChain %20 %3 %34
  634. OpStore %39 %38
  635. %40 = OpIMul %8 %25 %18
  636. %41 = OpIMul %8 %25 %19
  637. %42 = OpAccessChain %20 %3 %41
  638. %43 = OpLoad %13 %42
  639. %44 = OpAccessChain %20 %3 %40
  640. OpStore %44 %43
  641. OpBranch %27
  642. %27 = OpLabel
  643. %26 = OpIAdd %8 %25 %21
  644. OpStore %5 %26
  645. OpBranch %24
  646. %28 = OpLabel
  647. OpReturn
  648. OpFunctionEnd
  649. )";
  650. std::unique_ptr<IRContext> context =
  651. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  652. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  653. Module* module = context->module();
  654. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  655. << text << std::endl;
  656. const Function* f = spvtest::GetFunction(module, 2);
  657. ScalarEvolutionAnalysis analysis{context.get()};
  658. const Instruction* loads[2] = {nullptr, nullptr};
  659. const Instruction* stores[2] = {nullptr, nullptr};
  660. int load_count = 0;
  661. int store_count = 0;
  662. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 31)) {
  663. if (inst.opcode() == SpvOp::SpvOpLoad) {
  664. loads[load_count] = &inst;
  665. ++load_count;
  666. }
  667. if (inst.opcode() == SpvOp::SpvOpStore) {
  668. stores[store_count] = &inst;
  669. ++store_count;
  670. }
  671. }
  672. EXPECT_EQ(load_count, 2);
  673. EXPECT_EQ(store_count, 2);
  674. Instruction* load_access_chain =
  675. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  676. Instruction* store_access_chain =
  677. context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
  678. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  679. load_access_chain->GetSingleWordInOperand(1));
  680. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  681. store_access_chain->GetSingleWordInOperand(1));
  682. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  683. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  684. load_access_chain =
  685. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  686. store_access_chain =
  687. context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
  688. load_child = context->get_def_use_mgr()->GetDef(
  689. load_access_chain->GetSingleWordInOperand(1));
  690. store_child = context->get_def_use_mgr()->GetDef(
  691. store_access_chain->GetSingleWordInOperand(1));
  692. SENode* second_store =
  693. analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child));
  694. SENode* second_load =
  695. analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child));
  696. SENode* combined_add = analysis.SimplifyExpression(
  697. analysis.CreateAddNode(second_load, second_store));
  698. // We're checking that the two recurrent expression have been correctly
  699. // folded. In store_simplified they will have been folded as the entire
  700. // expression was simplified as one. In combined_add the two expressions have
  701. // been simplified one after the other which means the recurrent expressions
  702. // aren't exactly the same but should still be folded as they are with respect
  703. // to the same loop.
  704. EXPECT_EQ(combined_add, store_simplified);
  705. }
  706. /*
  707. Generated from the following GLSL + --eliminate-local-multi-store
  708. #version 430
  709. void main(void) {
  710. for (int i = 0; i < 10; --i) {
  711. array[i] = array[i];
  712. }
  713. }
  714. */
  715. TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) {
  716. const std::string text = R"(
  717. OpCapability Shader
  718. %1 = OpExtInstImport "GLSL.std.450"
  719. OpMemoryModel Logical GLSL450
  720. OpEntryPoint Fragment %2 "main" %3 %4
  721. OpExecutionMode %2 OriginUpperLeft
  722. OpSource GLSL 430
  723. OpName %2 "main"
  724. OpName %5 "i"
  725. OpName %3 "array"
  726. OpName %4 "loop_invariant"
  727. OpDecorate %3 Location 1
  728. OpDecorate %4 Flat
  729. OpDecorate %4 Location 2
  730. %6 = OpTypeVoid
  731. %7 = OpTypeFunction %6
  732. %8 = OpTypeInt 32 1
  733. %9 = OpTypePointer Function %8
  734. %10 = OpConstant %8 0
  735. %11 = OpConstant %8 10
  736. %12 = OpTypeBool
  737. %13 = OpTypeFloat 32
  738. %14 = OpTypeInt 32 0
  739. %15 = OpConstant %14 10
  740. %16 = OpTypeArray %13 %15
  741. %17 = OpTypePointer Output %16
  742. %3 = OpVariable %17 Output
  743. %18 = OpTypePointer Output %13
  744. %19 = OpConstant %8 1
  745. %20 = OpTypePointer Input %8
  746. %4 = OpVariable %20 Input
  747. %2 = OpFunction %6 None %7
  748. %21 = OpLabel
  749. %5 = OpVariable %9 Function
  750. OpStore %5 %10
  751. OpBranch %22
  752. %22 = OpLabel
  753. %23 = OpPhi %8 %10 %21 %24 %25
  754. OpLoopMerge %26 %25 None
  755. OpBranch %27
  756. %27 = OpLabel
  757. %28 = OpSLessThan %12 %23 %11
  758. OpBranchConditional %28 %29 %26
  759. %29 = OpLabel
  760. %30 = OpAccessChain %18 %3 %23
  761. %31 = OpLoad %13 %30
  762. %32 = OpAccessChain %18 %3 %23
  763. OpStore %32 %31
  764. OpBranch %25
  765. %25 = OpLabel
  766. %24 = OpISub %8 %23 %19
  767. OpStore %5 %24
  768. OpBranch %22
  769. %26 = OpLabel
  770. OpReturn
  771. OpFunctionEnd
  772. )";
  773. std::unique_ptr<IRContext> context =
  774. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  775. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  776. Module* module = context->module();
  777. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  778. << text << std::endl;
  779. const Function* f = spvtest::GetFunction(module, 2);
  780. ScalarEvolutionAnalysis analysis{context.get()};
  781. const Instruction* loads[1] = {nullptr};
  782. int load_count = 0;
  783. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) {
  784. if (inst.opcode() == SpvOp::SpvOpLoad) {
  785. loads[load_count] = &inst;
  786. ++load_count;
  787. }
  788. }
  789. EXPECT_EQ(load_count, 1);
  790. Instruction* load_access_chain =
  791. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  792. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  793. load_access_chain->GetSingleWordInOperand(1));
  794. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  795. EXPECT_TRUE(load_node);
  796. EXPECT_EQ(load_node->GetType(), SENode::RecurrentAddExpr);
  797. EXPECT_TRUE(load_node->AsSERecurrentNode());
  798. SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient();
  799. SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset();
  800. EXPECT_EQ(child_1->GetType(), SENode::Constant);
  801. EXPECT_EQ(child_2->GetType(), SENode::Constant);
  802. EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1);
  803. EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u);
  804. SERecurrentNode* load_simplified =
  805. analysis.SimplifyExpression(load_node)->AsSERecurrentNode();
  806. EXPECT_TRUE(load_simplified);
  807. EXPECT_EQ(load_node, load_simplified);
  808. EXPECT_EQ(load_simplified->GetType(), SENode::RecurrentAddExpr);
  809. EXPECT_TRUE(load_simplified->AsSERecurrentNode());
  810. SENode* simplified_child_1 =
  811. load_simplified->AsSERecurrentNode()->GetCoefficient();
  812. SENode* simplified_child_2 =
  813. load_simplified->AsSERecurrentNode()->GetOffset();
  814. EXPECT_EQ(child_1, simplified_child_1);
  815. EXPECT_EQ(child_2, simplified_child_2);
  816. }
  817. /*
  818. Generated from the following GLSL + --eliminate-local-multi-store
  819. #version 430
  820. void main(void) {
  821. for (int i = 0; i < 10; --i) {
  822. array[i] = array[i];
  823. }
  824. }
  825. */
  826. TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) {
  827. const std::string text = R"(
  828. OpCapability Shader
  829. %1 = OpExtInstImport "GLSL.std.450"
  830. OpMemoryModel Logical GLSL450
  831. OpEntryPoint Fragment %2 "main" %3 %4
  832. OpExecutionMode %2 OriginUpperLeft
  833. OpSource GLSL 430
  834. OpName %2 "main"
  835. OpName %5 "i"
  836. OpName %3 "array"
  837. OpName %4 "N"
  838. OpDecorate %3 Location 1
  839. OpDecorate %4 Flat
  840. OpDecorate %4 Location 2
  841. %6 = OpTypeVoid
  842. %7 = OpTypeFunction %6
  843. %8 = OpTypeInt 32 1
  844. %9 = OpTypePointer Function %8
  845. %10 = OpConstant %8 0
  846. %11 = OpConstant %8 10
  847. %12 = OpTypeBool
  848. %13 = OpTypeFloat 32
  849. %14 = OpTypeInt 32 0
  850. %15 = OpConstant %14 10
  851. %16 = OpTypeArray %13 %15
  852. %17 = OpTypePointer Output %16
  853. %3 = OpVariable %17 Output
  854. %18 = OpConstant %8 2
  855. %19 = OpTypePointer Input %8
  856. %4 = OpVariable %19 Input
  857. %20 = OpTypePointer Output %13
  858. %21 = OpConstant %8 1
  859. %2 = OpFunction %6 None %7
  860. %22 = OpLabel
  861. %5 = OpVariable %9 Function
  862. OpStore %5 %10
  863. OpBranch %23
  864. %23 = OpLabel
  865. %24 = OpPhi %8 %10 %22 %25 %26
  866. OpLoopMerge %27 %26 None
  867. OpBranch %28
  868. %28 = OpLabel
  869. %29 = OpSLessThan %12 %24 %11
  870. OpBranchConditional %29 %30 %27
  871. %30 = OpLabel
  872. %31 = OpLoad %8 %4
  873. %32 = OpIMul %8 %18 %31
  874. %33 = OpIAdd %8 %24 %32
  875. %35 = OpIAdd %8 %24 %31
  876. %36 = OpAccessChain %20 %3 %35
  877. %37 = OpLoad %13 %36
  878. %38 = OpAccessChain %20 %3 %33
  879. OpStore %38 %37
  880. %39 = OpIMul %8 %18 %24
  881. %41 = OpIMul %8 %18 %31
  882. %42 = OpIAdd %8 %39 %41
  883. %43 = OpIAdd %8 %42 %21
  884. %44 = OpIMul %8 %18 %24
  885. %46 = OpIAdd %8 %44 %31
  886. %47 = OpIAdd %8 %46 %21
  887. %48 = OpAccessChain %20 %3 %47
  888. %49 = OpLoad %13 %48
  889. %50 = OpAccessChain %20 %3 %43
  890. OpStore %50 %49
  891. OpBranch %26
  892. %26 = OpLabel
  893. %25 = OpISub %8 %24 %21
  894. OpStore %5 %25
  895. OpBranch %23
  896. %27 = OpLabel
  897. OpReturn
  898. OpFunctionEnd
  899. )";
  900. std::unique_ptr<IRContext> context =
  901. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  902. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  903. Module* module = context->module();
  904. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  905. << text << std::endl;
  906. const Function* f = spvtest::GetFunction(module, 2);
  907. ScalarEvolutionAnalysis analysis{context.get()};
  908. std::vector<const Instruction*> loads{};
  909. std::vector<const Instruction*> stores{};
  910. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) {
  911. if (inst.opcode() == SpvOp::SpvOpLoad) {
  912. loads.push_back(&inst);
  913. }
  914. if (inst.opcode() == SpvOp::SpvOpStore) {
  915. stores.push_back(&inst);
  916. }
  917. }
  918. EXPECT_EQ(loads.size(), 3u);
  919. EXPECT_EQ(stores.size(), 2u);
  920. {
  921. Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
  922. stores[0]->GetSingleWordInOperand(0));
  923. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  924. store_access_chain->GetSingleWordInOperand(1));
  925. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  926. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  927. Instruction* load_access_chain =
  928. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  929. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  930. load_access_chain->GetSingleWordInOperand(1));
  931. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  932. SENode* load_simplified = analysis.SimplifyExpression(load_node);
  933. SENode* difference =
  934. analysis.CreateSubtraction(store_simplified, load_simplified);
  935. SENode* difference_simplified = analysis.SimplifyExpression(difference);
  936. // Check that i+2*N - i*N, turns into just N when both sides have already
  937. // been simplified into a single recurrent expression.
  938. EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
  939. // Check that the inverse, i*N - i+2*N turns into -N.
  940. SENode* difference_inverse = analysis.SimplifyExpression(
  941. analysis.CreateSubtraction(load_simplified, store_simplified));
  942. EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
  943. EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
  944. EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
  945. }
  946. {
  947. Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
  948. stores[1]->GetSingleWordInOperand(0));
  949. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  950. store_access_chain->GetSingleWordInOperand(1));
  951. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  952. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  953. Instruction* load_access_chain =
  954. context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
  955. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  956. load_access_chain->GetSingleWordInOperand(1));
  957. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  958. SENode* load_simplified = analysis.SimplifyExpression(load_node);
  959. SENode* difference =
  960. analysis.CreateSubtraction(store_simplified, load_simplified);
  961. SENode* difference_simplified = analysis.SimplifyExpression(difference);
  962. // Check that 2*i + 2*N + 1 - 2*i + N + 1, turns into just N when both
  963. // sides have already been simplified into a single recurrent expression.
  964. EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
  965. // Check that the inverse, (2*i + N + 1) - (2*i + 2*N + 1) turns into -N.
  966. SENode* difference_inverse = analysis.SimplifyExpression(
  967. analysis.CreateSubtraction(load_simplified, store_simplified));
  968. EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
  969. EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
  970. EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
  971. }
  972. }
  973. /* Generated from the following GLSL + --eliminate-local-multi-store
  974. #version 430
  975. layout(location = 1) out float array[10];
  976. layout(location = 2) flat in int N;
  977. void main(void) {
  978. int step = 0;
  979. for (int i = 0; i < N; i += step) {
  980. step++;
  981. }
  982. }
  983. */
  984. TEST_F(ScalarAnalysisTest, InductionWithVariantStep) {
  985. const std::string text = R"(
  986. OpCapability Shader
  987. %1 = OpExtInstImport "GLSL.std.450"
  988. OpMemoryModel Logical GLSL450
  989. OpEntryPoint Fragment %2 "main" %3 %4
  990. OpExecutionMode %2 OriginUpperLeft
  991. OpSource GLSL 430
  992. OpName %2 "main"
  993. OpName %5 "step"
  994. OpName %6 "i"
  995. OpName %3 "N"
  996. OpName %4 "array"
  997. OpDecorate %3 Flat
  998. OpDecorate %3 Location 2
  999. OpDecorate %4 Location 1
  1000. %7 = OpTypeVoid
  1001. %8 = OpTypeFunction %7
  1002. %9 = OpTypeInt 32 1
  1003. %10 = OpTypePointer Function %9
  1004. %11 = OpConstant %9 0
  1005. %12 = OpTypePointer Input %9
  1006. %3 = OpVariable %12 Input
  1007. %13 = OpTypeBool
  1008. %14 = OpConstant %9 1
  1009. %15 = OpTypeFloat 32
  1010. %16 = OpTypeInt 32 0
  1011. %17 = OpConstant %16 10
  1012. %18 = OpTypeArray %15 %17
  1013. %19 = OpTypePointer Output %18
  1014. %4 = OpVariable %19 Output
  1015. %2 = OpFunction %7 None %8
  1016. %20 = OpLabel
  1017. %5 = OpVariable %10 Function
  1018. %6 = OpVariable %10 Function
  1019. OpStore %5 %11
  1020. OpStore %6 %11
  1021. OpBranch %21
  1022. %21 = OpLabel
  1023. %22 = OpPhi %9 %11 %20 %23 %24
  1024. %25 = OpPhi %9 %11 %20 %26 %24
  1025. OpLoopMerge %27 %24 None
  1026. OpBranch %28
  1027. %28 = OpLabel
  1028. %29 = OpLoad %9 %3
  1029. %30 = OpSLessThan %13 %25 %29
  1030. OpBranchConditional %30 %31 %27
  1031. %31 = OpLabel
  1032. %23 = OpIAdd %9 %22 %14
  1033. OpStore %5 %23
  1034. OpBranch %24
  1035. %24 = OpLabel
  1036. %26 = OpIAdd %9 %25 %23
  1037. OpStore %6 %26
  1038. OpBranch %21
  1039. %27 = OpLabel
  1040. OpReturn
  1041. OpFunctionEnd
  1042. )";
  1043. std::unique_ptr<IRContext> context =
  1044. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  1045. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  1046. Module* module = context->module();
  1047. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  1048. << text << std::endl;
  1049. const Function* f = spvtest::GetFunction(module, 2);
  1050. ScalarEvolutionAnalysis analysis{context.get()};
  1051. std::vector<const Instruction*> phis{};
  1052. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
  1053. if (inst.opcode() == SpvOp::SpvOpPhi) {
  1054. phis.push_back(&inst);
  1055. }
  1056. }
  1057. EXPECT_EQ(phis.size(), 2u);
  1058. SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]);
  1059. SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]);
  1060. phi_node_1->DumpDot(std::cout, true);
  1061. EXPECT_NE(phi_node_1, nullptr);
  1062. EXPECT_NE(phi_node_2, nullptr);
  1063. EXPECT_EQ(phi_node_1->GetType(), SENode::RecurrentAddExpr);
  1064. EXPECT_EQ(phi_node_2->GetType(), SENode::CanNotCompute);
  1065. SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1);
  1066. SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2);
  1067. EXPECT_EQ(simplified_1->GetType(), SENode::RecurrentAddExpr);
  1068. EXPECT_EQ(simplified_2->GetType(), SENode::CanNotCompute);
  1069. }
  1070. } // namespace
  1071. } // namespace opt
  1072. } // namespace spvtools