scalar_replacement_pass.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/scalar_replacement_pass.h"
  15. #include <algorithm>
  16. #include <queue>
  17. #include <tuple>
  18. #include <utility>
  19. #include "source/enum_string_mapping.h"
  20. #include "source/extensions.h"
  21. #include "source/opt/reflect.h"
  22. #include "source/opt/types.h"
  23. #include "source/util/make_unique.h"
  24. static const uint32_t kDebugValueOperandValueIndex = 5;
  25. static const uint32_t kDebugValueOperandExpressionIndex = 6;
  26. namespace spvtools {
  27. namespace opt {
  28. Pass::Status ScalarReplacementPass::Process() {
  29. Status status = Status::SuccessWithoutChange;
  30. for (auto& f : *get_module()) {
  31. Status functionStatus = ProcessFunction(&f);
  32. if (functionStatus == Status::Failure)
  33. return functionStatus;
  34. else if (functionStatus == Status::SuccessWithChange)
  35. status = functionStatus;
  36. }
  37. return status;
  38. }
  39. Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
  40. std::queue<Instruction*> worklist;
  41. BasicBlock& entry = *function->begin();
  42. for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
  43. // Function storage class OpVariables must appear as the first instructions
  44. // of the entry block.
  45. if (iter->opcode() != SpvOpVariable) break;
  46. Instruction* varInst = &*iter;
  47. if (CanReplaceVariable(varInst)) {
  48. worklist.push(varInst);
  49. }
  50. }
  51. Status status = Status::SuccessWithoutChange;
  52. while (!worklist.empty()) {
  53. Instruction* varInst = worklist.front();
  54. worklist.pop();
  55. Status var_status = ReplaceVariable(varInst, &worklist);
  56. if (var_status == Status::Failure)
  57. return var_status;
  58. else if (var_status == Status::SuccessWithChange)
  59. status = var_status;
  60. }
  61. return status;
  62. }
  63. Pass::Status ScalarReplacementPass::ReplaceVariable(
  64. Instruction* inst, std::queue<Instruction*>* worklist) {
  65. std::vector<Instruction*> replacements;
  66. if (!CreateReplacementVariables(inst, &replacements)) {
  67. return Status::Failure;
  68. }
  69. std::vector<Instruction*> dead;
  70. bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
  71. inst, [this, &replacements, &dead](Instruction* user) {
  72. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
  73. if (ReplaceWholeDebugDeclare(user, replacements)) {
  74. dead.push_back(user);
  75. return true;
  76. }
  77. return false;
  78. }
  79. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
  80. if (ReplaceWholeDebugValue(user, replacements)) {
  81. dead.push_back(user);
  82. return true;
  83. }
  84. return false;
  85. }
  86. if (!IsAnnotationInst(user->opcode())) {
  87. switch (user->opcode()) {
  88. case SpvOpLoad:
  89. if (ReplaceWholeLoad(user, replacements)) {
  90. dead.push_back(user);
  91. } else {
  92. return false;
  93. }
  94. break;
  95. case SpvOpStore:
  96. if (ReplaceWholeStore(user, replacements)) {
  97. dead.push_back(user);
  98. } else {
  99. return false;
  100. }
  101. break;
  102. case SpvOpAccessChain:
  103. case SpvOpInBoundsAccessChain:
  104. if (ReplaceAccessChain(user, replacements))
  105. dead.push_back(user);
  106. else
  107. return false;
  108. break;
  109. case SpvOpName:
  110. case SpvOpMemberName:
  111. break;
  112. default:
  113. assert(false && "Unexpected opcode");
  114. break;
  115. }
  116. }
  117. return true;
  118. });
  119. if (replaced_all_uses) {
  120. dead.push_back(inst);
  121. } else {
  122. return Status::Failure;
  123. }
  124. // If there are no dead instructions to clean up, return with no changes.
  125. if (dead.empty()) return Status::SuccessWithoutChange;
  126. // Clean up some dead code.
  127. while (!dead.empty()) {
  128. Instruction* toKill = dead.back();
  129. dead.pop_back();
  130. context()->KillInst(toKill);
  131. }
  132. // Attempt to further scalarize.
  133. for (auto var : replacements) {
  134. if (var->opcode() == SpvOpVariable) {
  135. if (get_def_use_mgr()->NumUsers(var) == 0) {
  136. context()->KillInst(var);
  137. } else if (CanReplaceVariable(var)) {
  138. worklist->push(var);
  139. }
  140. }
  141. }
  142. return Status::SuccessWithChange;
  143. }
  144. bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
  145. Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
  146. // Insert Deref operation to the front of the operation list of |dbg_decl|.
  147. Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
  148. dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
  149. auto* deref_expr =
  150. context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
  151. // Add DebugValue instruction with Indexes operand and Deref operation.
  152. int32_t idx = 0;
  153. for (const auto* var : replacements) {
  154. Instruction* added_dbg_value =
  155. context()->get_debug_info_mgr()->AddDebugValueForDecl(
  156. dbg_decl, /*value_id=*/var->result_id(),
  157. /*insert_before=*/var->NextNode(), /*scope_and_line=*/dbg_decl);
  158. if (added_dbg_value == nullptr) return false;
  159. added_dbg_value->AddOperand(
  160. {SPV_OPERAND_TYPE_ID,
  161. {context()->get_constant_mgr()->GetSIntConst(idx)}});
  162. added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
  163. {deref_expr->result_id()});
  164. if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
  165. context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
  166. }
  167. ++idx;
  168. }
  169. return true;
  170. }
  171. bool ScalarReplacementPass::ReplaceWholeDebugValue(
  172. Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
  173. int32_t idx = 0;
  174. BasicBlock* block = context()->get_instr_block(dbg_value);
  175. for (auto var : replacements) {
  176. // Clone the DebugValue.
  177. std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
  178. uint32_t new_id = TakeNextId();
  179. if (new_id == 0) return false;
  180. new_dbg_value->SetResultId(new_id);
  181. // Update 'Value' operand to the |replacements|.
  182. new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
  183. // Append 'Indexes' operand.
  184. new_dbg_value->AddOperand(
  185. {SPV_OPERAND_TYPE_ID,
  186. {context()->get_constant_mgr()->GetSIntConst(idx)}});
  187. // Insert the new DebugValue to the basic block.
  188. auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
  189. get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
  190. context()->set_instr_block(added_instr, block);
  191. ++idx;
  192. }
  193. return true;
  194. }
  195. bool ScalarReplacementPass::ReplaceWholeLoad(
  196. Instruction* load, const std::vector<Instruction*>& replacements) {
  197. // Replaces the load of the entire composite with a load from each replacement
  198. // variable followed by a composite construction.
  199. BasicBlock* block = context()->get_instr_block(load);
  200. std::vector<Instruction*> loads;
  201. loads.reserve(replacements.size());
  202. BasicBlock::iterator where(load);
  203. for (auto var : replacements) {
  204. // Create a load of each replacement variable.
  205. if (var->opcode() != SpvOpVariable) {
  206. loads.push_back(var);
  207. continue;
  208. }
  209. Instruction* type = GetStorageType(var);
  210. uint32_t loadId = TakeNextId();
  211. if (loadId == 0) {
  212. return false;
  213. }
  214. std::unique_ptr<Instruction> newLoad(
  215. new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
  216. std::initializer_list<Operand>{
  217. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  218. // Copy memory access attributes which start at index 1. Index 0 is the
  219. // pointer to load.
  220. for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
  221. Operand copy(load->GetInOperand(i));
  222. newLoad->AddOperand(std::move(copy));
  223. }
  224. where = where.InsertBefore(std::move(newLoad));
  225. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  226. context()->set_instr_block(&*where, block);
  227. where->UpdateDebugInfoFrom(load);
  228. loads.push_back(&*where);
  229. }
  230. // Construct a new composite.
  231. uint32_t compositeId = TakeNextId();
  232. if (compositeId == 0) {
  233. return false;
  234. }
  235. where = load;
  236. std::unique_ptr<Instruction> compositeConstruct(new Instruction(
  237. context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
  238. for (auto l : loads) {
  239. Operand op(SPV_OPERAND_TYPE_ID,
  240. std::initializer_list<uint32_t>{l->result_id()});
  241. compositeConstruct->AddOperand(std::move(op));
  242. }
  243. where = where.InsertBefore(std::move(compositeConstruct));
  244. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  245. where->UpdateDebugInfoFrom(load);
  246. context()->set_instr_block(&*where, block);
  247. context()->ReplaceAllUsesWith(load->result_id(), compositeId);
  248. return true;
  249. }
  250. bool ScalarReplacementPass::ReplaceWholeStore(
  251. Instruction* store, const std::vector<Instruction*>& replacements) {
  252. // Replaces a store to the whole composite with a series of extract and stores
  253. // to each element.
  254. uint32_t storeInput = store->GetSingleWordInOperand(1u);
  255. BasicBlock* block = context()->get_instr_block(store);
  256. BasicBlock::iterator where(store);
  257. uint32_t elementIndex = 0;
  258. for (auto var : replacements) {
  259. // Create the extract.
  260. if (var->opcode() != SpvOpVariable) {
  261. elementIndex++;
  262. continue;
  263. }
  264. Instruction* type = GetStorageType(var);
  265. uint32_t extractId = TakeNextId();
  266. if (extractId == 0) {
  267. return false;
  268. }
  269. std::unique_ptr<Instruction> extract(new Instruction(
  270. context(), SpvOpCompositeExtract, type->result_id(), extractId,
  271. std::initializer_list<Operand>{
  272. {SPV_OPERAND_TYPE_ID, {storeInput}},
  273. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
  274. auto iter = where.InsertBefore(std::move(extract));
  275. iter->UpdateDebugInfoFrom(store);
  276. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  277. context()->set_instr_block(&*iter, block);
  278. // Create the store.
  279. std::unique_ptr<Instruction> newStore(
  280. new Instruction(context(), SpvOpStore, 0, 0,
  281. std::initializer_list<Operand>{
  282. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  283. {SPV_OPERAND_TYPE_ID, {extractId}}}));
  284. // Copy memory access attributes which start at index 2. Index 0 is the
  285. // pointer and index 1 is the data.
  286. for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
  287. Operand copy(store->GetInOperand(i));
  288. newStore->AddOperand(std::move(copy));
  289. }
  290. iter = where.InsertBefore(std::move(newStore));
  291. iter->UpdateDebugInfoFrom(store);
  292. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  293. context()->set_instr_block(&*iter, block);
  294. }
  295. return true;
  296. }
  297. bool ScalarReplacementPass::ReplaceAccessChain(
  298. Instruction* chain, const std::vector<Instruction*>& replacements) {
  299. // Replaces the access chain with either another access chain (with one fewer
  300. // indexes) or a direct use of the replacement variable.
  301. uint32_t indexId = chain->GetSingleWordInOperand(1u);
  302. const Instruction* index = get_def_use_mgr()->GetDef(indexId);
  303. int64_t indexValue = context()
  304. ->get_constant_mgr()
  305. ->GetConstantFromInst(index)
  306. ->GetSignExtendedValue();
  307. if (indexValue < 0 ||
  308. indexValue >= static_cast<int64_t>(replacements.size())) {
  309. // Out of bounds access, this is illegal IR. Notice that OpAccessChain
  310. // indexing is 0-based, so we should also reject index == size-of-array.
  311. return false;
  312. } else {
  313. const Instruction* var = replacements[static_cast<size_t>(indexValue)];
  314. if (chain->NumInOperands() > 2) {
  315. // Replace input access chain with another access chain.
  316. BasicBlock::iterator chainIter(chain);
  317. uint32_t replacementId = TakeNextId();
  318. if (replacementId == 0) {
  319. return false;
  320. }
  321. std::unique_ptr<Instruction> replacementChain(new Instruction(
  322. context(), chain->opcode(), chain->type_id(), replacementId,
  323. std::initializer_list<Operand>{
  324. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  325. // Add the remaining indexes.
  326. for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
  327. Operand copy(chain->GetInOperand(i));
  328. replacementChain->AddOperand(std::move(copy));
  329. }
  330. replacementChain->UpdateDebugInfoFrom(chain);
  331. auto iter = chainIter.InsertBefore(std::move(replacementChain));
  332. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  333. context()->set_instr_block(&*iter, context()->get_instr_block(chain));
  334. context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
  335. } else {
  336. // Replace with a use of the variable.
  337. context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
  338. }
  339. }
  340. return true;
  341. }
  342. bool ScalarReplacementPass::CreateReplacementVariables(
  343. Instruction* inst, std::vector<Instruction*>* replacements) {
  344. Instruction* type = GetStorageType(inst);
  345. std::unique_ptr<std::unordered_set<int64_t>> components_used =
  346. GetUsedComponents(inst);
  347. uint32_t elem = 0;
  348. switch (type->opcode()) {
  349. case SpvOpTypeStruct:
  350. type->ForEachInOperand(
  351. [this, inst, &elem, replacements, &components_used](uint32_t* id) {
  352. if (!components_used || components_used->count(elem)) {
  353. CreateVariable(*id, inst, elem, replacements);
  354. } else {
  355. replacements->push_back(CreateNullConstant(*id));
  356. }
  357. elem++;
  358. });
  359. break;
  360. case SpvOpTypeArray:
  361. for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
  362. if (!components_used || components_used->count(i)) {
  363. CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
  364. replacements);
  365. } else {
  366. replacements->push_back(
  367. CreateNullConstant(type->GetSingleWordInOperand(0u)));
  368. }
  369. }
  370. break;
  371. case SpvOpTypeMatrix:
  372. case SpvOpTypeVector:
  373. for (uint32_t i = 0; i != GetNumElements(type); ++i) {
  374. CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
  375. }
  376. break;
  377. default:
  378. assert(false && "Unexpected type.");
  379. break;
  380. }
  381. TransferAnnotations(inst, replacements);
  382. return std::find(replacements->begin(), replacements->end(), nullptr) ==
  383. replacements->end();
  384. }
  385. void ScalarReplacementPass::TransferAnnotations(
  386. const Instruction* source, std::vector<Instruction*>* replacements) {
  387. // Only transfer invariant and restrict decorations on the variable. There are
  388. // no type or member decorations that are necessary to transfer.
  389. for (auto inst :
  390. get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
  391. assert(inst->opcode() == SpvOpDecorate);
  392. uint32_t decoration = inst->GetSingleWordInOperand(1u);
  393. if (decoration == SpvDecorationInvariant ||
  394. decoration == SpvDecorationRestrict) {
  395. for (auto var : *replacements) {
  396. if (var == nullptr) {
  397. continue;
  398. }
  399. std::unique_ptr<Instruction> annotation(
  400. new Instruction(context(), SpvOpDecorate, 0, 0,
  401. std::initializer_list<Operand>{
  402. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  403. {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
  404. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  405. Operand copy(inst->GetInOperand(i));
  406. annotation->AddOperand(std::move(copy));
  407. }
  408. context()->AddAnnotationInst(std::move(annotation));
  409. get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
  410. }
  411. }
  412. }
  413. }
  414. void ScalarReplacementPass::CreateVariable(
  415. uint32_t typeId, Instruction* varInst, uint32_t index,
  416. std::vector<Instruction*>* replacements) {
  417. uint32_t ptrId = GetOrCreatePointerType(typeId);
  418. uint32_t id = TakeNextId();
  419. if (id == 0) {
  420. replacements->push_back(nullptr);
  421. }
  422. std::unique_ptr<Instruction> variable(new Instruction(
  423. context(), SpvOpVariable, ptrId, id,
  424. std::initializer_list<Operand>{
  425. {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
  426. BasicBlock* block = context()->get_instr_block(varInst);
  427. block->begin().InsertBefore(std::move(variable));
  428. Instruction* inst = &*block->begin();
  429. // If varInst was initialized, make sure to initialize its replacement.
  430. GetOrCreateInitialValue(varInst, index, inst);
  431. get_def_use_mgr()->AnalyzeInstDefUse(inst);
  432. context()->set_instr_block(inst, block);
  433. // Copy decorations from the member to the new variable.
  434. Instruction* typeInst = GetStorageType(varInst);
  435. for (auto dec_inst :
  436. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  437. uint32_t decoration;
  438. if (dec_inst->opcode() != SpvOpMemberDecorate) {
  439. continue;
  440. }
  441. if (dec_inst->GetSingleWordInOperand(1) != index) {
  442. continue;
  443. }
  444. decoration = dec_inst->GetSingleWordInOperand(2u);
  445. switch (decoration) {
  446. case SpvDecorationRelaxedPrecision: {
  447. std::unique_ptr<Instruction> new_dec_inst(
  448. new Instruction(context(), SpvOpDecorate, 0, 0, {}));
  449. new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
  450. for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
  451. new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
  452. }
  453. context()->AddAnnotationInst(std::move(new_dec_inst));
  454. } break;
  455. default:
  456. break;
  457. }
  458. }
  459. // Update the DebugInfo debug information.
  460. inst->UpdateDebugInfoFrom(varInst);
  461. replacements->push_back(inst);
  462. }
  463. uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
  464. auto iter = pointee_to_pointer_.find(id);
  465. if (iter != pointee_to_pointer_.end()) return iter->second;
  466. analysis::Type* pointeeTy;
  467. std::unique_ptr<analysis::Pointer> pointerTy;
  468. std::tie(pointeeTy, pointerTy) =
  469. context()->get_type_mgr()->GetTypeAndPointerType(id,
  470. SpvStorageClassFunction);
  471. uint32_t ptrId = 0;
  472. if (pointeeTy->IsUniqueType()) {
  473. // Non-ambiguous type, just ask the type manager for an id.
  474. ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
  475. pointee_to_pointer_[id] = ptrId;
  476. return ptrId;
  477. }
  478. // Ambiguous type. We must perform a linear search to try and find the right
  479. // type.
  480. for (auto global : context()->types_values()) {
  481. if (global.opcode() == SpvOpTypePointer &&
  482. global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
  483. global.GetSingleWordInOperand(1u) == id) {
  484. if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
  485. // Only reuse a decoration-less pointer of the correct type.
  486. ptrId = global.result_id();
  487. break;
  488. }
  489. }
  490. }
  491. if (ptrId != 0) {
  492. pointee_to_pointer_[id] = ptrId;
  493. return ptrId;
  494. }
  495. ptrId = TakeNextId();
  496. context()->AddType(MakeUnique<Instruction>(
  497. context(), SpvOpTypePointer, 0, ptrId,
  498. std::initializer_list<Operand>{
  499. {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
  500. {SPV_OPERAND_TYPE_ID, {id}}}));
  501. Instruction* ptr = &*--context()->types_values_end();
  502. get_def_use_mgr()->AnalyzeInstDefUse(ptr);
  503. pointee_to_pointer_[id] = ptrId;
  504. // Register with the type manager if necessary.
  505. context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
  506. return ptrId;
  507. }
  508. void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
  509. uint32_t index,
  510. Instruction* newVar) {
  511. assert(source->opcode() == SpvOpVariable);
  512. if (source->NumInOperands() < 2) return;
  513. uint32_t initId = source->GetSingleWordInOperand(1u);
  514. uint32_t storageId = GetStorageType(newVar)->result_id();
  515. Instruction* init = get_def_use_mgr()->GetDef(initId);
  516. uint32_t newInitId = 0;
  517. // TODO(dnovillo): Refactor this with constant propagation.
  518. if (init->opcode() == SpvOpConstantNull) {
  519. // Initialize to appropriate NULL.
  520. auto iter = type_to_null_.find(storageId);
  521. if (iter == type_to_null_.end()) {
  522. newInitId = TakeNextId();
  523. type_to_null_[storageId] = newInitId;
  524. context()->AddGlobalValue(
  525. MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
  526. newInitId, std::initializer_list<Operand>{}));
  527. Instruction* newNull = &*--context()->types_values_end();
  528. get_def_use_mgr()->AnalyzeInstDefUse(newNull);
  529. } else {
  530. newInitId = iter->second;
  531. }
  532. } else if (IsSpecConstantInst(init->opcode())) {
  533. // Create a new constant extract.
  534. newInitId = TakeNextId();
  535. context()->AddGlobalValue(MakeUnique<Instruction>(
  536. context(), SpvOpSpecConstantOp, storageId, newInitId,
  537. std::initializer_list<Operand>{
  538. {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
  539. {SPV_OPERAND_TYPE_ID, {init->result_id()}},
  540. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
  541. Instruction* newSpecConst = &*--context()->types_values_end();
  542. get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
  543. } else if (init->opcode() == SpvOpConstantComposite) {
  544. // Get the appropriate index constant.
  545. newInitId = init->GetSingleWordInOperand(index);
  546. Instruction* element = get_def_use_mgr()->GetDef(newInitId);
  547. if (element->opcode() == SpvOpUndef) {
  548. // Undef is not a valid initializer for a variable.
  549. newInitId = 0;
  550. }
  551. } else {
  552. assert(false);
  553. }
  554. if (newInitId != 0) {
  555. newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
  556. }
  557. }
  558. uint64_t ScalarReplacementPass::GetArrayLength(
  559. const Instruction* arrayType) const {
  560. assert(arrayType->opcode() == SpvOpTypeArray);
  561. const Instruction* length =
  562. get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
  563. return context()
  564. ->get_constant_mgr()
  565. ->GetConstantFromInst(length)
  566. ->GetZeroExtendedValue();
  567. }
  568. uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
  569. assert(type->opcode() == SpvOpTypeVector ||
  570. type->opcode() == SpvOpTypeMatrix);
  571. const Operand& op = type->GetInOperand(1u);
  572. assert(op.words.size() <= 2);
  573. uint64_t len = 0;
  574. for (size_t i = 0; i != op.words.size(); ++i) {
  575. len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
  576. }
  577. return len;
  578. }
  579. bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
  580. const Instruction* inst = get_def_use_mgr()->GetDef(id);
  581. assert(inst);
  582. return spvOpcodeIsSpecConstant(inst->opcode());
  583. }
  584. Instruction* ScalarReplacementPass::GetStorageType(
  585. const Instruction* inst) const {
  586. assert(inst->opcode() == SpvOpVariable);
  587. uint32_t ptrTypeId = inst->type_id();
  588. uint32_t typeId =
  589. get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
  590. return get_def_use_mgr()->GetDef(typeId);
  591. }
  592. bool ScalarReplacementPass::CanReplaceVariable(
  593. const Instruction* varInst) const {
  594. assert(varInst->opcode() == SpvOpVariable);
  595. // Can only replace function scope variables.
  596. if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
  597. return false;
  598. }
  599. if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
  600. return false;
  601. }
  602. const Instruction* typeInst = GetStorageType(varInst);
  603. if (!CheckType(typeInst)) {
  604. return false;
  605. }
  606. if (!CheckAnnotations(varInst)) {
  607. return false;
  608. }
  609. if (!CheckUses(varInst)) {
  610. return false;
  611. }
  612. return true;
  613. }
  614. bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
  615. if (!CheckTypeAnnotations(typeInst)) {
  616. return false;
  617. }
  618. switch (typeInst->opcode()) {
  619. case SpvOpTypeStruct:
  620. // Don't bother with empty structs or very large structs.
  621. if (typeInst->NumInOperands() == 0 ||
  622. IsLargerThanSizeLimit(typeInst->NumInOperands())) {
  623. return false;
  624. }
  625. return true;
  626. case SpvOpTypeArray:
  627. if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
  628. return false;
  629. }
  630. if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
  631. return false;
  632. }
  633. return true;
  634. // TODO(alanbaker): Develop some heuristics for when this should be
  635. // re-enabled.
  636. //// Specifically including matrix and vector in an attempt to reduce the
  637. //// number of vector registers required.
  638. // case SpvOpTypeMatrix:
  639. // case SpvOpTypeVector:
  640. // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
  641. // return true;
  642. case SpvOpTypeRuntimeArray:
  643. default:
  644. return false;
  645. }
  646. }
  647. bool ScalarReplacementPass::CheckTypeAnnotations(
  648. const Instruction* typeInst) const {
  649. for (auto inst :
  650. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  651. uint32_t decoration;
  652. if (inst->opcode() == SpvOpDecorate) {
  653. decoration = inst->GetSingleWordInOperand(1u);
  654. } else {
  655. assert(inst->opcode() == SpvOpMemberDecorate);
  656. decoration = inst->GetSingleWordInOperand(2u);
  657. }
  658. switch (decoration) {
  659. case SpvDecorationRowMajor:
  660. case SpvDecorationColMajor:
  661. case SpvDecorationArrayStride:
  662. case SpvDecorationMatrixStride:
  663. case SpvDecorationCPacked:
  664. case SpvDecorationInvariant:
  665. case SpvDecorationRestrict:
  666. case SpvDecorationOffset:
  667. case SpvDecorationAlignment:
  668. case SpvDecorationAlignmentId:
  669. case SpvDecorationMaxByteOffset:
  670. case SpvDecorationRelaxedPrecision:
  671. break;
  672. default:
  673. return false;
  674. }
  675. }
  676. return true;
  677. }
  678. bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
  679. for (auto inst :
  680. get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
  681. assert(inst->opcode() == SpvOpDecorate);
  682. uint32_t decoration = inst->GetSingleWordInOperand(1u);
  683. switch (decoration) {
  684. case SpvDecorationInvariant:
  685. case SpvDecorationRestrict:
  686. case SpvDecorationAlignment:
  687. case SpvDecorationAlignmentId:
  688. case SpvDecorationMaxByteOffset:
  689. break;
  690. default:
  691. return false;
  692. }
  693. }
  694. return true;
  695. }
  696. bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
  697. VariableStats stats = {0, 0};
  698. bool ok = CheckUses(inst, &stats);
  699. // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
  700. // SRoA is costly, such as when the structure has many (unaccessed?)
  701. // members.
  702. return ok;
  703. }
  704. bool ScalarReplacementPass::CheckUses(const Instruction* inst,
  705. VariableStats* stats) const {
  706. uint64_t max_legal_index = GetMaxLegalIndex(inst);
  707. bool ok = true;
  708. get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
  709. const Instruction* user,
  710. uint32_t index) {
  711. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare ||
  712. user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
  713. // TODO: include num_partial_accesses if it uses Fragment operation or
  714. // DebugValue has Indexes operand.
  715. stats->num_full_accesses++;
  716. return;
  717. }
  718. // Annotations are check as a group separately.
  719. if (!IsAnnotationInst(user->opcode())) {
  720. switch (user->opcode()) {
  721. case SpvOpAccessChain:
  722. case SpvOpInBoundsAccessChain:
  723. if (index == 2u && user->NumInOperands() > 1) {
  724. uint32_t id = user->GetSingleWordInOperand(1u);
  725. const Instruction* opInst = get_def_use_mgr()->GetDef(id);
  726. const auto* constant =
  727. context()->get_constant_mgr()->GetConstantFromInst(opInst);
  728. if (!constant) {
  729. ok = false;
  730. } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
  731. ok = false;
  732. } else {
  733. if (!CheckUsesRelaxed(user)) ok = false;
  734. }
  735. stats->num_partial_accesses++;
  736. } else {
  737. ok = false;
  738. }
  739. break;
  740. case SpvOpLoad:
  741. if (!CheckLoad(user, index)) ok = false;
  742. stats->num_full_accesses++;
  743. break;
  744. case SpvOpStore:
  745. if (!CheckStore(user, index)) ok = false;
  746. stats->num_full_accesses++;
  747. break;
  748. case SpvOpName:
  749. case SpvOpMemberName:
  750. break;
  751. default:
  752. ok = false;
  753. break;
  754. }
  755. }
  756. });
  757. return ok;
  758. }
  759. bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
  760. bool ok = true;
  761. get_def_use_mgr()->ForEachUse(
  762. inst, [this, &ok](const Instruction* user, uint32_t index) {
  763. switch (user->opcode()) {
  764. case SpvOpAccessChain:
  765. case SpvOpInBoundsAccessChain:
  766. if (index != 2u) {
  767. ok = false;
  768. } else {
  769. if (!CheckUsesRelaxed(user)) ok = false;
  770. }
  771. break;
  772. case SpvOpLoad:
  773. if (!CheckLoad(user, index)) ok = false;
  774. break;
  775. case SpvOpStore:
  776. if (!CheckStore(user, index)) ok = false;
  777. break;
  778. case SpvOpImageTexelPointer:
  779. if (!CheckImageTexelPointer(index)) ok = false;
  780. break;
  781. default:
  782. ok = false;
  783. break;
  784. }
  785. });
  786. return ok;
  787. }
  788. bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
  789. return index == 2u;
  790. }
  791. bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
  792. uint32_t index) const {
  793. if (index != 2u) return false;
  794. if (inst->NumInOperands() >= 2 &&
  795. inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
  796. return false;
  797. return true;
  798. }
  799. bool ScalarReplacementPass::CheckStore(const Instruction* inst,
  800. uint32_t index) const {
  801. if (index != 0u) return false;
  802. if (inst->NumInOperands() >= 3 &&
  803. inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
  804. return false;
  805. return true;
  806. }
  807. bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
  808. if (max_num_elements_ == 0) {
  809. return false;
  810. }
  811. return length > max_num_elements_;
  812. }
  813. std::unique_ptr<std::unordered_set<int64_t>>
  814. ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
  815. std::unique_ptr<std::unordered_set<int64_t>> result(
  816. new std::unordered_set<int64_t>());
  817. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  818. def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
  819. this](Instruction* use) {
  820. switch (use->opcode()) {
  821. case SpvOpLoad: {
  822. // Look for extract from the load.
  823. std::vector<uint32_t> t;
  824. if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
  825. if (use2->opcode() != SpvOpCompositeExtract ||
  826. use2->NumInOperands() <= 1) {
  827. return false;
  828. }
  829. t.push_back(use2->GetSingleWordInOperand(1));
  830. return true;
  831. })) {
  832. result->insert(t.begin(), t.end());
  833. return true;
  834. } else {
  835. result.reset(nullptr);
  836. return false;
  837. }
  838. }
  839. case SpvOpName:
  840. case SpvOpMemberName:
  841. case SpvOpStore:
  842. // No components are used.
  843. return true;
  844. case SpvOpAccessChain:
  845. case SpvOpInBoundsAccessChain: {
  846. // Add the first index it if is a constant.
  847. // TODO: Could be improved by checking if the address is used in a load.
  848. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  849. uint32_t index_id = use->GetSingleWordInOperand(1);
  850. const analysis::Constant* index_const =
  851. const_mgr->FindDeclaredConstant(index_id);
  852. if (index_const) {
  853. result->insert(index_const->GetSignExtendedValue());
  854. return true;
  855. } else {
  856. // Could be any element. Assuming all are used.
  857. result.reset(nullptr);
  858. return false;
  859. }
  860. }
  861. default:
  862. // We do not know what is happening. Have to assume the worst.
  863. result.reset(nullptr);
  864. return false;
  865. }
  866. });
  867. return result;
  868. }
  869. Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
  870. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  871. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  872. const analysis::Type* type = type_mgr->GetType(type_id);
  873. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  874. Instruction* null_inst =
  875. const_mgr->GetDefiningInstruction(null_const, type_id);
  876. if (null_inst != nullptr) {
  877. context()->UpdateDefUse(null_inst);
  878. }
  879. return null_inst;
  880. }
  881. uint64_t ScalarReplacementPass::GetMaxLegalIndex(
  882. const Instruction* var_inst) const {
  883. assert(var_inst->opcode() == SpvOpVariable &&
  884. "|var_inst| must be a variable instruction.");
  885. Instruction* type = GetStorageType(var_inst);
  886. switch (type->opcode()) {
  887. case SpvOpTypeStruct:
  888. return type->NumInOperands();
  889. case SpvOpTypeArray:
  890. return GetArrayLength(type);
  891. case SpvOpTypeMatrix:
  892. case SpvOpTypeVector:
  893. return GetNumElements(type);
  894. default:
  895. return 0;
  896. }
  897. return 0;
  898. }
  899. } // namespace opt
  900. } // namespace spvtools