copy_prop_arrays.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/copy_prop_arrays.h"
  15. #include <utility>
  16. #include "source/opt/ir_builder.h"
  17. namespace spvtools {
  18. namespace opt {
  19. namespace {
  20. constexpr uint32_t kLoadPointerInOperand = 0;
  21. constexpr uint32_t kStorePointerInOperand = 0;
  22. constexpr uint32_t kStoreObjectInOperand = 1;
  23. constexpr uint32_t kCompositeExtractObjectInOperand = 0;
  24. constexpr uint32_t kTypePointerStorageClassInIdx = 0;
  25. constexpr uint32_t kTypePointerPointeeInIdx = 1;
  26. bool IsDebugDeclareOrValue(Instruction* di) {
  27. auto dbg_opcode = di->GetCommonDebugOpcode();
  28. return dbg_opcode == CommonDebugInfoDebugDeclare ||
  29. dbg_opcode == CommonDebugInfoDebugValue;
  30. }
  31. } // namespace
  32. Pass::Status CopyPropagateArrays::Process() {
  33. bool modified = false;
  34. for (Function& function : *get_module()) {
  35. if (function.IsDeclaration()) {
  36. continue;
  37. }
  38. BasicBlock* entry_bb = &*function.begin();
  39. for (auto var_inst = entry_bb->begin();
  40. var_inst->opcode() == spv::Op::OpVariable; ++var_inst) {
  41. if (!IsPointerToArrayType(var_inst->type_id())) {
  42. continue;
  43. }
  44. // Find the only store to the entire memory location, if it exists.
  45. Instruction* store_inst = FindStoreInstruction(&*var_inst);
  46. if (!store_inst) {
  47. continue;
  48. }
  49. std::unique_ptr<MemoryObject> source_object =
  50. FindSourceObjectIfPossible(&*var_inst, store_inst);
  51. if (source_object != nullptr) {
  52. if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
  53. modified = true;
  54. PropagateObject(&*var_inst, source_object.get(), store_inst);
  55. }
  56. }
  57. }
  58. }
  59. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  60. }
  61. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  62. CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
  63. Instruction* store_inst) {
  64. assert(var_inst->opcode() == spv::Op::OpVariable && "Expecting a variable.");
  65. // Check that the variable is a composite object where |store_inst|
  66. // dominates all of its loads.
  67. if (!store_inst) {
  68. return nullptr;
  69. }
  70. // Look at the loads to ensure they are dominated by the store.
  71. if (!HasValidReferencesOnly(var_inst, store_inst)) {
  72. return nullptr;
  73. }
  74. // If so, look at the store to see if it is the copy of an object.
  75. std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
  76. store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
  77. if (!source) {
  78. return nullptr;
  79. }
  80. // Ensure that |source| does not change between the point at which it is
  81. // loaded, and the position in which |var_inst| is loaded.
  82. //
  83. // For now we will go with the easy to implement approach, and check that the
  84. // entire variable (not just the specific component) is never written to.
  85. if (!HasNoStores(source->GetVariable())) {
  86. return nullptr;
  87. }
  88. return source;
  89. }
  90. Instruction* CopyPropagateArrays::FindStoreInstruction(
  91. const Instruction* var_inst) const {
  92. Instruction* store_inst = nullptr;
  93. get_def_use_mgr()->WhileEachUser(
  94. var_inst, [&store_inst, var_inst](Instruction* use) {
  95. if (use->opcode() == spv::Op::OpStore &&
  96. use->GetSingleWordInOperand(kStorePointerInOperand) ==
  97. var_inst->result_id()) {
  98. if (store_inst == nullptr) {
  99. store_inst = use;
  100. } else {
  101. store_inst = nullptr;
  102. return false;
  103. }
  104. }
  105. return true;
  106. });
  107. return store_inst;
  108. }
  109. void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
  110. MemoryObject* source,
  111. Instruction* insertion_point) {
  112. assert(var_inst->opcode() == spv::Op::OpVariable &&
  113. "This function propagates variables.");
  114. Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
  115. context()->KillNamesAndDecorates(var_inst);
  116. UpdateUses(var_inst, new_access_chain);
  117. }
  118. Instruction* CopyPropagateArrays::BuildNewAccessChain(
  119. Instruction* insertion_point,
  120. CopyPropagateArrays::MemoryObject* source) const {
  121. InstructionBuilder builder(
  122. context(), insertion_point,
  123. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  124. if (source->AccessChain().size() == 0) {
  125. return source->GetVariable();
  126. }
  127. source->BuildConstants();
  128. std::vector<uint32_t> access_ids(source->AccessChain().size());
  129. std::transform(
  130. source->AccessChain().cbegin(), source->AccessChain().cend(),
  131. access_ids.begin(), [](const AccessChainEntry& entry) {
  132. assert(entry.is_result_id && "Constants needs to be built first.");
  133. return entry.result_id;
  134. });
  135. return builder.AddAccessChain(source->GetPointerTypeId(this),
  136. source->GetVariable()->result_id(), access_ids);
  137. }
  138. bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
  139. return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
  140. if (use->opcode() == spv::Op::OpLoad) {
  141. return true;
  142. } else if (use->opcode() == spv::Op::OpAccessChain) {
  143. return HasNoStores(use);
  144. } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
  145. return true;
  146. } else if (use->opcode() == spv::Op::OpStore) {
  147. return false;
  148. } else if (use->opcode() == spv::Op::OpImageTexelPointer) {
  149. return true;
  150. } else if (use->opcode() == spv::Op::OpEntryPoint) {
  151. return true;
  152. }
  153. // Some other instruction. Be conservative.
  154. return false;
  155. });
  156. }
  157. bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
  158. Instruction* store_inst) {
  159. BasicBlock* store_block = context()->get_instr_block(store_inst);
  160. DominatorAnalysis* dominator_analysis =
  161. context()->GetDominatorAnalysis(store_block->GetParent());
  162. return get_def_use_mgr()->WhileEachUser(
  163. ptr_inst,
  164. [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
  165. if (use->opcode() == spv::Op::OpLoad ||
  166. use->opcode() == spv::Op::OpImageTexelPointer) {
  167. // TODO: If there are many load in the same BB as |store_inst| the
  168. // time to do the multiple traverses can add up. Consider collecting
  169. // those loads and doing a single traversal.
  170. return dominator_analysis->Dominates(store_inst, use);
  171. } else if (use->opcode() == spv::Op::OpAccessChain) {
  172. return HasValidReferencesOnly(use, store_inst);
  173. } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
  174. return true;
  175. } else if (use->opcode() == spv::Op::OpStore) {
  176. // If we are storing to part of the object it is not an candidate.
  177. return ptr_inst->opcode() == spv::Op::OpVariable &&
  178. store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
  179. ptr_inst->result_id();
  180. } else if (IsDebugDeclareOrValue(use)) {
  181. return true;
  182. }
  183. // Some other instruction. Be conservative.
  184. return false;
  185. });
  186. }
  187. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  188. CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
  189. Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
  190. switch (result_inst->opcode()) {
  191. case spv::Op::OpLoad:
  192. return BuildMemoryObjectFromLoad(result_inst);
  193. case spv::Op::OpCompositeExtract:
  194. return BuildMemoryObjectFromExtract(result_inst);
  195. case spv::Op::OpCompositeConstruct:
  196. return BuildMemoryObjectFromCompositeConstruct(result_inst);
  197. case spv::Op::OpCopyObject:
  198. return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
  199. case spv::Op::OpCompositeInsert:
  200. return BuildMemoryObjectFromInsert(result_inst);
  201. default:
  202. return nullptr;
  203. }
  204. }
  205. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  206. CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
  207. std::vector<uint32_t> components_in_reverse;
  208. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  209. Instruction* current_inst = def_use_mgr->GetDef(
  210. load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
  211. // Build the access chain for the memory object by collecting the indices used
  212. // in the OpAccessChain instructions. If we find a variable index, then
  213. // return |nullptr| because we cannot know for sure which memory location is
  214. // used.
  215. //
  216. // It is built in reverse order because the different |OpAccessChain|
  217. // instructions are visited in reverse order from which they are applied.
  218. while (current_inst->opcode() == spv::Op::OpAccessChain) {
  219. for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
  220. uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
  221. components_in_reverse.push_back(element_index_id);
  222. }
  223. current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
  224. }
  225. // If the address in the load is not constructed from an |OpVariable|
  226. // instruction followed by a series of |OpAccessChain| instructions, then
  227. // return |nullptr| because we cannot identify the owner or access chain
  228. // exactly.
  229. if (current_inst->opcode() != spv::Op::OpVariable) {
  230. return nullptr;
  231. }
  232. // Build the memory object. Use |rbegin| and |rend| to put the access chain
  233. // back in the correct order.
  234. return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
  235. new MemoryObject(current_inst, components_in_reverse.rbegin(),
  236. components_in_reverse.rend()));
  237. }
  238. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  239. CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
  240. assert(extract_inst->opcode() == spv::Op::OpCompositeExtract &&
  241. "Expecting an OpCompositeExtract instruction.");
  242. std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
  243. extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
  244. if (!result) {
  245. return nullptr;
  246. }
  247. // Copy the indices of the extract instruction to |OpAccessChain| indices.
  248. std::vector<AccessChainEntry> components;
  249. for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
  250. components.push_back({false, {extract_inst->GetSingleWordInOperand(i)}});
  251. }
  252. result->PushIndirection(components);
  253. return result;
  254. }
  255. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  256. CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
  257. Instruction* conststruct_inst) {
  258. assert(conststruct_inst->opcode() == spv::Op::OpCompositeConstruct &&
  259. "Expecting an OpCompositeConstruct instruction.");
  260. // If every operand in the instruction are part of the same memory object, and
  261. // are being combined in the same order, then the result is the same as the
  262. // parent.
  263. std::unique_ptr<MemoryObject> memory_object =
  264. GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
  265. if (!memory_object) {
  266. return nullptr;
  267. }
  268. if (!memory_object->IsMember()) {
  269. return nullptr;
  270. }
  271. AccessChainEntry last_access = memory_object->AccessChain().back();
  272. if (!IsAccessChainIndexValidAndEqualTo(last_access, 0)) {
  273. return nullptr;
  274. }
  275. memory_object->PopIndirection();
  276. if (memory_object->GetNumberOfMembers() !=
  277. conststruct_inst->NumInOperands()) {
  278. return nullptr;
  279. }
  280. for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
  281. std::unique_ptr<MemoryObject> member_object =
  282. GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
  283. if (!member_object) {
  284. return nullptr;
  285. }
  286. if (!member_object->IsMember()) {
  287. return nullptr;
  288. }
  289. if (!memory_object->Contains(member_object.get())) {
  290. return nullptr;
  291. }
  292. last_access = member_object->AccessChain().back();
  293. if (!IsAccessChainIndexValidAndEqualTo(last_access, i)) {
  294. return nullptr;
  295. }
  296. }
  297. return memory_object;
  298. }
  299. std::unique_ptr<CopyPropagateArrays::MemoryObject>
  300. CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
  301. assert(insert_inst->opcode() == spv::Op::OpCompositeInsert &&
  302. "Expecting an OpCompositeInsert instruction.");
  303. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  304. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  305. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  306. const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
  307. uint32_t number_of_elements = 0;
  308. if (const analysis::Struct* struct_type = result_type->AsStruct()) {
  309. number_of_elements =
  310. static_cast<uint32_t>(struct_type->element_types().size());
  311. } else if (const analysis::Array* array_type = result_type->AsArray()) {
  312. const analysis::Constant* length_const =
  313. const_mgr->FindDeclaredConstant(array_type->LengthId());
  314. number_of_elements = length_const->GetU32();
  315. } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
  316. number_of_elements = vector_type->element_count();
  317. } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
  318. number_of_elements = matrix_type->element_count();
  319. }
  320. if (number_of_elements == 0) {
  321. return nullptr;
  322. }
  323. if (insert_inst->NumInOperands() != 3) {
  324. return nullptr;
  325. }
  326. if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
  327. return nullptr;
  328. }
  329. std::unique_ptr<MemoryObject> memory_object =
  330. GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
  331. if (!memory_object) {
  332. return nullptr;
  333. }
  334. if (!memory_object->IsMember()) {
  335. return nullptr;
  336. }
  337. AccessChainEntry last_access = memory_object->AccessChain().back();
  338. if (!IsAccessChainIndexValidAndEqualTo(last_access, number_of_elements - 1)) {
  339. return nullptr;
  340. }
  341. memory_object->PopIndirection();
  342. Instruction* current_insert =
  343. def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
  344. for (uint32_t i = number_of_elements - 1; i > 0; --i) {
  345. if (current_insert->opcode() != spv::Op::OpCompositeInsert) {
  346. return nullptr;
  347. }
  348. if (current_insert->NumInOperands() != 3) {
  349. return nullptr;
  350. }
  351. if (current_insert->GetSingleWordInOperand(2) != i - 1) {
  352. return nullptr;
  353. }
  354. std::unique_ptr<MemoryObject> current_memory_object =
  355. GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
  356. if (!current_memory_object) {
  357. return nullptr;
  358. }
  359. if (!current_memory_object->IsMember()) {
  360. return nullptr;
  361. }
  362. if (memory_object->AccessChain().size() + 1 !=
  363. current_memory_object->AccessChain().size()) {
  364. return nullptr;
  365. }
  366. if (!memory_object->Contains(current_memory_object.get())) {
  367. return nullptr;
  368. }
  369. AccessChainEntry current_last_access =
  370. current_memory_object->AccessChain().back();
  371. if (!IsAccessChainIndexValidAndEqualTo(current_last_access, i - 1)) {
  372. return nullptr;
  373. }
  374. current_insert =
  375. def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
  376. }
  377. return memory_object;
  378. }
  379. bool CopyPropagateArrays::IsAccessChainIndexValidAndEqualTo(
  380. const AccessChainEntry& entry, uint32_t value) const {
  381. if (!entry.is_result_id) {
  382. return entry.immediate == value;
  383. }
  384. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  385. const analysis::Constant* constant =
  386. const_mgr->FindDeclaredConstant(entry.result_id);
  387. if (!constant || !constant->type()->AsInteger()) {
  388. return false;
  389. }
  390. return constant->GetU32() == value;
  391. }
  392. bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
  393. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  394. analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
  395. if (pointer_type) {
  396. return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
  397. pointer_type->pointee_type()->kind() == analysis::Type::kImage;
  398. }
  399. return false;
  400. }
  401. bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
  402. uint32_t type_id) {
  403. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  404. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  405. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  406. analysis::Type* type = type_mgr->GetType(type_id);
  407. if (type->AsRuntimeArray()) {
  408. return false;
  409. }
  410. if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
  411. // If the type is not an aggregate, then the desired type must be the
  412. // same as the current type. No work to do, and we can do that.
  413. return true;
  414. }
  415. return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
  416. const_mgr,
  417. type](Instruction* use,
  418. uint32_t) {
  419. if (IsDebugDeclareOrValue(use)) return true;
  420. switch (use->opcode()) {
  421. case spv::Op::OpLoad: {
  422. analysis::Pointer* pointer_type = type->AsPointer();
  423. uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
  424. if (new_type_id != use->type_id()) {
  425. return CanUpdateUses(use, new_type_id);
  426. }
  427. return true;
  428. }
  429. case spv::Op::OpAccessChain: {
  430. analysis::Pointer* pointer_type = type->AsPointer();
  431. const analysis::Type* pointee_type = pointer_type->pointee_type();
  432. std::vector<uint32_t> access_chain;
  433. for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
  434. const analysis::Constant* index_const =
  435. const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
  436. if (index_const) {
  437. access_chain.push_back(index_const->GetU32());
  438. } else {
  439. // Variable index means the type is a type where every element
  440. // is the same type. Use element 0 to get the type.
  441. access_chain.push_back(0);
  442. // We are trying to access a struct with variable indices.
  443. // This cannot happen.
  444. if (pointee_type->kind() == analysis::Type::kStruct) {
  445. return false;
  446. }
  447. }
  448. }
  449. const analysis::Type* new_pointee_type =
  450. type_mgr->GetMemberType(pointee_type, access_chain);
  451. analysis::Pointer pointerTy(new_pointee_type,
  452. pointer_type->storage_class());
  453. uint32_t new_pointer_type_id =
  454. context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
  455. if (new_pointer_type_id == 0) {
  456. return false;
  457. }
  458. if (new_pointer_type_id != use->type_id()) {
  459. return CanUpdateUses(use, new_pointer_type_id);
  460. }
  461. return true;
  462. }
  463. case spv::Op::OpCompositeExtract: {
  464. std::vector<uint32_t> access_chain;
  465. for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
  466. access_chain.push_back(use->GetSingleWordInOperand(i));
  467. }
  468. const analysis::Type* new_type =
  469. type_mgr->GetMemberType(type, access_chain);
  470. uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
  471. if (new_type_id == 0) {
  472. return false;
  473. }
  474. if (new_type_id != use->type_id()) {
  475. return CanUpdateUses(use, new_type_id);
  476. }
  477. return true;
  478. }
  479. case spv::Op::OpStore:
  480. // If needed, we can create an element-by-element copy to change the
  481. // type of the value being stored. This way we can always handled
  482. // stores.
  483. return true;
  484. case spv::Op::OpImageTexelPointer:
  485. case spv::Op::OpName:
  486. return true;
  487. default:
  488. return use->IsDecoration();
  489. }
  490. });
  491. }
  492. void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
  493. Instruction* new_ptr_inst) {
  494. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  495. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  496. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  497. std::vector<std::pair<Instruction*, uint32_t> > uses;
  498. def_use_mgr->ForEachUse(original_ptr_inst,
  499. [&uses](Instruction* use, uint32_t index) {
  500. uses.push_back({use, index});
  501. });
  502. for (auto pair : uses) {
  503. Instruction* use = pair.first;
  504. uint32_t index = pair.second;
  505. if (use->IsCommonDebugInstr()) {
  506. switch (use->GetCommonDebugOpcode()) {
  507. case CommonDebugInfoDebugDeclare: {
  508. if (new_ptr_inst->opcode() == spv::Op::OpVariable ||
  509. new_ptr_inst->opcode() == spv::Op::OpFunctionParameter) {
  510. context()->ForgetUses(use);
  511. use->SetOperand(index, {new_ptr_inst->result_id()});
  512. context()->AnalyzeUses(use);
  513. } else {
  514. // Based on the spec, we cannot use a pointer other than OpVariable
  515. // or OpFunctionParameter for DebugDeclare. We have to use
  516. // DebugValue with Deref.
  517. context()->ForgetUses(use);
  518. // Change DebugDeclare to DebugValue.
  519. use->SetOperand(index - 2,
  520. {static_cast<uint32_t>(CommonDebugInfoDebugValue)});
  521. use->SetOperand(index, {new_ptr_inst->result_id()});
  522. // Add Deref operation.
  523. Instruction* dbg_expr =
  524. def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
  525. auto* deref_expr_instr =
  526. context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
  527. use->SetOperand(index + 1, {deref_expr_instr->result_id()});
  528. context()->AnalyzeUses(deref_expr_instr);
  529. context()->AnalyzeUses(use);
  530. }
  531. break;
  532. }
  533. case CommonDebugInfoDebugValue:
  534. context()->ForgetUses(use);
  535. use->SetOperand(index, {new_ptr_inst->result_id()});
  536. context()->AnalyzeUses(use);
  537. break;
  538. default:
  539. assert(false && "Don't know how to rewrite instruction");
  540. break;
  541. }
  542. continue;
  543. }
  544. switch (use->opcode()) {
  545. case spv::Op::OpLoad: {
  546. // Replace the actual use.
  547. context()->ForgetUses(use);
  548. use->SetOperand(index, {new_ptr_inst->result_id()});
  549. // Update the type.
  550. Instruction* pointer_type_inst =
  551. def_use_mgr->GetDef(new_ptr_inst->type_id());
  552. uint32_t new_type_id =
  553. pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
  554. if (new_type_id != use->type_id()) {
  555. use->SetResultType(new_type_id);
  556. context()->AnalyzeUses(use);
  557. UpdateUses(use, use);
  558. } else {
  559. context()->AnalyzeUses(use);
  560. }
  561. } break;
  562. case spv::Op::OpAccessChain: {
  563. // Update the actual use.
  564. context()->ForgetUses(use);
  565. use->SetOperand(index, {new_ptr_inst->result_id()});
  566. // Convert the ids on the OpAccessChain to indices that can be used to
  567. // get the specific member.
  568. std::vector<uint32_t> access_chain;
  569. for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
  570. const analysis::Constant* index_const =
  571. const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
  572. if (index_const) {
  573. access_chain.push_back(index_const->GetU32());
  574. } else {
  575. // Variable index means the type is an type where every element
  576. // is the same type. Use element 0 to get the type.
  577. access_chain.push_back(0);
  578. }
  579. }
  580. Instruction* pointer_type_inst =
  581. get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
  582. uint32_t new_pointee_type_id = GetMemberTypeId(
  583. pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
  584. access_chain);
  585. spv::StorageClass storage_class = static_cast<spv::StorageClass>(
  586. pointer_type_inst->GetSingleWordInOperand(
  587. kTypePointerStorageClassInIdx));
  588. uint32_t new_pointer_type_id =
  589. type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
  590. if (new_pointer_type_id != use->type_id()) {
  591. use->SetResultType(new_pointer_type_id);
  592. context()->AnalyzeUses(use);
  593. UpdateUses(use, use);
  594. } else {
  595. context()->AnalyzeUses(use);
  596. }
  597. } break;
  598. case spv::Op::OpCompositeExtract: {
  599. // Update the actual use.
  600. context()->ForgetUses(use);
  601. use->SetOperand(index, {new_ptr_inst->result_id()});
  602. uint32_t new_type_id = new_ptr_inst->type_id();
  603. std::vector<uint32_t> access_chain;
  604. for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
  605. access_chain.push_back(use->GetSingleWordInOperand(i));
  606. }
  607. new_type_id = GetMemberTypeId(new_type_id, access_chain);
  608. if (new_type_id != use->type_id()) {
  609. use->SetResultType(new_type_id);
  610. context()->AnalyzeUses(use);
  611. UpdateUses(use, use);
  612. } else {
  613. context()->AnalyzeUses(use);
  614. }
  615. } break;
  616. case spv::Op::OpStore:
  617. // If the use is the pointer, then it is the single store to that
  618. // variable. We do not want to replace it. Instead, it will become
  619. // dead after all of the loads are removed, and ADCE will get rid of it.
  620. //
  621. // If the use is the object being stored, we will create a copy of the
  622. // object turning it into the correct type. The copy is done by
  623. // decomposing the object into the base type, which must be the same,
  624. // and then rebuilding them.
  625. if (index == 1) {
  626. Instruction* target_pointer = def_use_mgr->GetDef(
  627. use->GetSingleWordInOperand(kStorePointerInOperand));
  628. Instruction* pointer_type =
  629. def_use_mgr->GetDef(target_pointer->type_id());
  630. uint32_t pointee_type_id =
  631. pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
  632. uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
  633. context()->ForgetUses(use);
  634. use->SetInOperand(index, {copy});
  635. context()->AnalyzeUses(use);
  636. }
  637. break;
  638. case spv::Op::OpDecorate:
  639. // We treat an OpImageTexelPointer as a load. The result type should
  640. // always have the Image storage class, and should not need to be
  641. // updated.
  642. case spv::Op::OpImageTexelPointer:
  643. // Replace the actual use.
  644. context()->ForgetUses(use);
  645. use->SetOperand(index, {new_ptr_inst->result_id()});
  646. context()->AnalyzeUses(use);
  647. break;
  648. default:
  649. assert(false && "Don't know how to rewrite instruction");
  650. break;
  651. }
  652. }
  653. }
  654. uint32_t CopyPropagateArrays::GetMemberTypeId(
  655. uint32_t id, const std::vector<uint32_t>& access_chain) const {
  656. for (uint32_t element_index : access_chain) {
  657. Instruction* type_inst = get_def_use_mgr()->GetDef(id);
  658. switch (type_inst->opcode()) {
  659. case spv::Op::OpTypeArray:
  660. case spv::Op::OpTypeRuntimeArray:
  661. case spv::Op::OpTypeMatrix:
  662. case spv::Op::OpTypeVector:
  663. id = type_inst->GetSingleWordInOperand(0);
  664. break;
  665. case spv::Op::OpTypeStruct:
  666. id = type_inst->GetSingleWordInOperand(element_index);
  667. break;
  668. default:
  669. break;
  670. }
  671. assert(id != 0 &&
  672. "Tried to extract from an object where it cannot be done.");
  673. }
  674. return id;
  675. }
  676. void CopyPropagateArrays::MemoryObject::PushIndirection(
  677. const std::vector<AccessChainEntry>& access_chain) {
  678. access_chain_.insert(access_chain_.end(), access_chain.begin(),
  679. access_chain.end());
  680. }
  681. uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
  682. IRContext* context = variable_inst_->context();
  683. analysis::TypeManager* type_mgr = context->get_type_mgr();
  684. const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
  685. type = type->AsPointer()->pointee_type();
  686. std::vector<uint32_t> access_indices = GetAccessIds();
  687. type = type_mgr->GetMemberType(type, access_indices);
  688. if (const analysis::Struct* struct_type = type->AsStruct()) {
  689. return static_cast<uint32_t>(struct_type->element_types().size());
  690. } else if (const analysis::Array* array_type = type->AsArray()) {
  691. const analysis::Constant* length_const =
  692. context->get_constant_mgr()->FindDeclaredConstant(
  693. array_type->LengthId());
  694. assert(length_const->type()->AsInteger());
  695. return length_const->GetU32();
  696. } else if (const analysis::Vector* vector_type = type->AsVector()) {
  697. return vector_type->element_count();
  698. } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
  699. return matrix_type->element_count();
  700. } else {
  701. return 0;
  702. }
  703. }
  704. template <class iterator>
  705. CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
  706. iterator begin, iterator end)
  707. : variable_inst_(var_inst) {
  708. std::transform(begin, end, std::back_inserter(access_chain_),
  709. [](uint32_t id) {
  710. return AccessChainEntry{true, {id}};
  711. });
  712. }
  713. std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
  714. analysis::ConstantManager* const_mgr =
  715. variable_inst_->context()->get_constant_mgr();
  716. std::vector<uint32_t> indices(AccessChain().size());
  717. std::transform(AccessChain().cbegin(), AccessChain().cend(), indices.begin(),
  718. [&const_mgr](const AccessChainEntry& entry) {
  719. if (entry.is_result_id) {
  720. const analysis::Constant* constant =
  721. const_mgr->FindDeclaredConstant(entry.result_id);
  722. return constant == nullptr ? 0 : constant->GetU32();
  723. }
  724. return entry.immediate;
  725. });
  726. return indices;
  727. }
  728. bool CopyPropagateArrays::MemoryObject::Contains(
  729. CopyPropagateArrays::MemoryObject* other) {
  730. if (this->GetVariable() != other->GetVariable()) {
  731. return false;
  732. }
  733. if (AccessChain().size() > other->AccessChain().size()) {
  734. return false;
  735. }
  736. for (uint32_t i = 0; i < AccessChain().size(); i++) {
  737. if (AccessChain()[i] != other->AccessChain()[i]) {
  738. return false;
  739. }
  740. }
  741. return true;
  742. }
  743. void CopyPropagateArrays::MemoryObject::BuildConstants() {
  744. for (auto& entry : access_chain_) {
  745. if (entry.is_result_id) {
  746. continue;
  747. }
  748. auto context = variable_inst_->context();
  749. analysis::Integer int_type(32, false);
  750. const analysis::Type* uint32_type =
  751. context->get_type_mgr()->GetRegisteredType(&int_type);
  752. analysis::ConstantManager* const_mgr = context->get_constant_mgr();
  753. const analysis::Constant* index_const =
  754. const_mgr->GetConstant(uint32_type, {entry.immediate});
  755. entry.result_id =
  756. const_mgr->GetDefiningInstruction(index_const)->result_id();
  757. entry.is_result_id = true;
  758. }
  759. }
  760. } // namespace opt
  761. } // namespace spvtools