copy_prop_arrays.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/copy_prop_arrays.h"
  15. #include <utility>
  16. #include "source/opt/ir_builder.h"
  17. namespace spvtools {
  18. namespace opt {
  19. namespace {
  20. const uint32_t kLoadPointerInOperand = 0;
  21. const uint32_t kStorePointerInOperand = 0;
  22. const uint32_t kStoreObjectInOperand = 1;
  23. const uint32_t kCompositeExtractObjectInOperand = 0;
  24. const uint32_t kTypePointerStorageClassInIdx = 0;
  25. const uint32_t kTypePointerPointeeInIdx = 1;
  26. } // namespace
  27. Pass::Status CopyPropagateArrays::Process() {
  28. bool modified = false;
  29. for (Function& function : *get_module()) {
  30. BasicBlock* entry_bb = &*function.begin();
  31. for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
  32. ++var_inst) {
  33. if (!IsPointerToArrayType(var_inst->type_id())) {
  34. continue;
  35. }
  36. // Find the only store to the entire memory location, if it exists.
  37. Instruction* store_inst = FindStoreInstruction(&*var_inst);
  38. if (!store_inst) {
  39. continue;
  40. }
  41. std::unique_ptr<MemoryObject> source_object =
  42. FindSourceObjectIfPossible(&*var_inst, store_inst);
  43. if (source_object != nullptr) {
  44. if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
  45. modified = true;
  46. PropagateObject(&*var_inst, source_object.get(), store_inst);
  47. }
  48. }
  49. }
  50. }
  51. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  52. }
  53. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  54. CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
  55. Instruction* store_inst) {
  56. assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
  57. // Check that the variable is a composite object where |store_inst|
  58. // dominates all of its loads.
  59. if (!store_inst) {
  60. return nullptr;
  61. }
  62. // Look at the loads to ensure they are dominated by the store.
  63. if (!HasValidReferencesOnly(var_inst, store_inst)) {
  64. return nullptr;
  65. }
  66. // If so, look at the store to see if it is the copy of an object.
  67. std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
  68. store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
  69. if (!source) {
  70. return nullptr;
  71. }
  72. // Ensure that |source| does not change between the point at which it is
  73. // loaded, and the position in which |var_inst| is loaded.
  74. //
  75. // For now we will go with the easy to implement approach, and check that the
  76. // entire variable (not just the specific component) is never written to.
  77. if (!HasNoStores(source->GetVariable())) {
  78. return nullptr;
  79. }
  80. return source;
  81. }
  82. Instruction* CopyPropagateArrays::FindStoreInstruction(
  83. const Instruction* var_inst) const {
  84. Instruction* store_inst = nullptr;
  85. get_def_use_mgr()->WhileEachUser(
  86. var_inst, [&store_inst, var_inst](Instruction* use) {
  87. if (use->opcode() == SpvOpStore &&
  88. use->GetSingleWordInOperand(kStorePointerInOperand) ==
  89. var_inst->result_id()) {
  90. if (store_inst == nullptr) {
  91. store_inst = use;
  92. } else {
  93. store_inst = nullptr;
  94. return false;
  95. }
  96. }
  97. return true;
  98. });
  99. return store_inst;
  100. }
  101. void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
  102. MemoryObject* source,
  103. Instruction* insertion_point) {
  104. assert(var_inst->opcode() == SpvOpVariable &&
  105. "This function propagates variables.");
  106. Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
  107. context()->KillNamesAndDecorates(var_inst);
  108. UpdateUses(var_inst, new_access_chain);
  109. }
  110. Instruction* CopyPropagateArrays::BuildNewAccessChain(
  111. Instruction* insertion_point,
  112. CopyPropagateArrays::MemoryObject* source) const {
  113. InstructionBuilder builder(
  114. context(), insertion_point,
  115. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  116. if (source->AccessChain().size() == 0) {
  117. return source->GetVariable();
  118. }
  119. return builder.AddAccessChain(source->GetPointerTypeId(this),
  120. source->GetVariable()->result_id(),
  121. source->AccessChain());
  122. }
  123. bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
  124. return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
  125. if (use->opcode() == SpvOpLoad) {
  126. return true;
  127. } else if (use->opcode() == SpvOpAccessChain) {
  128. return HasNoStores(use);
  129. } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
  130. return true;
  131. } else if (use->opcode() == SpvOpStore) {
  132. return false;
  133. } else if (use->opcode() == SpvOpImageTexelPointer) {
  134. return true;
  135. }
  136. // Some other instruction. Be conservative.
  137. return false;
  138. });
  139. }
  140. bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
  141. Instruction* store_inst) {
  142. BasicBlock* store_block = context()->get_instr_block(store_inst);
  143. DominatorAnalysis* dominator_analysis =
  144. context()->GetDominatorAnalysis(store_block->GetParent());
  145. return get_def_use_mgr()->WhileEachUser(
  146. ptr_inst,
  147. [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
  148. if (use->opcode() == SpvOpLoad ||
  149. use->opcode() == SpvOpImageTexelPointer) {
  150. // TODO: If there are many load in the same BB as |store_inst| the
  151. // time to do the multiple traverses can add up. Consider collecting
  152. // those loads and doing a single traversal.
  153. return dominator_analysis->Dominates(store_inst, use);
  154. } else if (use->opcode() == SpvOpAccessChain) {
  155. return HasValidReferencesOnly(use, store_inst);
  156. } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
  157. return true;
  158. } else if (use->opcode() == SpvOpStore) {
  159. // If we are storing to part of the object it is not an candidate.
  160. return ptr_inst->opcode() == SpvOpVariable &&
  161. store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
  162. ptr_inst->result_id();
  163. }
  164. // Some other instruction. Be conservative.
  165. return false;
  166. });
  167. }
  168. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  169. CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
  170. Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
  171. switch (result_inst->opcode()) {
  172. case SpvOpLoad:
  173. return BuildMemoryObjectFromLoad(result_inst);
  174. case SpvOpCompositeExtract:
  175. return BuildMemoryObjectFromExtract(result_inst);
  176. case SpvOpCompositeConstruct:
  177. return BuildMemoryObjectFromCompositeConstruct(result_inst);
  178. case SpvOpCopyObject:
  179. return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
  180. case SpvOpCompositeInsert:
  181. return BuildMemoryObjectFromInsert(result_inst);
  182. default:
  183. return nullptr;
  184. }
  185. }
  186. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  187. CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
  188. std::vector<uint32_t> components_in_reverse;
  189. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  190. Instruction* current_inst = def_use_mgr->GetDef(
  191. load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
  192. // Build the access chain for the memory object by collecting the indices used
  193. // in the OpAccessChain instructions. If we find a variable index, then
  194. // return |nullptr| because we cannot know for sure which memory location is
  195. // used.
  196. //
  197. // It is built in reverse order because the different |OpAccessChain|
  198. // instructions are visited in reverse order from which they are applied.
  199. while (current_inst->opcode() == SpvOpAccessChain) {
  200. for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
  201. uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
  202. components_in_reverse.push_back(element_index_id);
  203. }
  204. current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
  205. }
  206. // If the address in the load is not constructed from an |OpVariable|
  207. // instruction followed by a series of |OpAccessChain| instructions, then
  208. // return |nullptr| because we cannot identify the owner or access chain
  209. // exactly.
  210. if (current_inst->opcode() != SpvOpVariable) {
  211. return nullptr;
  212. }
  213. // Build the memory object. Use |rbegin| and |rend| to put the access chain
  214. // back in the correct order.
  215. return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
  216. new MemoryObject(current_inst, components_in_reverse.rbegin(),
  217. components_in_reverse.rend()));
  218. }
  219. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  220. CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
  221. assert(extract_inst->opcode() == SpvOpCompositeExtract &&
  222. "Expecting an OpCompositeExtract instruction.");
  223. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  224. std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
  225. extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
  226. if (result) {
  227. analysis::Integer int_type(32, false);
  228. const analysis::Type* uint32_type =
  229. context()->get_type_mgr()->GetRegisteredType(&int_type);
  230. std::vector<uint32_t> components;
  231. // Convert the indices in the extract instruction to a series of ids that
  232. // can be used by the |OpAccessChain| instruction.
  233. for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
  234. uint32_t index = extract_inst->GetSingleWordInOperand(i);
  235. const analysis::Constant* index_const =
  236. const_mgr->GetConstant(uint32_type, {index});
  237. components.push_back(
  238. const_mgr->GetDefiningInstruction(index_const)->result_id());
  239. }
  240. result->GetMember(components);
  241. return result;
  242. }
  243. return nullptr;
  244. }
  245. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  246. CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
  247. Instruction* conststruct_inst) {
  248. assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
  249. "Expecting an OpCompositeConstruct instruction.");
  250. // If every operand in the instruction are part of the same memory object, and
  251. // are being combined in the same order, then the result is the same as the
  252. // parent.
  253. std::unique_ptr<MemoryObject> memory_object =
  254. GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
  255. if (!memory_object) {
  256. return nullptr;
  257. }
  258. if (!memory_object->IsMember()) {
  259. return nullptr;
  260. }
  261. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  262. const analysis::Constant* last_access =
  263. const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
  264. if (!last_access || !last_access->type()->AsInteger()) {
  265. return nullptr;
  266. }
  267. if (last_access->GetU32() != 0) {
  268. return nullptr;
  269. }
  270. memory_object->GetParent();
  271. if (memory_object->GetNumberOfMembers() !=
  272. conststruct_inst->NumInOperands()) {
  273. return nullptr;
  274. }
  275. for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
  276. std::unique_ptr<MemoryObject> member_object =
  277. GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
  278. if (!member_object) {
  279. return nullptr;
  280. }
  281. if (!member_object->IsMember()) {
  282. return nullptr;
  283. }
  284. if (!memory_object->Contains(member_object.get())) {
  285. return nullptr;
  286. }
  287. last_access =
  288. const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
  289. if (!last_access || !last_access->type()->AsInteger()) {
  290. return nullptr;
  291. }
  292. if (last_access->GetU32() != i) {
  293. return nullptr;
  294. }
  295. }
  296. return memory_object;
  297. }
  298. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  299. CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
  300. assert(insert_inst->opcode() == SpvOpCompositeInsert &&
  301. "Expecting an OpCompositeInsert instruction.");
  302. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  303. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  304. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  305. const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
  306. uint32_t number_of_elements = 0;
  307. if (const analysis::Struct* struct_type = result_type->AsStruct()) {
  308. number_of_elements =
  309. static_cast<uint32_t>(struct_type->element_types().size());
  310. } else if (const analysis::Array* array_type = result_type->AsArray()) {
  311. const analysis::Constant* length_const =
  312. const_mgr->FindDeclaredConstant(array_type->LengthId());
  313. number_of_elements = length_const->GetU32();
  314. } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
  315. number_of_elements = vector_type->element_count();
  316. } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
  317. number_of_elements = matrix_type->element_count();
  318. }
  319. if (number_of_elements == 0) {
  320. return nullptr;
  321. }
  322. if (insert_inst->NumInOperands() != 3) {
  323. return nullptr;
  324. }
  325. if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
  326. return nullptr;
  327. }
  328. std::unique_ptr<MemoryObject> memory_object =
  329. GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
  330. if (!memory_object) {
  331. return nullptr;
  332. }
  333. if (!memory_object->IsMember()) {
  334. return nullptr;
  335. }
  336. const analysis::Constant* last_access =
  337. const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
  338. if (!last_access || !last_access->type()->AsInteger()) {
  339. return nullptr;
  340. }
  341. if (last_access->GetU32() != number_of_elements - 1) {
  342. return nullptr;
  343. }
  344. memory_object->GetParent();
  345. Instruction* current_insert =
  346. def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
  347. for (uint32_t i = number_of_elements - 1; i > 0; --i) {
  348. if (current_insert->opcode() != SpvOpCompositeInsert) {
  349. return nullptr;
  350. }
  351. if (current_insert->NumInOperands() != 3) {
  352. return nullptr;
  353. }
  354. if (current_insert->GetSingleWordInOperand(2) != i - 1) {
  355. return nullptr;
  356. }
  357. std::unique_ptr<MemoryObject> current_memory_object =
  358. GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
  359. if (!current_memory_object) {
  360. return nullptr;
  361. }
  362. if (!current_memory_object->IsMember()) {
  363. return nullptr;
  364. }
  365. if (memory_object->AccessChain().size() + 1 !=
  366. current_memory_object->AccessChain().size()) {
  367. return nullptr;
  368. }
  369. if (!memory_object->Contains(current_memory_object.get())) {
  370. return nullptr;
  371. }
  372. const analysis::Constant* current_last_access =
  373. const_mgr->FindDeclaredConstant(
  374. current_memory_object->AccessChain().back());
  375. if (!current_last_access || !current_last_access->type()->AsInteger()) {
  376. return nullptr;
  377. }
  378. if (current_last_access->GetU32() != i - 1) {
  379. return nullptr;
  380. }
  381. current_insert =
  382. def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
  383. }
  384. return memory_object;
  385. }
  386. bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
  387. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  388. analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
  389. if (pointer_type) {
  390. return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
  391. pointer_type->pointee_type()->kind() == analysis::Type::kImage;
  392. }
  393. return false;
  394. }
  395. bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
  396. uint32_t type_id) {
  397. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  398. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  399. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  400. analysis::Type* type = type_mgr->GetType(type_id);
  401. if (type->AsRuntimeArray()) {
  402. return false;
  403. }
  404. if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
  405. // If the type is not an aggregate, then the desired type must be the
  406. // same as the current type. No work to do, and we can do that.
  407. return true;
  408. }
  409. return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
  410. const_mgr,
  411. type](Instruction* use,
  412. uint32_t) {
  413. switch (use->opcode()) {
  414. case SpvOpLoad: {
  415. analysis::Pointer* pointer_type = type->AsPointer();
  416. uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
  417. if (new_type_id != use->type_id()) {
  418. return CanUpdateUses(use, new_type_id);
  419. }
  420. return true;
  421. }
  422. case SpvOpAccessChain: {
  423. analysis::Pointer* pointer_type = type->AsPointer();
  424. const analysis::Type* pointee_type = pointer_type->pointee_type();
  425. std::vector<uint32_t> access_chain;
  426. for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
  427. const analysis::Constant* index_const =
  428. const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
  429. if (index_const) {
  430. access_chain.push_back(index_const->GetU32());
  431. } else {
  432. // Variable index means the type is a type where every element
  433. // is the same type. Use element 0 to get the type.
  434. access_chain.push_back(0);
  435. }
  436. }
  437. const analysis::Type* new_pointee_type =
  438. type_mgr->GetMemberType(pointee_type, access_chain);
  439. analysis::Pointer pointerTy(new_pointee_type,
  440. pointer_type->storage_class());
  441. uint32_t new_pointer_type_id =
  442. context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
  443. if (new_pointer_type_id == 0) {
  444. return false;
  445. }
  446. if (new_pointer_type_id != use->type_id()) {
  447. return CanUpdateUses(use, new_pointer_type_id);
  448. }
  449. return true;
  450. }
  451. case SpvOpCompositeExtract: {
  452. std::vector<uint32_t> access_chain;
  453. for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
  454. access_chain.push_back(use->GetSingleWordInOperand(i));
  455. }
  456. const analysis::Type* new_type =
  457. type_mgr->GetMemberType(type, access_chain);
  458. uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
  459. if (new_type_id == 0) {
  460. return false;
  461. }
  462. if (new_type_id != use->type_id()) {
  463. return CanUpdateUses(use, new_type_id);
  464. }
  465. return true;
  466. }
  467. case SpvOpStore:
  468. // If needed, we can create an element-by-element copy to change the
  469. // type of the value being stored. This way we can always handled
  470. // stores.
  471. return true;
  472. case SpvOpImageTexelPointer:
  473. case SpvOpName:
  474. return true;
  475. default:
  476. return use->IsDecoration();
  477. }
  478. });
  479. }
  480. void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
  481. Instruction* new_ptr_inst) {
  482. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  483. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  484. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  485. std::vector<std::pair<Instruction*, uint32_t> > uses;
  486. def_use_mgr->ForEachUse(original_ptr_inst,
  487. [&uses](Instruction* use, uint32_t index) {
  488. uses.push_back({use, index});
  489. });
  490. for (auto pair : uses) {
  491. Instruction* use = pair.first;
  492. uint32_t index = pair.second;
  493. switch (use->opcode()) {
  494. case SpvOpLoad: {
  495. // Replace the actual use.
  496. context()->ForgetUses(use);
  497. use->SetOperand(index, {new_ptr_inst->result_id()});
  498. // Update the type.
  499. Instruction* pointer_type_inst =
  500. def_use_mgr->GetDef(new_ptr_inst->type_id());
  501. uint32_t new_type_id =
  502. pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
  503. if (new_type_id != use->type_id()) {
  504. use->SetResultType(new_type_id);
  505. context()->AnalyzeUses(use);
  506. UpdateUses(use, use);
  507. } else {
  508. context()->AnalyzeUses(use);
  509. }
  510. } break;
  511. case SpvOpAccessChain: {
  512. // Update the actual use.
  513. context()->ForgetUses(use);
  514. use->SetOperand(index, {new_ptr_inst->result_id()});
  515. // Convert the ids on the OpAccessChain to indices that can be used to
  516. // get the specific member.
  517. std::vector<uint32_t> access_chain;
  518. for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
  519. const analysis::Constant* index_const =
  520. const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
  521. if (index_const) {
  522. access_chain.push_back(index_const->GetU32());
  523. } else {
  524. // Variable index means the type is an type where every element
  525. // is the same type. Use element 0 to get the type.
  526. access_chain.push_back(0);
  527. }
  528. }
  529. Instruction* pointer_type_inst =
  530. get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
  531. uint32_t new_pointee_type_id = GetMemberTypeId(
  532. pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
  533. access_chain);
  534. SpvStorageClass storage_class = static_cast<SpvStorageClass>(
  535. pointer_type_inst->GetSingleWordInOperand(
  536. kTypePointerStorageClassInIdx));
  537. uint32_t new_pointer_type_id =
  538. type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
  539. if (new_pointer_type_id != use->type_id()) {
  540. use->SetResultType(new_pointer_type_id);
  541. context()->AnalyzeUses(use);
  542. UpdateUses(use, use);
  543. } else {
  544. context()->AnalyzeUses(use);
  545. }
  546. } break;
  547. case SpvOpCompositeExtract: {
  548. // Update the actual use.
  549. context()->ForgetUses(use);
  550. use->SetOperand(index, {new_ptr_inst->result_id()});
  551. uint32_t new_type_id = new_ptr_inst->type_id();
  552. std::vector<uint32_t> access_chain;
  553. for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
  554. access_chain.push_back(use->GetSingleWordInOperand(i));
  555. }
  556. new_type_id = GetMemberTypeId(new_type_id, access_chain);
  557. if (new_type_id != use->type_id()) {
  558. use->SetResultType(new_type_id);
  559. context()->AnalyzeUses(use);
  560. UpdateUses(use, use);
  561. } else {
  562. context()->AnalyzeUses(use);
  563. }
  564. } break;
  565. case SpvOpStore:
  566. // If the use is the pointer, then it is the single store to that
  567. // variable. We do not want to replace it. Instead, it will become
  568. // dead after all of the loads are removed, and ADCE will get rid of it.
  569. //
  570. // If the use is the object being stored, we will create a copy of the
  571. // object turning it into the correct type. The copy is done by
  572. // decomposing the object into the base type, which must be the same,
  573. // and then rebuilding them.
  574. if (index == 1) {
  575. Instruction* target_pointer = def_use_mgr->GetDef(
  576. use->GetSingleWordInOperand(kStorePointerInOperand));
  577. Instruction* pointer_type =
  578. def_use_mgr->GetDef(target_pointer->type_id());
  579. uint32_t pointee_type_id =
  580. pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
  581. uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
  582. context()->ForgetUses(use);
  583. use->SetInOperand(index, {copy});
  584. context()->AnalyzeUses(use);
  585. }
  586. break;
  587. case SpvOpImageTexelPointer:
  588. // We treat an OpImageTexelPointer as a load. The result type should
  589. // always have the Image storage class, and should not need to be
  590. // updated.
  591. // Replace the actual use.
  592. context()->ForgetUses(use);
  593. use->SetOperand(index, {new_ptr_inst->result_id()});
  594. context()->AnalyzeUses(use);
  595. break;
  596. default:
  597. assert(false && "Don't know how to rewrite instruction");
  598. break;
  599. }
  600. }
  601. }
  602. uint32_t CopyPropagateArrays::GetMemberTypeId(
  603. uint32_t id, const std::vector<uint32_t>& access_chain) const {
  604. for (uint32_t element_index : access_chain) {
  605. Instruction* type_inst = get_def_use_mgr()->GetDef(id);
  606. switch (type_inst->opcode()) {
  607. case SpvOpTypeArray:
  608. case SpvOpTypeRuntimeArray:
  609. case SpvOpTypeMatrix:
  610. case SpvOpTypeVector:
  611. id = type_inst->GetSingleWordInOperand(0);
  612. break;
  613. case SpvOpTypeStruct:
  614. id = type_inst->GetSingleWordInOperand(element_index);
  615. break;
  616. default:
  617. break;
  618. }
  619. assert(id != 0 &&
  620. "Tried to extract from an object where it cannot be done.");
  621. }
  622. return id;
  623. }
  624. void CopyPropagateArrays::MemoryObject::GetMember(
  625. const std::vector<uint32_t>& access_chain) {
  626. access_chain_.insert(access_chain_.end(), access_chain.begin(),
  627. access_chain.end());
  628. }
  629. uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
  630. IRContext* context = variable_inst_->context();
  631. analysis::TypeManager* type_mgr = context->get_type_mgr();
  632. const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
  633. type = type->AsPointer()->pointee_type();
  634. std::vector<uint32_t> access_indices = GetAccessIds();
  635. type = type_mgr->GetMemberType(type, access_indices);
  636. if (const analysis::Struct* struct_type = type->AsStruct()) {
  637. return static_cast<uint32_t>(struct_type->element_types().size());
  638. } else if (const analysis::Array* array_type = type->AsArray()) {
  639. const analysis::Constant* length_const =
  640. context->get_constant_mgr()->FindDeclaredConstant(
  641. array_type->LengthId());
  642. assert(length_const->type()->AsInteger());
  643. return length_const->GetU32();
  644. } else if (const analysis::Vector* vector_type = type->AsVector()) {
  645. return vector_type->element_count();
  646. } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
  647. return matrix_type->element_count();
  648. } else {
  649. return 0;
  650. }
  651. }
  652. template <class iterator>
  653. CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
  654. iterator begin, iterator end)
  655. : variable_inst_(var_inst), access_chain_(begin, end) {}
  656. std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
  657. analysis::ConstantManager* const_mgr =
  658. variable_inst_->context()->get_constant_mgr();
  659. std::vector<uint32_t> access_indices;
  660. for (uint32_t id : AccessChain()) {
  661. const analysis::Constant* element_index_const =
  662. const_mgr->FindDeclaredConstant(id);
  663. if (!element_index_const) {
  664. access_indices.push_back(0);
  665. } else {
  666. access_indices.push_back(element_index_const->GetU32());
  667. }
  668. }
  669. return access_indices;
  670. }
  671. bool CopyPropagateArrays::MemoryObject::Contains(
  672. CopyPropagateArrays::MemoryObject* other) {
  673. if (this->GetVariable() != other->GetVariable()) {
  674. return false;
  675. }
  676. if (AccessChain().size() > other->AccessChain().size()) {
  677. return false;
  678. }
  679. for (uint32_t i = 0; i < AccessChain().size(); i++) {
  680. if (AccessChain()[i] != other->AccessChain()[i]) {
  681. return false;
  682. }
  683. }
  684. return true;
  685. }
  686. } // namespace opt
  687. } // namespace spvtools