scalar_replacement_pass.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/scalar_replacement_pass.h"
  15. #include <algorithm>
  16. #include <queue>
  17. #include <tuple>
  18. #include <utility>
  19. #include "source/enum_string_mapping.h"
  20. #include "source/extensions.h"
  21. #include "source/opt/reflect.h"
  22. #include "source/opt/types.h"
  23. #include "source/util/make_unique.h"
  24. namespace spvtools {
  25. namespace opt {
  26. Pass::Status ScalarReplacementPass::Process() {
  27. Status status = Status::SuccessWithoutChange;
  28. for (auto& f : *get_module()) {
  29. Status functionStatus = ProcessFunction(&f);
  30. if (functionStatus == Status::Failure)
  31. return functionStatus;
  32. else if (functionStatus == Status::SuccessWithChange)
  33. status = functionStatus;
  34. }
  35. return status;
  36. }
  37. Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
  38. std::queue<Instruction*> worklist;
  39. BasicBlock& entry = *function->begin();
  40. for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
  41. // Function storage class OpVariables must appear as the first instructions
  42. // of the entry block.
  43. if (iter->opcode() != SpvOpVariable) break;
  44. Instruction* varInst = &*iter;
  45. if (CanReplaceVariable(varInst)) {
  46. worklist.push(varInst);
  47. }
  48. }
  49. Status status = Status::SuccessWithoutChange;
  50. while (!worklist.empty()) {
  51. Instruction* varInst = worklist.front();
  52. worklist.pop();
  53. Status var_status = ReplaceVariable(varInst, &worklist);
  54. if (var_status == Status::Failure)
  55. return var_status;
  56. else if (var_status == Status::SuccessWithChange)
  57. status = var_status;
  58. }
  59. return status;
  60. }
  61. Pass::Status ScalarReplacementPass::ReplaceVariable(
  62. Instruction* inst, std::queue<Instruction*>* worklist) {
  63. std::vector<Instruction*> replacements;
  64. if (!CreateReplacementVariables(inst, &replacements)) {
  65. return Status::Failure;
  66. }
  67. std::vector<Instruction*> dead;
  68. if (get_def_use_mgr()->WhileEachUser(
  69. inst, [this, &replacements, &dead](Instruction* user) {
  70. if (!IsAnnotationInst(user->opcode())) {
  71. switch (user->opcode()) {
  72. case SpvOpLoad:
  73. ReplaceWholeLoad(user, replacements);
  74. dead.push_back(user);
  75. break;
  76. case SpvOpStore:
  77. ReplaceWholeStore(user, replacements);
  78. dead.push_back(user);
  79. break;
  80. case SpvOpAccessChain:
  81. case SpvOpInBoundsAccessChain:
  82. if (ReplaceAccessChain(user, replacements))
  83. dead.push_back(user);
  84. else
  85. return false;
  86. break;
  87. case SpvOpName:
  88. case SpvOpMemberName:
  89. break;
  90. default:
  91. assert(false && "Unexpected opcode");
  92. break;
  93. }
  94. }
  95. return true;
  96. }))
  97. dead.push_back(inst);
  98. // If there are no dead instructions to clean up, return with no changes.
  99. if (dead.empty()) return Status::SuccessWithoutChange;
  100. // Clean up some dead code.
  101. while (!dead.empty()) {
  102. Instruction* toKill = dead.back();
  103. dead.pop_back();
  104. context()->KillInst(toKill);
  105. }
  106. // Attempt to further scalarize.
  107. for (auto var : replacements) {
  108. if (var->opcode() == SpvOpVariable) {
  109. if (get_def_use_mgr()->NumUsers(var) == 0) {
  110. context()->KillInst(var);
  111. } else if (CanReplaceVariable(var)) {
  112. worklist->push(var);
  113. }
  114. }
  115. }
  116. return Status::SuccessWithChange;
  117. }
  118. void ScalarReplacementPass::ReplaceWholeLoad(
  119. Instruction* load, const std::vector<Instruction*>& replacements) {
  120. // Replaces the load of the entire composite with a load from each replacement
  121. // variable followed by a composite construction.
  122. BasicBlock* block = context()->get_instr_block(load);
  123. std::vector<Instruction*> loads;
  124. loads.reserve(replacements.size());
  125. BasicBlock::iterator where(load);
  126. for (auto var : replacements) {
  127. // Create a load of each replacement variable.
  128. if (var->opcode() != SpvOpVariable) {
  129. loads.push_back(var);
  130. continue;
  131. }
  132. Instruction* type = GetStorageType(var);
  133. uint32_t loadId = TakeNextId();
  134. std::unique_ptr<Instruction> newLoad(
  135. new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
  136. std::initializer_list<Operand>{
  137. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  138. // Copy memory access attributes which start at index 1. Index 0 is the
  139. // pointer to load.
  140. for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
  141. Operand copy(load->GetInOperand(i));
  142. newLoad->AddOperand(std::move(copy));
  143. }
  144. where = where.InsertBefore(std::move(newLoad));
  145. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  146. context()->set_instr_block(&*where, block);
  147. loads.push_back(&*where);
  148. }
  149. // Construct a new composite.
  150. uint32_t compositeId = TakeNextId();
  151. where = load;
  152. std::unique_ptr<Instruction> compositeConstruct(new Instruction(
  153. context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
  154. for (auto l : loads) {
  155. Operand op(SPV_OPERAND_TYPE_ID,
  156. std::initializer_list<uint32_t>{l->result_id()});
  157. compositeConstruct->AddOperand(std::move(op));
  158. }
  159. where = where.InsertBefore(std::move(compositeConstruct));
  160. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  161. context()->set_instr_block(&*where, block);
  162. context()->ReplaceAllUsesWith(load->result_id(), compositeId);
  163. }
  164. void ScalarReplacementPass::ReplaceWholeStore(
  165. Instruction* store, const std::vector<Instruction*>& replacements) {
  166. // Replaces a store to the whole composite with a series of extract and stores
  167. // to each element.
  168. uint32_t storeInput = store->GetSingleWordInOperand(1u);
  169. BasicBlock* block = context()->get_instr_block(store);
  170. BasicBlock::iterator where(store);
  171. uint32_t elementIndex = 0;
  172. for (auto var : replacements) {
  173. // Create the extract.
  174. if (var->opcode() != SpvOpVariable) {
  175. elementIndex++;
  176. continue;
  177. }
  178. Instruction* type = GetStorageType(var);
  179. uint32_t extractId = TakeNextId();
  180. std::unique_ptr<Instruction> extract(new Instruction(
  181. context(), SpvOpCompositeExtract, type->result_id(), extractId,
  182. std::initializer_list<Operand>{
  183. {SPV_OPERAND_TYPE_ID, {storeInput}},
  184. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
  185. auto iter = where.InsertBefore(std::move(extract));
  186. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  187. context()->set_instr_block(&*iter, block);
  188. // Create the store.
  189. std::unique_ptr<Instruction> newStore(
  190. new Instruction(context(), SpvOpStore, 0, 0,
  191. std::initializer_list<Operand>{
  192. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  193. {SPV_OPERAND_TYPE_ID, {extractId}}}));
  194. // Copy memory access attributes which start at index 2. Index 0 is the
  195. // pointer and index 1 is the data.
  196. for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
  197. Operand copy(store->GetInOperand(i));
  198. newStore->AddOperand(std::move(copy));
  199. }
  200. iter = where.InsertBefore(std::move(newStore));
  201. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  202. context()->set_instr_block(&*iter, block);
  203. }
  204. }
  205. bool ScalarReplacementPass::ReplaceAccessChain(
  206. Instruction* chain, const std::vector<Instruction*>& replacements) {
  207. // Replaces the access chain with either another access chain (with one fewer
  208. // indexes) or a direct use of the replacement variable.
  209. uint32_t indexId = chain->GetSingleWordInOperand(1u);
  210. const Instruction* index = get_def_use_mgr()->GetDef(indexId);
  211. int64_t indexValue = context()
  212. ->get_constant_mgr()
  213. ->GetConstantFromInst(index)
  214. ->GetSignExtendedValue();
  215. if (indexValue < 0 ||
  216. indexValue >= static_cast<int64_t>(replacements.size())) {
  217. // Out of bounds access, this is illegal IR. Notice that OpAccessChain
  218. // indexing is 0-based, so we should also reject index == size-of-array.
  219. return false;
  220. } else {
  221. const Instruction* var = replacements[static_cast<size_t>(indexValue)];
  222. if (chain->NumInOperands() > 2) {
  223. // Replace input access chain with another access chain.
  224. BasicBlock::iterator chainIter(chain);
  225. uint32_t replacementId = TakeNextId();
  226. std::unique_ptr<Instruction> replacementChain(new Instruction(
  227. context(), chain->opcode(), chain->type_id(), replacementId,
  228. std::initializer_list<Operand>{
  229. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  230. // Add the remaining indexes.
  231. for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
  232. Operand copy(chain->GetInOperand(i));
  233. replacementChain->AddOperand(std::move(copy));
  234. }
  235. auto iter = chainIter.InsertBefore(std::move(replacementChain));
  236. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  237. context()->set_instr_block(&*iter, context()->get_instr_block(chain));
  238. context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
  239. } else {
  240. // Replace with a use of the variable.
  241. context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
  242. }
  243. }
  244. return true;
  245. }
  246. bool ScalarReplacementPass::CreateReplacementVariables(
  247. Instruction* inst, std::vector<Instruction*>* replacements) {
  248. Instruction* type = GetStorageType(inst);
  249. std::unique_ptr<std::unordered_set<int64_t>> components_used =
  250. GetUsedComponents(inst);
  251. uint32_t elem = 0;
  252. switch (type->opcode()) {
  253. case SpvOpTypeStruct:
  254. type->ForEachInOperand(
  255. [this, inst, &elem, replacements, &components_used](uint32_t* id) {
  256. if (!components_used || components_used->count(elem)) {
  257. CreateVariable(*id, inst, elem, replacements);
  258. } else {
  259. replacements->push_back(CreateNullConstant(*id));
  260. }
  261. elem++;
  262. });
  263. break;
  264. case SpvOpTypeArray:
  265. for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
  266. if (!components_used || components_used->count(i)) {
  267. CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
  268. replacements);
  269. } else {
  270. replacements->push_back(
  271. CreateNullConstant(type->GetSingleWordInOperand(0u)));
  272. }
  273. }
  274. break;
  275. case SpvOpTypeMatrix:
  276. case SpvOpTypeVector:
  277. for (uint32_t i = 0; i != GetNumElements(type); ++i) {
  278. CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
  279. }
  280. break;
  281. default:
  282. assert(false && "Unexpected type.");
  283. break;
  284. }
  285. TransferAnnotations(inst, replacements);
  286. return std::find(replacements->begin(), replacements->end(), nullptr) ==
  287. replacements->end();
  288. }
  289. void ScalarReplacementPass::TransferAnnotations(
  290. const Instruction* source, std::vector<Instruction*>* replacements) {
  291. // Only transfer invariant and restrict decorations on the variable. There are
  292. // no type or member decorations that are necessary to transfer.
  293. for (auto inst :
  294. get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
  295. assert(inst->opcode() == SpvOpDecorate);
  296. uint32_t decoration = inst->GetSingleWordInOperand(1u);
  297. if (decoration == SpvDecorationInvariant ||
  298. decoration == SpvDecorationRestrict) {
  299. for (auto var : *replacements) {
  300. std::unique_ptr<Instruction> annotation(
  301. new Instruction(context(), SpvOpDecorate, 0, 0,
  302. std::initializer_list<Operand>{
  303. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  304. {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
  305. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  306. Operand copy(inst->GetInOperand(i));
  307. annotation->AddOperand(std::move(copy));
  308. }
  309. context()->AddAnnotationInst(std::move(annotation));
  310. get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
  311. }
  312. }
  313. }
  314. }
  315. void ScalarReplacementPass::CreateVariable(
  316. uint32_t typeId, Instruction* varInst, uint32_t index,
  317. std::vector<Instruction*>* replacements) {
  318. uint32_t ptrId = GetOrCreatePointerType(typeId);
  319. uint32_t id = TakeNextId();
  320. std::unique_ptr<Instruction> variable(new Instruction(
  321. context(), SpvOpVariable, ptrId, id,
  322. std::initializer_list<Operand>{
  323. {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
  324. BasicBlock* block = context()->get_instr_block(varInst);
  325. block->begin().InsertBefore(std::move(variable));
  326. Instruction* inst = &*block->begin();
  327. // If varInst was initialized, make sure to initialize its replacement.
  328. GetOrCreateInitialValue(varInst, index, inst);
  329. get_def_use_mgr()->AnalyzeInstDefUse(inst);
  330. context()->set_instr_block(inst, block);
  331. // Copy decorations from the member to the new variable.
  332. Instruction* typeInst = GetStorageType(varInst);
  333. for (auto dec_inst :
  334. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  335. uint32_t decoration;
  336. if (dec_inst->opcode() != SpvOpMemberDecorate) {
  337. continue;
  338. }
  339. if (dec_inst->GetSingleWordInOperand(1) != index) {
  340. continue;
  341. }
  342. decoration = dec_inst->GetSingleWordInOperand(2u);
  343. switch (decoration) {
  344. case SpvDecorationRelaxedPrecision: {
  345. std::unique_ptr<Instruction> new_dec_inst(
  346. new Instruction(context(), SpvOpDecorate, 0, 0, {}));
  347. new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
  348. for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
  349. new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
  350. }
  351. context()->AddAnnotationInst(std::move(new_dec_inst));
  352. } break;
  353. default:
  354. break;
  355. }
  356. }
  357. replacements->push_back(inst);
  358. }
  359. uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
  360. auto iter = pointee_to_pointer_.find(id);
  361. if (iter != pointee_to_pointer_.end()) return iter->second;
  362. analysis::Type* pointeeTy;
  363. std::unique_ptr<analysis::Pointer> pointerTy;
  364. std::tie(pointeeTy, pointerTy) =
  365. context()->get_type_mgr()->GetTypeAndPointerType(id,
  366. SpvStorageClassFunction);
  367. uint32_t ptrId = 0;
  368. if (pointeeTy->IsUniqueType()) {
  369. // Non-ambiguous type, just ask the type manager for an id.
  370. ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
  371. pointee_to_pointer_[id] = ptrId;
  372. return ptrId;
  373. }
  374. // Ambiguous type. We must perform a linear search to try and find the right
  375. // type.
  376. for (auto global : context()->types_values()) {
  377. if (global.opcode() == SpvOpTypePointer &&
  378. global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
  379. global.GetSingleWordInOperand(1u) == id) {
  380. if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
  381. // Only reuse a decoration-less pointer of the correct type.
  382. ptrId = global.result_id();
  383. break;
  384. }
  385. }
  386. }
  387. if (ptrId != 0) {
  388. pointee_to_pointer_[id] = ptrId;
  389. return ptrId;
  390. }
  391. ptrId = TakeNextId();
  392. context()->AddType(MakeUnique<Instruction>(
  393. context(), SpvOpTypePointer, 0, ptrId,
  394. std::initializer_list<Operand>{
  395. {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
  396. {SPV_OPERAND_TYPE_ID, {id}}}));
  397. Instruction* ptr = &*--context()->types_values_end();
  398. get_def_use_mgr()->AnalyzeInstDefUse(ptr);
  399. pointee_to_pointer_[id] = ptrId;
  400. // Register with the type manager if necessary.
  401. context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
  402. return ptrId;
  403. }
  404. void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
  405. uint32_t index,
  406. Instruction* newVar) {
  407. assert(source->opcode() == SpvOpVariable);
  408. if (source->NumInOperands() < 2) return;
  409. uint32_t initId = source->GetSingleWordInOperand(1u);
  410. uint32_t storageId = GetStorageType(newVar)->result_id();
  411. Instruction* init = get_def_use_mgr()->GetDef(initId);
  412. uint32_t newInitId = 0;
  413. // TODO(dnovillo): Refactor this with constant propagation.
  414. if (init->opcode() == SpvOpConstantNull) {
  415. // Initialize to appropriate NULL.
  416. auto iter = type_to_null_.find(storageId);
  417. if (iter == type_to_null_.end()) {
  418. newInitId = TakeNextId();
  419. type_to_null_[storageId] = newInitId;
  420. context()->AddGlobalValue(
  421. MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
  422. newInitId, std::initializer_list<Operand>{}));
  423. Instruction* newNull = &*--context()->types_values_end();
  424. get_def_use_mgr()->AnalyzeInstDefUse(newNull);
  425. } else {
  426. newInitId = iter->second;
  427. }
  428. } else if (IsSpecConstantInst(init->opcode())) {
  429. // Create a new constant extract.
  430. newInitId = TakeNextId();
  431. context()->AddGlobalValue(MakeUnique<Instruction>(
  432. context(), SpvOpSpecConstantOp, storageId, newInitId,
  433. std::initializer_list<Operand>{
  434. {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
  435. {SPV_OPERAND_TYPE_ID, {init->result_id()}},
  436. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
  437. Instruction* newSpecConst = &*--context()->types_values_end();
  438. get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
  439. } else if (init->opcode() == SpvOpConstantComposite) {
  440. // Get the appropriate index constant.
  441. newInitId = init->GetSingleWordInOperand(index);
  442. Instruction* element = get_def_use_mgr()->GetDef(newInitId);
  443. if (element->opcode() == SpvOpUndef) {
  444. // Undef is not a valid initializer for a variable.
  445. newInitId = 0;
  446. }
  447. } else {
  448. assert(false);
  449. }
  450. if (newInitId != 0) {
  451. newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
  452. }
  453. }
  454. uint64_t ScalarReplacementPass::GetArrayLength(
  455. const Instruction* arrayType) const {
  456. assert(arrayType->opcode() == SpvOpTypeArray);
  457. const Instruction* length =
  458. get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
  459. return context()
  460. ->get_constant_mgr()
  461. ->GetConstantFromInst(length)
  462. ->GetZeroExtendedValue();
  463. }
  464. uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
  465. assert(type->opcode() == SpvOpTypeVector ||
  466. type->opcode() == SpvOpTypeMatrix);
  467. const Operand& op = type->GetInOperand(1u);
  468. assert(op.words.size() <= 2);
  469. uint64_t len = 0;
  470. for (size_t i = 0; i != op.words.size(); ++i) {
  471. len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
  472. }
  473. return len;
  474. }
  475. bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
  476. const Instruction* inst = get_def_use_mgr()->GetDef(id);
  477. assert(inst);
  478. return spvOpcodeIsSpecConstant(inst->opcode());
  479. }
  480. Instruction* ScalarReplacementPass::GetStorageType(
  481. const Instruction* inst) const {
  482. assert(inst->opcode() == SpvOpVariable);
  483. uint32_t ptrTypeId = inst->type_id();
  484. uint32_t typeId =
  485. get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
  486. return get_def_use_mgr()->GetDef(typeId);
  487. }
  488. bool ScalarReplacementPass::CanReplaceVariable(
  489. const Instruction* varInst) const {
  490. assert(varInst->opcode() == SpvOpVariable);
  491. // Can only replace function scope variables.
  492. if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
  493. return false;
  494. }
  495. if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
  496. return false;
  497. }
  498. const Instruction* typeInst = GetStorageType(varInst);
  499. if (!CheckType(typeInst)) {
  500. return false;
  501. }
  502. if (!CheckAnnotations(varInst)) {
  503. return false;
  504. }
  505. if (!CheckUses(varInst)) {
  506. return false;
  507. }
  508. return true;
  509. }
  510. bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
  511. if (!CheckTypeAnnotations(typeInst)) {
  512. return false;
  513. }
  514. switch (typeInst->opcode()) {
  515. case SpvOpTypeStruct:
  516. // Don't bother with empty structs or very large structs.
  517. if (typeInst->NumInOperands() == 0 ||
  518. IsLargerThanSizeLimit(typeInst->NumInOperands())) {
  519. return false;
  520. }
  521. return true;
  522. case SpvOpTypeArray:
  523. if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
  524. return false;
  525. }
  526. if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
  527. return false;
  528. }
  529. return true;
  530. // TODO(alanbaker): Develop some heuristics for when this should be
  531. // re-enabled.
  532. //// Specifically including matrix and vector in an attempt to reduce the
  533. //// number of vector registers required.
  534. // case SpvOpTypeMatrix:
  535. // case SpvOpTypeVector:
  536. // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
  537. // return true;
  538. case SpvOpTypeRuntimeArray:
  539. default:
  540. return false;
  541. }
  542. }
  543. bool ScalarReplacementPass::CheckTypeAnnotations(
  544. const Instruction* typeInst) const {
  545. for (auto inst :
  546. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  547. uint32_t decoration;
  548. if (inst->opcode() == SpvOpDecorate) {
  549. decoration = inst->GetSingleWordInOperand(1u);
  550. } else {
  551. assert(inst->opcode() == SpvOpMemberDecorate);
  552. decoration = inst->GetSingleWordInOperand(2u);
  553. }
  554. switch (decoration) {
  555. case SpvDecorationRowMajor:
  556. case SpvDecorationColMajor:
  557. case SpvDecorationArrayStride:
  558. case SpvDecorationMatrixStride:
  559. case SpvDecorationCPacked:
  560. case SpvDecorationInvariant:
  561. case SpvDecorationRestrict:
  562. case SpvDecorationOffset:
  563. case SpvDecorationAlignment:
  564. case SpvDecorationAlignmentId:
  565. case SpvDecorationMaxByteOffset:
  566. case SpvDecorationRelaxedPrecision:
  567. break;
  568. default:
  569. return false;
  570. }
  571. }
  572. return true;
  573. }
  574. bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
  575. for (auto inst :
  576. get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
  577. assert(inst->opcode() == SpvOpDecorate);
  578. uint32_t decoration = inst->GetSingleWordInOperand(1u);
  579. switch (decoration) {
  580. case SpvDecorationInvariant:
  581. case SpvDecorationRestrict:
  582. case SpvDecorationAlignment:
  583. case SpvDecorationAlignmentId:
  584. case SpvDecorationMaxByteOffset:
  585. break;
  586. default:
  587. return false;
  588. }
  589. }
  590. return true;
  591. }
  592. bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
  593. VariableStats stats = {0, 0};
  594. bool ok = CheckUses(inst, &stats);
  595. // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
  596. // SRoA is costly, such as when the structure has many (unaccessed?)
  597. // members.
  598. return ok;
  599. }
  600. bool ScalarReplacementPass::CheckUses(const Instruction* inst,
  601. VariableStats* stats) const {
  602. bool ok = true;
  603. get_def_use_mgr()->ForEachUse(
  604. inst, [this, stats, &ok](const Instruction* user, uint32_t index) {
  605. // Annotations are check as a group separately.
  606. if (!IsAnnotationInst(user->opcode())) {
  607. switch (user->opcode()) {
  608. case SpvOpAccessChain:
  609. case SpvOpInBoundsAccessChain:
  610. if (index == 2u && user->NumInOperands() > 1) {
  611. uint32_t id = user->GetSingleWordInOperand(1u);
  612. const Instruction* opInst = get_def_use_mgr()->GetDef(id);
  613. if (!IsCompileTimeConstantInst(opInst->opcode())) {
  614. ok = false;
  615. } else {
  616. if (!CheckUsesRelaxed(user)) ok = false;
  617. }
  618. stats->num_partial_accesses++;
  619. } else {
  620. ok = false;
  621. }
  622. break;
  623. case SpvOpLoad:
  624. if (!CheckLoad(user, index)) ok = false;
  625. stats->num_full_accesses++;
  626. break;
  627. case SpvOpStore:
  628. if (!CheckStore(user, index)) ok = false;
  629. stats->num_full_accesses++;
  630. break;
  631. case SpvOpName:
  632. case SpvOpMemberName:
  633. break;
  634. default:
  635. ok = false;
  636. break;
  637. }
  638. }
  639. });
  640. return ok;
  641. }
  642. bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
  643. bool ok = true;
  644. get_def_use_mgr()->ForEachUse(
  645. inst, [this, &ok](const Instruction* user, uint32_t index) {
  646. switch (user->opcode()) {
  647. case SpvOpAccessChain:
  648. case SpvOpInBoundsAccessChain:
  649. if (index != 2u) {
  650. ok = false;
  651. } else {
  652. if (!CheckUsesRelaxed(user)) ok = false;
  653. }
  654. break;
  655. case SpvOpLoad:
  656. if (!CheckLoad(user, index)) ok = false;
  657. break;
  658. case SpvOpStore:
  659. if (!CheckStore(user, index)) ok = false;
  660. break;
  661. default:
  662. ok = false;
  663. break;
  664. }
  665. });
  666. return ok;
  667. }
  668. bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
  669. uint32_t index) const {
  670. if (index != 2u) return false;
  671. if (inst->NumInOperands() >= 2 &&
  672. inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
  673. return false;
  674. return true;
  675. }
  676. bool ScalarReplacementPass::CheckStore(const Instruction* inst,
  677. uint32_t index) const {
  678. if (index != 0u) return false;
  679. if (inst->NumInOperands() >= 3 &&
  680. inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
  681. return false;
  682. return true;
  683. }
  684. bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
  685. if (max_num_elements_ == 0) {
  686. return false;
  687. }
  688. return length > max_num_elements_;
  689. }
  690. std::unique_ptr<std::unordered_set<int64_t>>
  691. ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
  692. std::unique_ptr<std::unordered_set<int64_t>> result(
  693. new std::unordered_set<int64_t>());
  694. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  695. def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
  696. this](Instruction* use) {
  697. switch (use->opcode()) {
  698. case SpvOpLoad: {
  699. // Look for extract from the load.
  700. std::vector<uint32_t> t;
  701. if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
  702. if (use2->opcode() != SpvOpCompositeExtract) {
  703. return false;
  704. }
  705. t.push_back(use2->GetSingleWordInOperand(1));
  706. return true;
  707. })) {
  708. result->insert(t.begin(), t.end());
  709. return true;
  710. } else {
  711. result.reset(nullptr);
  712. return false;
  713. }
  714. }
  715. case SpvOpName:
  716. case SpvOpMemberName:
  717. case SpvOpStore:
  718. // No components are used.
  719. return true;
  720. case SpvOpAccessChain:
  721. case SpvOpInBoundsAccessChain: {
  722. // Add the first index it if is a constant.
  723. // TODO: Could be improved by checking if the address is used in a load.
  724. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  725. uint32_t index_id = use->GetSingleWordInOperand(1);
  726. const analysis::Constant* index_const =
  727. const_mgr->FindDeclaredConstant(index_id);
  728. if (index_const) {
  729. result->insert(index_const->GetSignExtendedValue());
  730. return true;
  731. } else {
  732. // Could be any element. Assuming all are used.
  733. result.reset(nullptr);
  734. return false;
  735. }
  736. }
  737. default:
  738. // We do not know what is happening. Have to assume the worst.
  739. result.reset(nullptr);
  740. return false;
  741. }
  742. });
  743. return result;
  744. }
  745. Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
  746. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  747. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  748. const analysis::Type* type = type_mgr->GetType(type_id);
  749. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  750. Instruction* null_inst =
  751. const_mgr->GetDefiningInstruction(null_const, type_id);
  752. if (null_inst != nullptr) {
  753. context()->UpdateDefUse(null_inst);
  754. }
  755. return null_inst;
  756. }
  757. } // namespace opt
  758. } // namespace spvtools