scalar_replacement_pass.cpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/scalar_replacement_pass.h"
  15. #include <algorithm>
  16. #include <queue>
  17. #include <tuple>
  18. #include <utility>
  19. #include "source/enum_string_mapping.h"
  20. #include "source/extensions.h"
  21. #include "source/opt/reflect.h"
  22. #include "source/opt/types.h"
  23. #include "source/util/make_unique.h"
  24. namespace spvtools {
  25. namespace opt {
  26. Pass::Status ScalarReplacementPass::Process() {
  27. Status status = Status::SuccessWithoutChange;
  28. for (auto& f : *get_module()) {
  29. Status functionStatus = ProcessFunction(&f);
  30. if (functionStatus == Status::Failure)
  31. return functionStatus;
  32. else if (functionStatus == Status::SuccessWithChange)
  33. status = functionStatus;
  34. }
  35. return status;
  36. }
  37. Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
  38. std::queue<Instruction*> worklist;
  39. BasicBlock& entry = *function->begin();
  40. for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
  41. // Function storage class OpVariables must appear as the first instructions
  42. // of the entry block.
  43. if (iter->opcode() != SpvOpVariable) break;
  44. Instruction* varInst = &*iter;
  45. if (CanReplaceVariable(varInst)) {
  46. worklist.push(varInst);
  47. }
  48. }
  49. Status status = Status::SuccessWithoutChange;
  50. while (!worklist.empty()) {
  51. Instruction* varInst = worklist.front();
  52. worklist.pop();
  53. Status var_status = ReplaceVariable(varInst, &worklist);
  54. if (var_status == Status::Failure)
  55. return var_status;
  56. else if (var_status == Status::SuccessWithChange)
  57. status = var_status;
  58. }
  59. return status;
  60. }
  61. Pass::Status ScalarReplacementPass::ReplaceVariable(
  62. Instruction* inst, std::queue<Instruction*>* worklist) {
  63. std::vector<Instruction*> replacements;
  64. if (!CreateReplacementVariables(inst, &replacements)) {
  65. return Status::Failure;
  66. }
  67. std::vector<Instruction*> dead;
  68. bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
  69. inst, [this, &replacements, &dead](Instruction* user) {
  70. if (!IsAnnotationInst(user->opcode())) {
  71. switch (user->opcode()) {
  72. case SpvOpLoad:
  73. if (ReplaceWholeLoad(user, replacements)) {
  74. dead.push_back(user);
  75. } else {
  76. return false;
  77. }
  78. break;
  79. case SpvOpStore:
  80. if (ReplaceWholeStore(user, replacements)) {
  81. dead.push_back(user);
  82. } else {
  83. return false;
  84. }
  85. break;
  86. case SpvOpAccessChain:
  87. case SpvOpInBoundsAccessChain:
  88. if (ReplaceAccessChain(user, replacements))
  89. dead.push_back(user);
  90. else
  91. return false;
  92. break;
  93. case SpvOpName:
  94. case SpvOpMemberName:
  95. break;
  96. default:
  97. assert(false && "Unexpected opcode");
  98. break;
  99. }
  100. }
  101. return true;
  102. });
  103. if (replaced_all_uses) {
  104. dead.push_back(inst);
  105. } else {
  106. return Status::Failure;
  107. }
  108. // If there are no dead instructions to clean up, return with no changes.
  109. if (dead.empty()) return Status::SuccessWithoutChange;
  110. // Clean up some dead code.
  111. while (!dead.empty()) {
  112. Instruction* toKill = dead.back();
  113. dead.pop_back();
  114. context()->KillInst(toKill);
  115. }
  116. // Attempt to further scalarize.
  117. for (auto var : replacements) {
  118. if (var->opcode() == SpvOpVariable) {
  119. if (get_def_use_mgr()->NumUsers(var) == 0) {
  120. context()->KillInst(var);
  121. } else if (CanReplaceVariable(var)) {
  122. worklist->push(var);
  123. }
  124. }
  125. }
  126. return Status::SuccessWithChange;
  127. }
  128. bool ScalarReplacementPass::ReplaceWholeLoad(
  129. Instruction* load, const std::vector<Instruction*>& replacements) {
  130. // Replaces the load of the entire composite with a load from each replacement
  131. // variable followed by a composite construction.
  132. BasicBlock* block = context()->get_instr_block(load);
  133. std::vector<Instruction*> loads;
  134. loads.reserve(replacements.size());
  135. BasicBlock::iterator where(load);
  136. for (auto var : replacements) {
  137. // Create a load of each replacement variable.
  138. if (var->opcode() != SpvOpVariable) {
  139. loads.push_back(var);
  140. continue;
  141. }
  142. Instruction* type = GetStorageType(var);
  143. uint32_t loadId = TakeNextId();
  144. if (loadId == 0) {
  145. return false;
  146. }
  147. std::unique_ptr<Instruction> newLoad(
  148. new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
  149. std::initializer_list<Operand>{
  150. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  151. // Copy memory access attributes which start at index 1. Index 0 is the
  152. // pointer to load.
  153. for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
  154. Operand copy(load->GetInOperand(i));
  155. newLoad->AddOperand(std::move(copy));
  156. }
  157. where = where.InsertBefore(std::move(newLoad));
  158. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  159. context()->set_instr_block(&*where, block);
  160. loads.push_back(&*where);
  161. }
  162. // Construct a new composite.
  163. uint32_t compositeId = TakeNextId();
  164. if (compositeId == 0) {
  165. return false;
  166. }
  167. where = load;
  168. std::unique_ptr<Instruction> compositeConstruct(new Instruction(
  169. context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
  170. for (auto l : loads) {
  171. Operand op(SPV_OPERAND_TYPE_ID,
  172. std::initializer_list<uint32_t>{l->result_id()});
  173. compositeConstruct->AddOperand(std::move(op));
  174. }
  175. where = where.InsertBefore(std::move(compositeConstruct));
  176. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  177. context()->set_instr_block(&*where, block);
  178. context()->ReplaceAllUsesWith(load->result_id(), compositeId);
  179. return true;
  180. }
  181. bool ScalarReplacementPass::ReplaceWholeStore(
  182. Instruction* store, const std::vector<Instruction*>& replacements) {
  183. // Replaces a store to the whole composite with a series of extract and stores
  184. // to each element.
  185. uint32_t storeInput = store->GetSingleWordInOperand(1u);
  186. BasicBlock* block = context()->get_instr_block(store);
  187. BasicBlock::iterator where(store);
  188. uint32_t elementIndex = 0;
  189. for (auto var : replacements) {
  190. // Create the extract.
  191. if (var->opcode() != SpvOpVariable) {
  192. elementIndex++;
  193. continue;
  194. }
  195. Instruction* type = GetStorageType(var);
  196. uint32_t extractId = TakeNextId();
  197. if (extractId == 0) {
  198. return false;
  199. }
  200. std::unique_ptr<Instruction> extract(new Instruction(
  201. context(), SpvOpCompositeExtract, type->result_id(), extractId,
  202. std::initializer_list<Operand>{
  203. {SPV_OPERAND_TYPE_ID, {storeInput}},
  204. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
  205. auto iter = where.InsertBefore(std::move(extract));
  206. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  207. context()->set_instr_block(&*iter, block);
  208. // Create the store.
  209. std::unique_ptr<Instruction> newStore(
  210. new Instruction(context(), SpvOpStore, 0, 0,
  211. std::initializer_list<Operand>{
  212. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  213. {SPV_OPERAND_TYPE_ID, {extractId}}}));
  214. // Copy memory access attributes which start at index 2. Index 0 is the
  215. // pointer and index 1 is the data.
  216. for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
  217. Operand copy(store->GetInOperand(i));
  218. newStore->AddOperand(std::move(copy));
  219. }
  220. iter = where.InsertBefore(std::move(newStore));
  221. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  222. context()->set_instr_block(&*iter, block);
  223. }
  224. return true;
  225. }
  226. bool ScalarReplacementPass::ReplaceAccessChain(
  227. Instruction* chain, const std::vector<Instruction*>& replacements) {
  228. // Replaces the access chain with either another access chain (with one fewer
  229. // indexes) or a direct use of the replacement variable.
  230. uint32_t indexId = chain->GetSingleWordInOperand(1u);
  231. const Instruction* index = get_def_use_mgr()->GetDef(indexId);
  232. int64_t indexValue = context()
  233. ->get_constant_mgr()
  234. ->GetConstantFromInst(index)
  235. ->GetSignExtendedValue();
  236. if (indexValue < 0 ||
  237. indexValue >= static_cast<int64_t>(replacements.size())) {
  238. // Out of bounds access, this is illegal IR. Notice that OpAccessChain
  239. // indexing is 0-based, so we should also reject index == size-of-array.
  240. return false;
  241. } else {
  242. const Instruction* var = replacements[static_cast<size_t>(indexValue)];
  243. if (chain->NumInOperands() > 2) {
  244. // Replace input access chain with another access chain.
  245. BasicBlock::iterator chainIter(chain);
  246. uint32_t replacementId = TakeNextId();
  247. if (replacementId == 0) {
  248. return false;
  249. }
  250. std::unique_ptr<Instruction> replacementChain(new Instruction(
  251. context(), chain->opcode(), chain->type_id(), replacementId,
  252. std::initializer_list<Operand>{
  253. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  254. // Add the remaining indexes.
  255. for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
  256. Operand copy(chain->GetInOperand(i));
  257. replacementChain->AddOperand(std::move(copy));
  258. }
  259. auto iter = chainIter.InsertBefore(std::move(replacementChain));
  260. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  261. context()->set_instr_block(&*iter, context()->get_instr_block(chain));
  262. context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
  263. } else {
  264. // Replace with a use of the variable.
  265. context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
  266. }
  267. }
  268. return true;
  269. }
  270. bool ScalarReplacementPass::CreateReplacementVariables(
  271. Instruction* inst, std::vector<Instruction*>* replacements) {
  272. Instruction* type = GetStorageType(inst);
  273. std::unique_ptr<std::unordered_set<int64_t>> components_used =
  274. GetUsedComponents(inst);
  275. uint32_t elem = 0;
  276. switch (type->opcode()) {
  277. case SpvOpTypeStruct:
  278. type->ForEachInOperand(
  279. [this, inst, &elem, replacements, &components_used](uint32_t* id) {
  280. if (!components_used || components_used->count(elem)) {
  281. CreateVariable(*id, inst, elem, replacements);
  282. } else {
  283. replacements->push_back(CreateNullConstant(*id));
  284. }
  285. elem++;
  286. });
  287. break;
  288. case SpvOpTypeArray:
  289. for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
  290. if (!components_used || components_used->count(i)) {
  291. CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
  292. replacements);
  293. } else {
  294. replacements->push_back(
  295. CreateNullConstant(type->GetSingleWordInOperand(0u)));
  296. }
  297. }
  298. break;
  299. case SpvOpTypeMatrix:
  300. case SpvOpTypeVector:
  301. for (uint32_t i = 0; i != GetNumElements(type); ++i) {
  302. CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
  303. }
  304. break;
  305. default:
  306. assert(false && "Unexpected type.");
  307. break;
  308. }
  309. TransferAnnotations(inst, replacements);
  310. return std::find(replacements->begin(), replacements->end(), nullptr) ==
  311. replacements->end();
  312. }
  313. void ScalarReplacementPass::TransferAnnotations(
  314. const Instruction* source, std::vector<Instruction*>* replacements) {
  315. // Only transfer invariant and restrict decorations on the variable. There are
  316. // no type or member decorations that are necessary to transfer.
  317. for (auto inst :
  318. get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
  319. assert(inst->opcode() == SpvOpDecorate);
  320. uint32_t decoration = inst->GetSingleWordInOperand(1u);
  321. if (decoration == SpvDecorationInvariant ||
  322. decoration == SpvDecorationRestrict) {
  323. for (auto var : *replacements) {
  324. if (var == nullptr) {
  325. continue;
  326. }
  327. std::unique_ptr<Instruction> annotation(
  328. new Instruction(context(), SpvOpDecorate, 0, 0,
  329. std::initializer_list<Operand>{
  330. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  331. {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
  332. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  333. Operand copy(inst->GetInOperand(i));
  334. annotation->AddOperand(std::move(copy));
  335. }
  336. context()->AddAnnotationInst(std::move(annotation));
  337. get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
  338. }
  339. }
  340. }
  341. }
  342. void ScalarReplacementPass::CreateVariable(
  343. uint32_t typeId, Instruction* varInst, uint32_t index,
  344. std::vector<Instruction*>* replacements) {
  345. uint32_t ptrId = GetOrCreatePointerType(typeId);
  346. uint32_t id = TakeNextId();
  347. if (id == 0) {
  348. replacements->push_back(nullptr);
  349. }
  350. std::unique_ptr<Instruction> variable(new Instruction(
  351. context(), SpvOpVariable, ptrId, id,
  352. std::initializer_list<Operand>{
  353. {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
  354. BasicBlock* block = context()->get_instr_block(varInst);
  355. block->begin().InsertBefore(std::move(variable));
  356. Instruction* inst = &*block->begin();
  357. // If varInst was initialized, make sure to initialize its replacement.
  358. GetOrCreateInitialValue(varInst, index, inst);
  359. get_def_use_mgr()->AnalyzeInstDefUse(inst);
  360. context()->set_instr_block(inst, block);
  361. // Copy decorations from the member to the new variable.
  362. Instruction* typeInst = GetStorageType(varInst);
  363. for (auto dec_inst :
  364. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  365. uint32_t decoration;
  366. if (dec_inst->opcode() != SpvOpMemberDecorate) {
  367. continue;
  368. }
  369. if (dec_inst->GetSingleWordInOperand(1) != index) {
  370. continue;
  371. }
  372. decoration = dec_inst->GetSingleWordInOperand(2u);
  373. switch (decoration) {
  374. case SpvDecorationRelaxedPrecision: {
  375. std::unique_ptr<Instruction> new_dec_inst(
  376. new Instruction(context(), SpvOpDecorate, 0, 0, {}));
  377. new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
  378. for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
  379. new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
  380. }
  381. context()->AddAnnotationInst(std::move(new_dec_inst));
  382. } break;
  383. default:
  384. break;
  385. }
  386. }
  387. replacements->push_back(inst);
  388. }
  389. uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
  390. auto iter = pointee_to_pointer_.find(id);
  391. if (iter != pointee_to_pointer_.end()) return iter->second;
  392. analysis::Type* pointeeTy;
  393. std::unique_ptr<analysis::Pointer> pointerTy;
  394. std::tie(pointeeTy, pointerTy) =
  395. context()->get_type_mgr()->GetTypeAndPointerType(id,
  396. SpvStorageClassFunction);
  397. uint32_t ptrId = 0;
  398. if (pointeeTy->IsUniqueType()) {
  399. // Non-ambiguous type, just ask the type manager for an id.
  400. ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
  401. pointee_to_pointer_[id] = ptrId;
  402. return ptrId;
  403. }
  404. // Ambiguous type. We must perform a linear search to try and find the right
  405. // type.
  406. for (auto global : context()->types_values()) {
  407. if (global.opcode() == SpvOpTypePointer &&
  408. global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
  409. global.GetSingleWordInOperand(1u) == id) {
  410. if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
  411. // Only reuse a decoration-less pointer of the correct type.
  412. ptrId = global.result_id();
  413. break;
  414. }
  415. }
  416. }
  417. if (ptrId != 0) {
  418. pointee_to_pointer_[id] = ptrId;
  419. return ptrId;
  420. }
  421. ptrId = TakeNextId();
  422. context()->AddType(MakeUnique<Instruction>(
  423. context(), SpvOpTypePointer, 0, ptrId,
  424. std::initializer_list<Operand>{
  425. {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
  426. {SPV_OPERAND_TYPE_ID, {id}}}));
  427. Instruction* ptr = &*--context()->types_values_end();
  428. get_def_use_mgr()->AnalyzeInstDefUse(ptr);
  429. pointee_to_pointer_[id] = ptrId;
  430. // Register with the type manager if necessary.
  431. context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
  432. return ptrId;
  433. }
  434. void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
  435. uint32_t index,
  436. Instruction* newVar) {
  437. assert(source->opcode() == SpvOpVariable);
  438. if (source->NumInOperands() < 2) return;
  439. uint32_t initId = source->GetSingleWordInOperand(1u);
  440. uint32_t storageId = GetStorageType(newVar)->result_id();
  441. Instruction* init = get_def_use_mgr()->GetDef(initId);
  442. uint32_t newInitId = 0;
  443. // TODO(dnovillo): Refactor this with constant propagation.
  444. if (init->opcode() == SpvOpConstantNull) {
  445. // Initialize to appropriate NULL.
  446. auto iter = type_to_null_.find(storageId);
  447. if (iter == type_to_null_.end()) {
  448. newInitId = TakeNextId();
  449. type_to_null_[storageId] = newInitId;
  450. context()->AddGlobalValue(
  451. MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
  452. newInitId, std::initializer_list<Operand>{}));
  453. Instruction* newNull = &*--context()->types_values_end();
  454. get_def_use_mgr()->AnalyzeInstDefUse(newNull);
  455. } else {
  456. newInitId = iter->second;
  457. }
  458. } else if (IsSpecConstantInst(init->opcode())) {
  459. // Create a new constant extract.
  460. newInitId = TakeNextId();
  461. context()->AddGlobalValue(MakeUnique<Instruction>(
  462. context(), SpvOpSpecConstantOp, storageId, newInitId,
  463. std::initializer_list<Operand>{
  464. {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
  465. {SPV_OPERAND_TYPE_ID, {init->result_id()}},
  466. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
  467. Instruction* newSpecConst = &*--context()->types_values_end();
  468. get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
  469. } else if (init->opcode() == SpvOpConstantComposite) {
  470. // Get the appropriate index constant.
  471. newInitId = init->GetSingleWordInOperand(index);
  472. Instruction* element = get_def_use_mgr()->GetDef(newInitId);
  473. if (element->opcode() == SpvOpUndef) {
  474. // Undef is not a valid initializer for a variable.
  475. newInitId = 0;
  476. }
  477. } else {
  478. assert(false);
  479. }
  480. if (newInitId != 0) {
  481. newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
  482. }
  483. }
  484. uint64_t ScalarReplacementPass::GetArrayLength(
  485. const Instruction* arrayType) const {
  486. assert(arrayType->opcode() == SpvOpTypeArray);
  487. const Instruction* length =
  488. get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
  489. return context()
  490. ->get_constant_mgr()
  491. ->GetConstantFromInst(length)
  492. ->GetZeroExtendedValue();
  493. }
  494. uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
  495. assert(type->opcode() == SpvOpTypeVector ||
  496. type->opcode() == SpvOpTypeMatrix);
  497. const Operand& op = type->GetInOperand(1u);
  498. assert(op.words.size() <= 2);
  499. uint64_t len = 0;
  500. for (size_t i = 0; i != op.words.size(); ++i) {
  501. len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
  502. }
  503. return len;
  504. }
  505. bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
  506. const Instruction* inst = get_def_use_mgr()->GetDef(id);
  507. assert(inst);
  508. return spvOpcodeIsSpecConstant(inst->opcode());
  509. }
  510. Instruction* ScalarReplacementPass::GetStorageType(
  511. const Instruction* inst) const {
  512. assert(inst->opcode() == SpvOpVariable);
  513. uint32_t ptrTypeId = inst->type_id();
  514. uint32_t typeId =
  515. get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
  516. return get_def_use_mgr()->GetDef(typeId);
  517. }
  518. bool ScalarReplacementPass::CanReplaceVariable(
  519. const Instruction* varInst) const {
  520. assert(varInst->opcode() == SpvOpVariable);
  521. // Can only replace function scope variables.
  522. if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
  523. return false;
  524. }
  525. if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
  526. return false;
  527. }
  528. const Instruction* typeInst = GetStorageType(varInst);
  529. if (!CheckType(typeInst)) {
  530. return false;
  531. }
  532. if (!CheckAnnotations(varInst)) {
  533. return false;
  534. }
  535. if (!CheckUses(varInst)) {
  536. return false;
  537. }
  538. return true;
  539. }
  540. bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
  541. if (!CheckTypeAnnotations(typeInst)) {
  542. return false;
  543. }
  544. switch (typeInst->opcode()) {
  545. case SpvOpTypeStruct:
  546. // Don't bother with empty structs or very large structs.
  547. if (typeInst->NumInOperands() == 0 ||
  548. IsLargerThanSizeLimit(typeInst->NumInOperands())) {
  549. return false;
  550. }
  551. return true;
  552. case SpvOpTypeArray:
  553. if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
  554. return false;
  555. }
  556. if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
  557. return false;
  558. }
  559. return true;
  560. // TODO(alanbaker): Develop some heuristics for when this should be
  561. // re-enabled.
  562. //// Specifically including matrix and vector in an attempt to reduce the
  563. //// number of vector registers required.
  564. // case SpvOpTypeMatrix:
  565. // case SpvOpTypeVector:
  566. // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
  567. // return true;
  568. case SpvOpTypeRuntimeArray:
  569. default:
  570. return false;
  571. }
  572. }
  573. bool ScalarReplacementPass::CheckTypeAnnotations(
  574. const Instruction* typeInst) const {
  575. for (auto inst :
  576. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  577. uint32_t decoration;
  578. if (inst->opcode() == SpvOpDecorate) {
  579. decoration = inst->GetSingleWordInOperand(1u);
  580. } else {
  581. assert(inst->opcode() == SpvOpMemberDecorate);
  582. decoration = inst->GetSingleWordInOperand(2u);
  583. }
  584. switch (decoration) {
  585. case SpvDecorationRowMajor:
  586. case SpvDecorationColMajor:
  587. case SpvDecorationArrayStride:
  588. case SpvDecorationMatrixStride:
  589. case SpvDecorationCPacked:
  590. case SpvDecorationInvariant:
  591. case SpvDecorationRestrict:
  592. case SpvDecorationOffset:
  593. case SpvDecorationAlignment:
  594. case SpvDecorationAlignmentId:
  595. case SpvDecorationMaxByteOffset:
  596. case SpvDecorationRelaxedPrecision:
  597. break;
  598. default:
  599. return false;
  600. }
  601. }
  602. return true;
  603. }
  604. bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
  605. for (auto inst :
  606. get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
  607. assert(inst->opcode() == SpvOpDecorate);
  608. uint32_t decoration = inst->GetSingleWordInOperand(1u);
  609. switch (decoration) {
  610. case SpvDecorationInvariant:
  611. case SpvDecorationRestrict:
  612. case SpvDecorationAlignment:
  613. case SpvDecorationAlignmentId:
  614. case SpvDecorationMaxByteOffset:
  615. break;
  616. default:
  617. return false;
  618. }
  619. }
  620. return true;
  621. }
  622. bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
  623. VariableStats stats = {0, 0};
  624. bool ok = CheckUses(inst, &stats);
  625. // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
  626. // SRoA is costly, such as when the structure has many (unaccessed?)
  627. // members.
  628. return ok;
  629. }
  630. bool ScalarReplacementPass::CheckUses(const Instruction* inst,
  631. VariableStats* stats) const {
  632. uint64_t max_legal_index = GetMaxLegalIndex(inst);
  633. bool ok = true;
  634. get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
  635. const Instruction* user,
  636. uint32_t index) {
  637. // Annotations are check as a group separately.
  638. if (!IsAnnotationInst(user->opcode())) {
  639. switch (user->opcode()) {
  640. case SpvOpAccessChain:
  641. case SpvOpInBoundsAccessChain:
  642. if (index == 2u && user->NumInOperands() > 1) {
  643. uint32_t id = user->GetSingleWordInOperand(1u);
  644. const Instruction* opInst = get_def_use_mgr()->GetDef(id);
  645. const auto* constant =
  646. context()->get_constant_mgr()->GetConstantFromInst(opInst);
  647. if (!constant) {
  648. ok = false;
  649. } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
  650. ok = false;
  651. } else {
  652. if (!CheckUsesRelaxed(user)) ok = false;
  653. }
  654. stats->num_partial_accesses++;
  655. } else {
  656. ok = false;
  657. }
  658. break;
  659. case SpvOpLoad:
  660. if (!CheckLoad(user, index)) ok = false;
  661. stats->num_full_accesses++;
  662. break;
  663. case SpvOpStore:
  664. if (!CheckStore(user, index)) ok = false;
  665. stats->num_full_accesses++;
  666. break;
  667. case SpvOpName:
  668. case SpvOpMemberName:
  669. break;
  670. default:
  671. ok = false;
  672. break;
  673. }
  674. }
  675. });
  676. return ok;
  677. }
  678. bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
  679. bool ok = true;
  680. get_def_use_mgr()->ForEachUse(
  681. inst, [this, &ok](const Instruction* user, uint32_t index) {
  682. switch (user->opcode()) {
  683. case SpvOpAccessChain:
  684. case SpvOpInBoundsAccessChain:
  685. if (index != 2u) {
  686. ok = false;
  687. } else {
  688. if (!CheckUsesRelaxed(user)) ok = false;
  689. }
  690. break;
  691. case SpvOpLoad:
  692. if (!CheckLoad(user, index)) ok = false;
  693. break;
  694. case SpvOpStore:
  695. if (!CheckStore(user, index)) ok = false;
  696. break;
  697. default:
  698. ok = false;
  699. break;
  700. }
  701. });
  702. return ok;
  703. }
  704. bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
  705. uint32_t index) const {
  706. if (index != 2u) return false;
  707. if (inst->NumInOperands() >= 2 &&
  708. inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
  709. return false;
  710. return true;
  711. }
  712. bool ScalarReplacementPass::CheckStore(const Instruction* inst,
  713. uint32_t index) const {
  714. if (index != 0u) return false;
  715. if (inst->NumInOperands() >= 3 &&
  716. inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
  717. return false;
  718. return true;
  719. }
  720. bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
  721. if (max_num_elements_ == 0) {
  722. return false;
  723. }
  724. return length > max_num_elements_;
  725. }
  726. std::unique_ptr<std::unordered_set<int64_t>>
  727. ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
  728. std::unique_ptr<std::unordered_set<int64_t>> result(
  729. new std::unordered_set<int64_t>());
  730. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  731. def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
  732. this](Instruction* use) {
  733. switch (use->opcode()) {
  734. case SpvOpLoad: {
  735. // Look for extract from the load.
  736. std::vector<uint32_t> t;
  737. if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
  738. if (use2->opcode() != SpvOpCompositeExtract ||
  739. use2->NumInOperands() <= 1) {
  740. return false;
  741. }
  742. t.push_back(use2->GetSingleWordInOperand(1));
  743. return true;
  744. })) {
  745. result->insert(t.begin(), t.end());
  746. return true;
  747. } else {
  748. result.reset(nullptr);
  749. return false;
  750. }
  751. }
  752. case SpvOpName:
  753. case SpvOpMemberName:
  754. case SpvOpStore:
  755. // No components are used.
  756. return true;
  757. case SpvOpAccessChain:
  758. case SpvOpInBoundsAccessChain: {
  759. // Add the first index it if is a constant.
  760. // TODO: Could be improved by checking if the address is used in a load.
  761. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  762. uint32_t index_id = use->GetSingleWordInOperand(1);
  763. const analysis::Constant* index_const =
  764. const_mgr->FindDeclaredConstant(index_id);
  765. if (index_const) {
  766. result->insert(index_const->GetSignExtendedValue());
  767. return true;
  768. } else {
  769. // Could be any element. Assuming all are used.
  770. result.reset(nullptr);
  771. return false;
  772. }
  773. }
  774. default:
  775. // We do not know what is happening. Have to assume the worst.
  776. result.reset(nullptr);
  777. return false;
  778. }
  779. });
  780. return result;
  781. }
  782. Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
  783. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  784. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  785. const analysis::Type* type = type_mgr->GetType(type_id);
  786. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  787. Instruction* null_inst =
  788. const_mgr->GetDefiningInstruction(null_const, type_id);
  789. if (null_inst != nullptr) {
  790. context()->UpdateDefUse(null_inst);
  791. }
  792. return null_inst;
  793. }
  794. uint64_t ScalarReplacementPass::GetMaxLegalIndex(
  795. const Instruction* var_inst) const {
  796. assert(var_inst->opcode() == SpvOpVariable &&
  797. "|var_inst| must be a variable instruction.");
  798. Instruction* type = GetStorageType(var_inst);
  799. switch (type->opcode()) {
  800. case SpvOpTypeStruct:
  801. return type->NumInOperands();
  802. case SpvOpTypeArray:
  803. return GetArrayLength(type);
  804. case SpvOpTypeMatrix:
  805. case SpvOpTypeVector:
  806. return GetNumElements(type);
  807. default:
  808. return 0;
  809. }
  810. return 0;
  811. }
  812. } // namespace opt
  813. } // namespace spvtools