scalar_analysis.cpp 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/scalar_analysis.h"
  15. #include <memory>
  16. #include <vector>
  17. #include "gmock/gmock.h"
  18. #include "source/opt/pass.h"
  19. #include "test/opt/assembly_builder.h"
  20. #include "test/opt/function_utils.h"
  21. #include "test/opt/pass_fixture.h"
  22. #include "test/opt/pass_utils.h"
  23. namespace spvtools {
  24. namespace opt {
  25. namespace {
  26. using ::testing::UnorderedElementsAre;
  27. using ScalarAnalysisTest = PassTest<::testing::Test>;
  28. /*
  29. Generated from the following GLSL + --eliminate-local-multi-store
  30. #version 410 core
  31. layout (location = 1) out float array[10];
  32. void main() {
  33. for (int i = 0; i < 10; ++i) {
  34. array[i] = array[i+1];
  35. }
  36. }
  37. */
  38. TEST_F(ScalarAnalysisTest, BasicEvolutionTest) {
  39. const std::string text = R"(
  40. OpCapability Shader
  41. %1 = OpExtInstImport "GLSL.std.450"
  42. OpMemoryModel Logical GLSL450
  43. OpEntryPoint Fragment %4 "main" %24
  44. OpExecutionMode %4 OriginUpperLeft
  45. OpSource GLSL 410
  46. OpName %4 "main"
  47. OpName %24 "array"
  48. OpDecorate %24 Location 1
  49. %2 = OpTypeVoid
  50. %3 = OpTypeFunction %2
  51. %6 = OpTypeInt 32 1
  52. %7 = OpTypePointer Function %6
  53. %9 = OpConstant %6 0
  54. %16 = OpConstant %6 10
  55. %17 = OpTypeBool
  56. %19 = OpTypeFloat 32
  57. %20 = OpTypeInt 32 0
  58. %21 = OpConstant %20 10
  59. %22 = OpTypeArray %19 %21
  60. %23 = OpTypePointer Output %22
  61. %24 = OpVariable %23 Output
  62. %27 = OpConstant %6 1
  63. %29 = OpTypePointer Output %19
  64. %4 = OpFunction %2 None %3
  65. %5 = OpLabel
  66. OpBranch %10
  67. %10 = OpLabel
  68. %35 = OpPhi %6 %9 %5 %34 %13
  69. OpLoopMerge %12 %13 None
  70. OpBranch %14
  71. %14 = OpLabel
  72. %18 = OpSLessThan %17 %35 %16
  73. OpBranchConditional %18 %11 %12
  74. %11 = OpLabel
  75. %28 = OpIAdd %6 %35 %27
  76. %30 = OpAccessChain %29 %24 %28
  77. %31 = OpLoad %19 %30
  78. %32 = OpAccessChain %29 %24 %35
  79. OpStore %32 %31
  80. OpBranch %13
  81. %13 = OpLabel
  82. %34 = OpIAdd %6 %35 %27
  83. OpBranch %10
  84. %12 = OpLabel
  85. OpReturn
  86. OpFunctionEnd
  87. )";
  88. // clang-format on
  89. std::unique_ptr<IRContext> context =
  90. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  91. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  92. Module* module = context->module();
  93. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  94. << text << std::endl;
  95. const Function* f = spvtest::GetFunction(module, 4);
  96. ScalarEvolutionAnalysis analysis{context.get()};
  97. const Instruction* store = nullptr;
  98. const Instruction* load = nullptr;
  99. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) {
  100. if (inst.opcode() == spv::Op::OpStore) {
  101. store = &inst;
  102. }
  103. if (inst.opcode() == spv::Op::OpLoad) {
  104. load = &inst;
  105. }
  106. }
  107. EXPECT_NE(load, nullptr);
  108. EXPECT_NE(store, nullptr);
  109. Instruction* access_chain =
  110. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  111. Instruction* child = context->get_def_use_mgr()->GetDef(
  112. access_chain->GetSingleWordInOperand(1));
  113. const SENode* node = analysis.AnalyzeInstruction(child);
  114. EXPECT_NE(node, nullptr);
  115. // Unsimplified node should have the form of ADD(REC(0,1), 1)
  116. EXPECT_EQ(node->GetType(), SENode::Add);
  117. const SENode* child_1 = node->GetChild(0);
  118. EXPECT_TRUE(child_1->GetType() == SENode::Constant ||
  119. child_1->GetType() == SENode::RecurrentAddExpr);
  120. const SENode* child_2 = node->GetChild(1);
  121. EXPECT_TRUE(child_2->GetType() == SENode::Constant ||
  122. child_2->GetType() == SENode::RecurrentAddExpr);
  123. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  124. // Simplified should be in the form of REC(1,1)
  125. EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
  126. EXPECT_EQ(simplified->GetChild(0)->GetType(), SENode::Constant);
  127. EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(),
  128. 1);
  129. EXPECT_EQ(simplified->GetChild(1)->GetType(), SENode::Constant);
  130. EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(),
  131. 1);
  132. EXPECT_EQ(simplified->GetChild(0), simplified->GetChild(1));
  133. }
  134. /*
  135. Generated from the following GLSL + --eliminate-local-multi-store
  136. #version 410 core
  137. layout (location = 1) out float array[10];
  138. layout (location = 2) flat in int loop_invariant;
  139. void main() {
  140. for (int i = 0; i < 10; ++i) {
  141. array[i] = array[i+loop_invariant];
  142. }
  143. }
  144. */
  145. TEST_F(ScalarAnalysisTest, LoadTest) {
  146. const std::string text = R"(
  147. OpCapability Shader
  148. %1 = OpExtInstImport "GLSL.std.450"
  149. OpMemoryModel Logical GLSL450
  150. OpEntryPoint Fragment %2 "main" %3 %4
  151. OpExecutionMode %2 OriginUpperLeft
  152. OpSource GLSL 430
  153. OpName %2 "main"
  154. OpName %3 "array"
  155. OpName %4 "loop_invariant"
  156. OpDecorate %3 Location 1
  157. OpDecorate %4 Flat
  158. OpDecorate %4 Location 2
  159. %5 = OpTypeVoid
  160. %6 = OpTypeFunction %5
  161. %7 = OpTypeInt 32 1
  162. %8 = OpTypePointer Function %7
  163. %9 = OpConstant %7 0
  164. %10 = OpConstant %7 10
  165. %11 = OpTypeBool
  166. %12 = OpTypeFloat 32
  167. %13 = OpTypeInt 32 0
  168. %14 = OpConstant %13 10
  169. %15 = OpTypeArray %12 %14
  170. %16 = OpTypePointer Output %15
  171. %3 = OpVariable %16 Output
  172. %17 = OpTypePointer Input %7
  173. %4 = OpVariable %17 Input
  174. %18 = OpTypePointer Output %12
  175. %19 = OpConstant %7 1
  176. %2 = OpFunction %5 None %6
  177. %20 = OpLabel
  178. OpBranch %21
  179. %21 = OpLabel
  180. %22 = OpPhi %7 %9 %20 %23 %24
  181. OpLoopMerge %25 %24 None
  182. OpBranch %26
  183. %26 = OpLabel
  184. %27 = OpSLessThan %11 %22 %10
  185. OpBranchConditional %27 %28 %25
  186. %28 = OpLabel
  187. %29 = OpLoad %7 %4
  188. %30 = OpIAdd %7 %22 %29
  189. %31 = OpAccessChain %18 %3 %30
  190. %32 = OpLoad %12 %31
  191. %33 = OpAccessChain %18 %3 %22
  192. OpStore %33 %32
  193. OpBranch %24
  194. %24 = OpLabel
  195. %23 = OpIAdd %7 %22 %19
  196. OpBranch %21
  197. %25 = OpLabel
  198. OpReturn
  199. OpFunctionEnd
  200. )";
  201. // clang-format on
  202. std::unique_ptr<IRContext> context =
  203. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  204. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  205. Module* module = context->module();
  206. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  207. << text << std::endl;
  208. const Function* f = spvtest::GetFunction(module, 2);
  209. ScalarEvolutionAnalysis analysis{context.get()};
  210. const Instruction* load = nullptr;
  211. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 28)) {
  212. if (inst.opcode() == spv::Op::OpLoad) {
  213. load = &inst;
  214. }
  215. }
  216. EXPECT_NE(load, nullptr);
  217. Instruction* access_chain =
  218. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  219. Instruction* child = context->get_def_use_mgr()->GetDef(
  220. access_chain->GetSingleWordInOperand(1));
  221. // const SENode* node =
  222. // analysis.GetNodeFromInstruction(child->unique_id());
  223. const SENode* node = analysis.AnalyzeInstruction(child);
  224. EXPECT_NE(node, nullptr);
  225. // Unsimplified node should have the form of ADD(REC(0,1), X)
  226. EXPECT_EQ(node->GetType(), SENode::Add);
  227. const SENode* child_1 = node->GetChild(0);
  228. EXPECT_TRUE(child_1->GetType() == SENode::ValueUnknown ||
  229. child_1->GetType() == SENode::RecurrentAddExpr);
  230. const SENode* child_2 = node->GetChild(1);
  231. EXPECT_TRUE(child_2->GetType() == SENode::ValueUnknown ||
  232. child_2->GetType() == SENode::RecurrentAddExpr);
  233. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  234. EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
  235. const SERecurrentNode* rec = simplified->AsSERecurrentNode();
  236. EXPECT_NE(rec->GetChild(0), rec->GetChild(1));
  237. EXPECT_EQ(rec->GetOffset()->GetType(), SENode::ValueUnknown);
  238. EXPECT_EQ(rec->GetCoefficient()->GetType(), SENode::Constant);
  239. EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u);
  240. }
  241. /*
  242. Generated from the following GLSL + --eliminate-local-multi-store
  243. #version 410 core
  244. layout (location = 1) out float array[10];
  245. layout (location = 2) flat in int loop_invariant;
  246. void main() {
  247. array[0] = array[loop_invariant * 2 + 4 + 5 - 24 - loop_invariant -
  248. loop_invariant+ 16 * 3];
  249. }
  250. */
  251. TEST_F(ScalarAnalysisTest, SimplifySimple) {
  252. const std::string text = R"(
  253. OpCapability Shader
  254. %1 = OpExtInstImport "GLSL.std.450"
  255. OpMemoryModel Logical GLSL450
  256. OpEntryPoint Fragment %2 "main" %3 %4
  257. OpExecutionMode %2 OriginUpperLeft
  258. OpSource GLSL 430
  259. OpName %2 "main"
  260. OpName %3 "array"
  261. OpName %4 "loop_invariant"
  262. OpDecorate %3 Location 1
  263. OpDecorate %4 Flat
  264. OpDecorate %4 Location 2
  265. %5 = OpTypeVoid
  266. %6 = OpTypeFunction %5
  267. %7 = OpTypeFloat 32
  268. %8 = OpTypeInt 32 0
  269. %9 = OpConstant %8 10
  270. %10 = OpTypeArray %7 %9
  271. %11 = OpTypePointer Output %10
  272. %3 = OpVariable %11 Output
  273. %12 = OpTypeInt 32 1
  274. %13 = OpConstant %12 0
  275. %14 = OpTypePointer Input %12
  276. %4 = OpVariable %14 Input
  277. %15 = OpConstant %12 2
  278. %16 = OpConstant %12 4
  279. %17 = OpConstant %12 5
  280. %18 = OpConstant %12 24
  281. %19 = OpConstant %12 48
  282. %20 = OpTypePointer Output %7
  283. %2 = OpFunction %5 None %6
  284. %21 = OpLabel
  285. %22 = OpLoad %12 %4
  286. %23 = OpIMul %12 %22 %15
  287. %24 = OpIAdd %12 %23 %16
  288. %25 = OpIAdd %12 %24 %17
  289. %26 = OpISub %12 %25 %18
  290. %28 = OpISub %12 %26 %22
  291. %30 = OpISub %12 %28 %22
  292. %31 = OpIAdd %12 %30 %19
  293. %32 = OpAccessChain %20 %3 %31
  294. %33 = OpLoad %7 %32
  295. %34 = OpAccessChain %20 %3 %13
  296. OpStore %34 %33
  297. OpReturn
  298. OpFunctionEnd
  299. )";
  300. // clang-format on
  301. std::unique_ptr<IRContext> context =
  302. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  303. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  304. Module* module = context->module();
  305. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  306. << text << std::endl;
  307. const Function* f = spvtest::GetFunction(module, 2);
  308. ScalarEvolutionAnalysis analysis{context.get()};
  309. const Instruction* load = nullptr;
  310. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
  311. if (inst.opcode() == spv::Op::OpLoad && inst.result_id() == 33) {
  312. load = &inst;
  313. }
  314. }
  315. EXPECT_NE(load, nullptr);
  316. Instruction* access_chain =
  317. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  318. Instruction* child = context->get_def_use_mgr()->GetDef(
  319. access_chain->GetSingleWordInOperand(1));
  320. const SENode* node = analysis.AnalyzeInstruction(child);
  321. // Unsimplified is a very large graph with an add at the top.
  322. EXPECT_NE(node, nullptr);
  323. EXPECT_EQ(node->GetType(), SENode::Add);
  324. // Simplified node should resolve down to a constant expression as the loads
  325. // will eliminate themselves.
  326. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  327. EXPECT_EQ(simplified->GetType(), SENode::Constant);
  328. EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u);
  329. }
  330. /*
  331. Generated from the following GLSL + --eliminate-local-multi-store
  332. #version 410 core
  333. layout(location = 0) in vec4 c;
  334. layout (location = 1) out float array[10];
  335. void main() {
  336. int N = int(c.x);
  337. for (int i = 0; i < 10; ++i) {
  338. array[i] = array[i];
  339. array[i] = array[i-1];
  340. array[i] = array[i+1];
  341. array[i+1] = array[i+1];
  342. array[i+N] = array[i+N];
  343. array[i] = array[i+N];
  344. }
  345. }
  346. */
  347. TEST_F(ScalarAnalysisTest, Simplify) {
  348. const std::string text = R"( OpCapability Shader
  349. %1 = OpExtInstImport "GLSL.std.450"
  350. OpMemoryModel Logical GLSL450
  351. OpEntryPoint Fragment %4 "main" %12 %33
  352. OpExecutionMode %4 OriginUpperLeft
  353. OpSource GLSL 410
  354. OpName %4 "main"
  355. OpName %8 "N"
  356. OpName %12 "c"
  357. OpName %19 "i"
  358. OpName %33 "array"
  359. OpDecorate %12 Location 0
  360. OpDecorate %33 Location 1
  361. %2 = OpTypeVoid
  362. %3 = OpTypeFunction %2
  363. %6 = OpTypeInt 32 1
  364. %7 = OpTypePointer Function %6
  365. %9 = OpTypeFloat 32
  366. %10 = OpTypeVector %9 4
  367. %11 = OpTypePointer Input %10
  368. %12 = OpVariable %11 Input
  369. %13 = OpTypeInt 32 0
  370. %14 = OpConstant %13 0
  371. %15 = OpTypePointer Input %9
  372. %20 = OpConstant %6 0
  373. %27 = OpConstant %6 10
  374. %28 = OpTypeBool
  375. %30 = OpConstant %13 10
  376. %31 = OpTypeArray %9 %30
  377. %32 = OpTypePointer Output %31
  378. %33 = OpVariable %32 Output
  379. %36 = OpTypePointer Output %9
  380. %42 = OpConstant %6 1
  381. %4 = OpFunction %2 None %3
  382. %5 = OpLabel
  383. %8 = OpVariable %7 Function
  384. %19 = OpVariable %7 Function
  385. %16 = OpAccessChain %15 %12 %14
  386. %17 = OpLoad %9 %16
  387. %18 = OpConvertFToS %6 %17
  388. OpStore %8 %18
  389. OpStore %19 %20
  390. OpBranch %21
  391. %21 = OpLabel
  392. %78 = OpPhi %6 %20 %5 %77 %24
  393. OpLoopMerge %23 %24 None
  394. OpBranch %25
  395. %25 = OpLabel
  396. %29 = OpSLessThan %28 %78 %27
  397. OpBranchConditional %29 %22 %23
  398. %22 = OpLabel
  399. %37 = OpAccessChain %36 %33 %78
  400. %38 = OpLoad %9 %37
  401. %39 = OpAccessChain %36 %33 %78
  402. OpStore %39 %38
  403. %43 = OpISub %6 %78 %42
  404. %44 = OpAccessChain %36 %33 %43
  405. %45 = OpLoad %9 %44
  406. %46 = OpAccessChain %36 %33 %78
  407. OpStore %46 %45
  408. %49 = OpIAdd %6 %78 %42
  409. %50 = OpAccessChain %36 %33 %49
  410. %51 = OpLoad %9 %50
  411. %52 = OpAccessChain %36 %33 %78
  412. OpStore %52 %51
  413. %54 = OpIAdd %6 %78 %42
  414. %56 = OpIAdd %6 %78 %42
  415. %57 = OpAccessChain %36 %33 %56
  416. %58 = OpLoad %9 %57
  417. %59 = OpAccessChain %36 %33 %54
  418. OpStore %59 %58
  419. %62 = OpIAdd %6 %78 %18
  420. %65 = OpIAdd %6 %78 %18
  421. %66 = OpAccessChain %36 %33 %65
  422. %67 = OpLoad %9 %66
  423. %68 = OpAccessChain %36 %33 %62
  424. OpStore %68 %67
  425. %72 = OpIAdd %6 %78 %18
  426. %73 = OpAccessChain %36 %33 %72
  427. %74 = OpLoad %9 %73
  428. %75 = OpAccessChain %36 %33 %78
  429. OpStore %75 %74
  430. OpBranch %24
  431. %24 = OpLabel
  432. %77 = OpIAdd %6 %78 %42
  433. OpStore %19 %77
  434. OpBranch %21
  435. %23 = OpLabel
  436. OpReturn
  437. OpFunctionEnd
  438. )";
  439. // clang-format on
  440. std::unique_ptr<IRContext> context =
  441. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  442. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  443. Module* module = context->module();
  444. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  445. << text << std::endl;
  446. const Function* f = spvtest::GetFunction(module, 4);
  447. ScalarEvolutionAnalysis analysis{context.get()};
  448. const Instruction* loads[6];
  449. const Instruction* stores[6];
  450. int load_count = 0;
  451. int store_count = 0;
  452. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) {
  453. if (inst.opcode() == spv::Op::OpLoad) {
  454. loads[load_count] = &inst;
  455. ++load_count;
  456. }
  457. if (inst.opcode() == spv::Op::OpStore) {
  458. stores[store_count] = &inst;
  459. ++store_count;
  460. }
  461. }
  462. EXPECT_EQ(load_count, 6);
  463. EXPECT_EQ(store_count, 6);
  464. Instruction* load_access_chain;
  465. Instruction* store_access_chain;
  466. Instruction* load_child;
  467. Instruction* store_child;
  468. SENode* load_node;
  469. SENode* store_node;
  470. SENode* subtract_node;
  471. SENode* simplified_node;
  472. // Testing [i] - [i] == 0
  473. load_access_chain =
  474. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  475. store_access_chain =
  476. context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
  477. load_child = context->get_def_use_mgr()->GetDef(
  478. load_access_chain->GetSingleWordInOperand(1));
  479. store_child = context->get_def_use_mgr()->GetDef(
  480. store_access_chain->GetSingleWordInOperand(1));
  481. load_node = analysis.AnalyzeInstruction(load_child);
  482. store_node = analysis.AnalyzeInstruction(store_child);
  483. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  484. simplified_node = analysis.SimplifyExpression(subtract_node);
  485. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  486. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  487. // Testing [i] - [i-1] == 1
  488. load_access_chain =
  489. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  490. store_access_chain =
  491. context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
  492. load_child = context->get_def_use_mgr()->GetDef(
  493. load_access_chain->GetSingleWordInOperand(1));
  494. store_child = context->get_def_use_mgr()->GetDef(
  495. store_access_chain->GetSingleWordInOperand(1));
  496. load_node = analysis.AnalyzeInstruction(load_child);
  497. store_node = analysis.AnalyzeInstruction(store_child);
  498. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  499. simplified_node = analysis.SimplifyExpression(subtract_node);
  500. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  501. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u);
  502. // Testing [i] - [i+1] == -1
  503. load_access_chain =
  504. context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
  505. store_access_chain =
  506. context->get_def_use_mgr()->GetDef(stores[2]->GetSingleWordInOperand(0));
  507. load_child = context->get_def_use_mgr()->GetDef(
  508. load_access_chain->GetSingleWordInOperand(1));
  509. store_child = context->get_def_use_mgr()->GetDef(
  510. store_access_chain->GetSingleWordInOperand(1));
  511. load_node = analysis.AnalyzeInstruction(load_child);
  512. store_node = analysis.AnalyzeInstruction(store_child);
  513. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  514. simplified_node = analysis.SimplifyExpression(subtract_node);
  515. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  516. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1);
  517. // Testing [i+1] - [i+1] == 0
  518. load_access_chain =
  519. context->get_def_use_mgr()->GetDef(loads[3]->GetSingleWordInOperand(0));
  520. store_access_chain =
  521. context->get_def_use_mgr()->GetDef(stores[3]->GetSingleWordInOperand(0));
  522. load_child = context->get_def_use_mgr()->GetDef(
  523. load_access_chain->GetSingleWordInOperand(1));
  524. store_child = context->get_def_use_mgr()->GetDef(
  525. store_access_chain->GetSingleWordInOperand(1));
  526. load_node = analysis.AnalyzeInstruction(load_child);
  527. store_node = analysis.AnalyzeInstruction(store_child);
  528. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  529. simplified_node = analysis.SimplifyExpression(subtract_node);
  530. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  531. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  532. // Testing [i+N] - [i+N] == 0
  533. load_access_chain =
  534. context->get_def_use_mgr()->GetDef(loads[4]->GetSingleWordInOperand(0));
  535. store_access_chain =
  536. context->get_def_use_mgr()->GetDef(stores[4]->GetSingleWordInOperand(0));
  537. load_child = context->get_def_use_mgr()->GetDef(
  538. load_access_chain->GetSingleWordInOperand(1));
  539. store_child = context->get_def_use_mgr()->GetDef(
  540. store_access_chain->GetSingleWordInOperand(1));
  541. load_node = analysis.AnalyzeInstruction(load_child);
  542. store_node = analysis.AnalyzeInstruction(store_child);
  543. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  544. simplified_node = analysis.SimplifyExpression(subtract_node);
  545. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  546. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  547. // Testing [i] - [i+N] == -N
  548. load_access_chain =
  549. context->get_def_use_mgr()->GetDef(loads[5]->GetSingleWordInOperand(0));
  550. store_access_chain =
  551. context->get_def_use_mgr()->GetDef(stores[5]->GetSingleWordInOperand(0));
  552. load_child = context->get_def_use_mgr()->GetDef(
  553. load_access_chain->GetSingleWordInOperand(1));
  554. store_child = context->get_def_use_mgr()->GetDef(
  555. store_access_chain->GetSingleWordInOperand(1));
  556. load_node = analysis.AnalyzeInstruction(load_child);
  557. store_node = analysis.AnalyzeInstruction(store_child);
  558. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  559. simplified_node = analysis.SimplifyExpression(subtract_node);
  560. EXPECT_EQ(simplified_node->GetType(), SENode::Negative);
  561. }
  562. /*
  563. Generated from the following GLSL + --eliminate-local-multi-store
  564. #version 430
  565. layout(location = 1) out float array[10];
  566. layout(location = 2) flat in int loop_invariant;
  567. void main(void) {
  568. for (int i = 0; i < 10; ++i) {
  569. array[i * 2 + i * 5] = array[i * i * 2];
  570. array[i * 2] = array[i * 5];
  571. }
  572. }
  573. */
  574. TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) {
  575. const std::string text = R"(
  576. OpCapability Shader
  577. %1 = OpExtInstImport "GLSL.std.450"
  578. OpMemoryModel Logical GLSL450
  579. OpEntryPoint Fragment %2 "main" %3 %4
  580. OpExecutionMode %2 OriginUpperLeft
  581. OpSource GLSL 430
  582. OpName %2 "main"
  583. OpName %5 "i"
  584. OpName %3 "array"
  585. OpName %4 "loop_invariant"
  586. OpDecorate %3 Location 1
  587. OpDecorate %4 Flat
  588. OpDecorate %4 Location 2
  589. %6 = OpTypeVoid
  590. %7 = OpTypeFunction %6
  591. %8 = OpTypeInt 32 1
  592. %9 = OpTypePointer Function %8
  593. %10 = OpConstant %8 0
  594. %11 = OpConstant %8 10
  595. %12 = OpTypeBool
  596. %13 = OpTypeFloat 32
  597. %14 = OpTypeInt 32 0
  598. %15 = OpConstant %14 10
  599. %16 = OpTypeArray %13 %15
  600. %17 = OpTypePointer Output %16
  601. %3 = OpVariable %17 Output
  602. %18 = OpConstant %8 2
  603. %19 = OpConstant %8 5
  604. %20 = OpTypePointer Output %13
  605. %21 = OpConstant %8 1
  606. %22 = OpTypePointer Input %8
  607. %4 = OpVariable %22 Input
  608. %2 = OpFunction %6 None %7
  609. %23 = OpLabel
  610. %5 = OpVariable %9 Function
  611. OpStore %5 %10
  612. OpBranch %24
  613. %24 = OpLabel
  614. %25 = OpPhi %8 %10 %23 %26 %27
  615. OpLoopMerge %28 %27 None
  616. OpBranch %29
  617. %29 = OpLabel
  618. %30 = OpSLessThan %12 %25 %11
  619. OpBranchConditional %30 %31 %28
  620. %31 = OpLabel
  621. %32 = OpIMul %8 %25 %18
  622. %33 = OpIMul %8 %25 %19
  623. %34 = OpIAdd %8 %32 %33
  624. %35 = OpIMul %8 %25 %25
  625. %36 = OpIMul %8 %35 %18
  626. %37 = OpAccessChain %20 %3 %36
  627. %38 = OpLoad %13 %37
  628. %39 = OpAccessChain %20 %3 %34
  629. OpStore %39 %38
  630. %40 = OpIMul %8 %25 %18
  631. %41 = OpIMul %8 %25 %19
  632. %42 = OpAccessChain %20 %3 %41
  633. %43 = OpLoad %13 %42
  634. %44 = OpAccessChain %20 %3 %40
  635. OpStore %44 %43
  636. OpBranch %27
  637. %27 = OpLabel
  638. %26 = OpIAdd %8 %25 %21
  639. OpStore %5 %26
  640. OpBranch %24
  641. %28 = OpLabel
  642. OpReturn
  643. OpFunctionEnd
  644. )";
  645. std::unique_ptr<IRContext> context =
  646. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  647. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  648. Module* module = context->module();
  649. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  650. << text << std::endl;
  651. const Function* f = spvtest::GetFunction(module, 2);
  652. ScalarEvolutionAnalysis analysis{context.get()};
  653. const Instruction* loads[2] = {nullptr, nullptr};
  654. const Instruction* stores[2] = {nullptr, nullptr};
  655. int load_count = 0;
  656. int store_count = 0;
  657. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 31)) {
  658. if (inst.opcode() == spv::Op::OpLoad) {
  659. loads[load_count] = &inst;
  660. ++load_count;
  661. }
  662. if (inst.opcode() == spv::Op::OpStore) {
  663. stores[store_count] = &inst;
  664. ++store_count;
  665. }
  666. }
  667. EXPECT_EQ(load_count, 2);
  668. EXPECT_EQ(store_count, 2);
  669. Instruction* load_access_chain =
  670. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  671. Instruction* store_access_chain =
  672. context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
  673. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  674. load_access_chain->GetSingleWordInOperand(1));
  675. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  676. store_access_chain->GetSingleWordInOperand(1));
  677. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  678. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  679. load_access_chain =
  680. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  681. store_access_chain =
  682. context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
  683. load_child = context->get_def_use_mgr()->GetDef(
  684. load_access_chain->GetSingleWordInOperand(1));
  685. store_child = context->get_def_use_mgr()->GetDef(
  686. store_access_chain->GetSingleWordInOperand(1));
  687. SENode* second_store =
  688. analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child));
  689. SENode* second_load =
  690. analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child));
  691. SENode* combined_add = analysis.SimplifyExpression(
  692. analysis.CreateAddNode(second_load, second_store));
  693. // We're checking that the two recurrent expression have been correctly
  694. // folded. In store_simplified they will have been folded as the entire
  695. // expression was simplified as one. In combined_add the two expressions have
  696. // been simplified one after the other which means the recurrent expressions
  697. // aren't exactly the same but should still be folded as they are with respect
  698. // to the same loop.
  699. EXPECT_EQ(combined_add, store_simplified);
  700. }
  701. /*
  702. Generated from the following GLSL + --eliminate-local-multi-store
  703. #version 430
  704. void main(void) {
  705. for (int i = 0; i < 10; --i) {
  706. array[i] = array[i];
  707. }
  708. }
  709. */
  710. TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) {
  711. const std::string text = R"(
  712. OpCapability Shader
  713. %1 = OpExtInstImport "GLSL.std.450"
  714. OpMemoryModel Logical GLSL450
  715. OpEntryPoint Fragment %2 "main" %3 %4
  716. OpExecutionMode %2 OriginUpperLeft
  717. OpSource GLSL 430
  718. OpName %2 "main"
  719. OpName %5 "i"
  720. OpName %3 "array"
  721. OpName %4 "loop_invariant"
  722. OpDecorate %3 Location 1
  723. OpDecorate %4 Flat
  724. OpDecorate %4 Location 2
  725. %6 = OpTypeVoid
  726. %7 = OpTypeFunction %6
  727. %8 = OpTypeInt 32 1
  728. %9 = OpTypePointer Function %8
  729. %10 = OpConstant %8 0
  730. %11 = OpConstant %8 10
  731. %12 = OpTypeBool
  732. %13 = OpTypeFloat 32
  733. %14 = OpTypeInt 32 0
  734. %15 = OpConstant %14 10
  735. %16 = OpTypeArray %13 %15
  736. %17 = OpTypePointer Output %16
  737. %3 = OpVariable %17 Output
  738. %18 = OpTypePointer Output %13
  739. %19 = OpConstant %8 1
  740. %20 = OpTypePointer Input %8
  741. %4 = OpVariable %20 Input
  742. %2 = OpFunction %6 None %7
  743. %21 = OpLabel
  744. %5 = OpVariable %9 Function
  745. OpStore %5 %10
  746. OpBranch %22
  747. %22 = OpLabel
  748. %23 = OpPhi %8 %10 %21 %24 %25
  749. OpLoopMerge %26 %25 None
  750. OpBranch %27
  751. %27 = OpLabel
  752. %28 = OpSLessThan %12 %23 %11
  753. OpBranchConditional %28 %29 %26
  754. %29 = OpLabel
  755. %30 = OpAccessChain %18 %3 %23
  756. %31 = OpLoad %13 %30
  757. %32 = OpAccessChain %18 %3 %23
  758. OpStore %32 %31
  759. OpBranch %25
  760. %25 = OpLabel
  761. %24 = OpISub %8 %23 %19
  762. OpStore %5 %24
  763. OpBranch %22
  764. %26 = OpLabel
  765. OpReturn
  766. OpFunctionEnd
  767. )";
  768. std::unique_ptr<IRContext> context =
  769. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  770. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  771. Module* module = context->module();
  772. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  773. << text << std::endl;
  774. const Function* f = spvtest::GetFunction(module, 2);
  775. ScalarEvolutionAnalysis analysis{context.get()};
  776. const Instruction* loads[1] = {nullptr};
  777. int load_count = 0;
  778. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) {
  779. if (inst.opcode() == spv::Op::OpLoad) {
  780. loads[load_count] = &inst;
  781. ++load_count;
  782. }
  783. }
  784. EXPECT_EQ(load_count, 1);
  785. Instruction* load_access_chain =
  786. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  787. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  788. load_access_chain->GetSingleWordInOperand(1));
  789. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  790. EXPECT_TRUE(load_node);
  791. EXPECT_EQ(load_node->GetType(), SENode::RecurrentAddExpr);
  792. EXPECT_TRUE(load_node->AsSERecurrentNode());
  793. SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient();
  794. SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset();
  795. EXPECT_EQ(child_1->GetType(), SENode::Constant);
  796. EXPECT_EQ(child_2->GetType(), SENode::Constant);
  797. EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1);
  798. EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u);
  799. SERecurrentNode* load_simplified =
  800. analysis.SimplifyExpression(load_node)->AsSERecurrentNode();
  801. EXPECT_TRUE(load_simplified);
  802. EXPECT_EQ(load_node, load_simplified);
  803. EXPECT_EQ(load_simplified->GetType(), SENode::RecurrentAddExpr);
  804. EXPECT_TRUE(load_simplified->AsSERecurrentNode());
  805. SENode* simplified_child_1 =
  806. load_simplified->AsSERecurrentNode()->GetCoefficient();
  807. SENode* simplified_child_2 =
  808. load_simplified->AsSERecurrentNode()->GetOffset();
  809. EXPECT_EQ(child_1, simplified_child_1);
  810. EXPECT_EQ(child_2, simplified_child_2);
  811. }
  812. /*
  813. Generated from the following GLSL + --eliminate-local-multi-store
  814. #version 430
  815. void main(void) {
  816. for (int i = 0; i < 10; --i) {
  817. array[i] = array[i];
  818. }
  819. }
  820. */
  821. TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) {
  822. const std::string text = R"(
  823. OpCapability Shader
  824. %1 = OpExtInstImport "GLSL.std.450"
  825. OpMemoryModel Logical GLSL450
  826. OpEntryPoint Fragment %2 "main" %3 %4
  827. OpExecutionMode %2 OriginUpperLeft
  828. OpSource GLSL 430
  829. OpName %2 "main"
  830. OpName %5 "i"
  831. OpName %3 "array"
  832. OpName %4 "N"
  833. OpDecorate %3 Location 1
  834. OpDecorate %4 Flat
  835. OpDecorate %4 Location 2
  836. %6 = OpTypeVoid
  837. %7 = OpTypeFunction %6
  838. %8 = OpTypeInt 32 1
  839. %9 = OpTypePointer Function %8
  840. %10 = OpConstant %8 0
  841. %11 = OpConstant %8 10
  842. %12 = OpTypeBool
  843. %13 = OpTypeFloat 32
  844. %14 = OpTypeInt 32 0
  845. %15 = OpConstant %14 10
  846. %16 = OpTypeArray %13 %15
  847. %17 = OpTypePointer Output %16
  848. %3 = OpVariable %17 Output
  849. %18 = OpConstant %8 2
  850. %19 = OpTypePointer Input %8
  851. %4 = OpVariable %19 Input
  852. %20 = OpTypePointer Output %13
  853. %21 = OpConstant %8 1
  854. %2 = OpFunction %6 None %7
  855. %22 = OpLabel
  856. %5 = OpVariable %9 Function
  857. OpStore %5 %10
  858. OpBranch %23
  859. %23 = OpLabel
  860. %24 = OpPhi %8 %10 %22 %25 %26
  861. OpLoopMerge %27 %26 None
  862. OpBranch %28
  863. %28 = OpLabel
  864. %29 = OpSLessThan %12 %24 %11
  865. OpBranchConditional %29 %30 %27
  866. %30 = OpLabel
  867. %31 = OpLoad %8 %4
  868. %32 = OpIMul %8 %18 %31
  869. %33 = OpIAdd %8 %24 %32
  870. %35 = OpIAdd %8 %24 %31
  871. %36 = OpAccessChain %20 %3 %35
  872. %37 = OpLoad %13 %36
  873. %38 = OpAccessChain %20 %3 %33
  874. OpStore %38 %37
  875. %39 = OpIMul %8 %18 %24
  876. %41 = OpIMul %8 %18 %31
  877. %42 = OpIAdd %8 %39 %41
  878. %43 = OpIAdd %8 %42 %21
  879. %44 = OpIMul %8 %18 %24
  880. %46 = OpIAdd %8 %44 %31
  881. %47 = OpIAdd %8 %46 %21
  882. %48 = OpAccessChain %20 %3 %47
  883. %49 = OpLoad %13 %48
  884. %50 = OpAccessChain %20 %3 %43
  885. OpStore %50 %49
  886. OpBranch %26
  887. %26 = OpLabel
  888. %25 = OpISub %8 %24 %21
  889. OpStore %5 %25
  890. OpBranch %23
  891. %27 = OpLabel
  892. OpReturn
  893. OpFunctionEnd
  894. )";
  895. std::unique_ptr<IRContext> context =
  896. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  897. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  898. Module* module = context->module();
  899. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  900. << text << std::endl;
  901. const Function* f = spvtest::GetFunction(module, 2);
  902. ScalarEvolutionAnalysis analysis{context.get()};
  903. std::vector<const Instruction*> loads{};
  904. std::vector<const Instruction*> stores{};
  905. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) {
  906. if (inst.opcode() == spv::Op::OpLoad) {
  907. loads.push_back(&inst);
  908. }
  909. if (inst.opcode() == spv::Op::OpStore) {
  910. stores.push_back(&inst);
  911. }
  912. }
  913. EXPECT_EQ(loads.size(), 3u);
  914. EXPECT_EQ(stores.size(), 2u);
  915. {
  916. Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
  917. stores[0]->GetSingleWordInOperand(0));
  918. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  919. store_access_chain->GetSingleWordInOperand(1));
  920. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  921. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  922. Instruction* load_access_chain =
  923. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  924. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  925. load_access_chain->GetSingleWordInOperand(1));
  926. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  927. SENode* load_simplified = analysis.SimplifyExpression(load_node);
  928. SENode* difference =
  929. analysis.CreateSubtraction(store_simplified, load_simplified);
  930. SENode* difference_simplified = analysis.SimplifyExpression(difference);
  931. // Check that i+2*N - i*N, turns into just N when both sides have already
  932. // been simplified into a single recurrent expression.
  933. EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
  934. // Check that the inverse, i*N - i+2*N turns into -N.
  935. SENode* difference_inverse = analysis.SimplifyExpression(
  936. analysis.CreateSubtraction(load_simplified, store_simplified));
  937. EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
  938. EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
  939. EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
  940. }
  941. {
  942. Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
  943. stores[1]->GetSingleWordInOperand(0));
  944. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  945. store_access_chain->GetSingleWordInOperand(1));
  946. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  947. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  948. Instruction* load_access_chain =
  949. context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
  950. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  951. load_access_chain->GetSingleWordInOperand(1));
  952. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  953. SENode* load_simplified = analysis.SimplifyExpression(load_node);
  954. SENode* difference =
  955. analysis.CreateSubtraction(store_simplified, load_simplified);
  956. SENode* difference_simplified = analysis.SimplifyExpression(difference);
  957. // Check that 2*i + 2*N + 1 - 2*i + N + 1, turns into just N when both
  958. // sides have already been simplified into a single recurrent expression.
  959. EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
  960. // Check that the inverse, (2*i + N + 1) - (2*i + 2*N + 1) turns into -N.
  961. SENode* difference_inverse = analysis.SimplifyExpression(
  962. analysis.CreateSubtraction(load_simplified, store_simplified));
  963. EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
  964. EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
  965. EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
  966. }
  967. }
  968. /* Generated from the following GLSL + --eliminate-local-multi-store
  969. #version 430
  970. layout(location = 1) out float array[10];
  971. layout(location = 2) flat in int N;
  972. void main(void) {
  973. int step = 0;
  974. for (int i = 0; i < N; i += step) {
  975. step++;
  976. }
  977. }
  978. */
  979. TEST_F(ScalarAnalysisTest, InductionWithVariantStep) {
  980. const std::string text = R"(
  981. OpCapability Shader
  982. %1 = OpExtInstImport "GLSL.std.450"
  983. OpMemoryModel Logical GLSL450
  984. OpEntryPoint Fragment %2 "main" %3 %4
  985. OpExecutionMode %2 OriginUpperLeft
  986. OpSource GLSL 430
  987. OpName %2 "main"
  988. OpName %5 "step"
  989. OpName %6 "i"
  990. OpName %3 "N"
  991. OpName %4 "array"
  992. OpDecorate %3 Flat
  993. OpDecorate %3 Location 2
  994. OpDecorate %4 Location 1
  995. %7 = OpTypeVoid
  996. %8 = OpTypeFunction %7
  997. %9 = OpTypeInt 32 1
  998. %10 = OpTypePointer Function %9
  999. %11 = OpConstant %9 0
  1000. %12 = OpTypePointer Input %9
  1001. %3 = OpVariable %12 Input
  1002. %13 = OpTypeBool
  1003. %14 = OpConstant %9 1
  1004. %15 = OpTypeFloat 32
  1005. %16 = OpTypeInt 32 0
  1006. %17 = OpConstant %16 10
  1007. %18 = OpTypeArray %15 %17
  1008. %19 = OpTypePointer Output %18
  1009. %4 = OpVariable %19 Output
  1010. %2 = OpFunction %7 None %8
  1011. %20 = OpLabel
  1012. %5 = OpVariable %10 Function
  1013. %6 = OpVariable %10 Function
  1014. OpStore %5 %11
  1015. OpStore %6 %11
  1016. OpBranch %21
  1017. %21 = OpLabel
  1018. %22 = OpPhi %9 %11 %20 %23 %24
  1019. %25 = OpPhi %9 %11 %20 %26 %24
  1020. OpLoopMerge %27 %24 None
  1021. OpBranch %28
  1022. %28 = OpLabel
  1023. %29 = OpLoad %9 %3
  1024. %30 = OpSLessThan %13 %25 %29
  1025. OpBranchConditional %30 %31 %27
  1026. %31 = OpLabel
  1027. %23 = OpIAdd %9 %22 %14
  1028. OpStore %5 %23
  1029. OpBranch %24
  1030. %24 = OpLabel
  1031. %26 = OpIAdd %9 %25 %23
  1032. OpStore %6 %26
  1033. OpBranch %21
  1034. %27 = OpLabel
  1035. OpReturn
  1036. OpFunctionEnd
  1037. )";
  1038. std::unique_ptr<IRContext> context =
  1039. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  1040. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  1041. Module* module = context->module();
  1042. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  1043. << text << std::endl;
  1044. const Function* f = spvtest::GetFunction(module, 2);
  1045. ScalarEvolutionAnalysis analysis{context.get()};
  1046. std::vector<const Instruction*> phis{};
  1047. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
  1048. if (inst.opcode() == spv::Op::OpPhi) {
  1049. phis.push_back(&inst);
  1050. }
  1051. }
  1052. EXPECT_EQ(phis.size(), 2u);
  1053. SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]);
  1054. SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]);
  1055. EXPECT_NE(phi_node_1, nullptr);
  1056. EXPECT_NE(phi_node_2, nullptr);
  1057. EXPECT_EQ(phi_node_1->GetType(), SENode::RecurrentAddExpr);
  1058. EXPECT_EQ(phi_node_2->GetType(), SENode::CanNotCompute);
  1059. SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1);
  1060. SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2);
  1061. EXPECT_EQ(simplified_1->GetType(), SENode::RecurrentAddExpr);
  1062. EXPECT_EQ(simplified_2->GetType(), SENode::CanNotCompute);
  1063. }
  1064. } // namespace
  1065. } // namespace opt
  1066. } // namespace spvtools