scalar_replacement_pass.cpp 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/scalar_replacement_pass.h"
  15. #include <algorithm>
  16. #include <queue>
  17. #include <tuple>
  18. #include <utility>
  19. #include "source/enum_string_mapping.h"
  20. #include "source/extensions.h"
  21. #include "source/opt/reflect.h"
  22. #include "source/opt/types.h"
  23. #include "source/util/make_unique.h"
  24. static const uint32_t kDebugValueOperandValueIndex = 5;
  25. static const uint32_t kDebugValueOperandExpressionIndex = 6;
  26. static const uint32_t kDebugDeclareOperandVariableIndex = 5;
  27. namespace spvtools {
  28. namespace opt {
  29. Pass::Status ScalarReplacementPass::Process() {
  30. Status status = Status::SuccessWithoutChange;
  31. for (auto& f : *get_module()) {
  32. if (f.IsDeclaration()) {
  33. continue;
  34. }
  35. Status functionStatus = ProcessFunction(&f);
  36. if (functionStatus == Status::Failure)
  37. return functionStatus;
  38. else if (functionStatus == Status::SuccessWithChange)
  39. status = functionStatus;
  40. }
  41. return status;
  42. }
  43. Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
  44. std::queue<Instruction*> worklist;
  45. BasicBlock& entry = *function->begin();
  46. for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
  47. // Function storage class OpVariables must appear as the first instructions
  48. // of the entry block.
  49. if (iter->opcode() != SpvOpVariable) break;
  50. Instruction* varInst = &*iter;
  51. if (CanReplaceVariable(varInst)) {
  52. worklist.push(varInst);
  53. }
  54. }
  55. Status status = Status::SuccessWithoutChange;
  56. while (!worklist.empty()) {
  57. Instruction* varInst = worklist.front();
  58. worklist.pop();
  59. Status var_status = ReplaceVariable(varInst, &worklist);
  60. if (var_status == Status::Failure)
  61. return var_status;
  62. else if (var_status == Status::SuccessWithChange)
  63. status = var_status;
  64. }
  65. return status;
  66. }
  67. Pass::Status ScalarReplacementPass::ReplaceVariable(
  68. Instruction* inst, std::queue<Instruction*>* worklist) {
  69. std::vector<Instruction*> replacements;
  70. if (!CreateReplacementVariables(inst, &replacements)) {
  71. return Status::Failure;
  72. }
  73. std::vector<Instruction*> dead;
  74. bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
  75. inst, [this, &replacements, &dead](Instruction* user) {
  76. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
  77. if (ReplaceWholeDebugDeclare(user, replacements)) {
  78. dead.push_back(user);
  79. return true;
  80. }
  81. return false;
  82. }
  83. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
  84. if (ReplaceWholeDebugValue(user, replacements)) {
  85. dead.push_back(user);
  86. return true;
  87. }
  88. return false;
  89. }
  90. if (!IsAnnotationInst(user->opcode())) {
  91. switch (user->opcode()) {
  92. case SpvOpLoad:
  93. if (ReplaceWholeLoad(user, replacements)) {
  94. dead.push_back(user);
  95. } else {
  96. return false;
  97. }
  98. break;
  99. case SpvOpStore:
  100. if (ReplaceWholeStore(user, replacements)) {
  101. dead.push_back(user);
  102. } else {
  103. return false;
  104. }
  105. break;
  106. case SpvOpAccessChain:
  107. case SpvOpInBoundsAccessChain:
  108. if (ReplaceAccessChain(user, replacements))
  109. dead.push_back(user);
  110. else
  111. return false;
  112. break;
  113. case SpvOpName:
  114. case SpvOpMemberName:
  115. break;
  116. default:
  117. assert(false && "Unexpected opcode");
  118. break;
  119. }
  120. }
  121. return true;
  122. });
  123. if (replaced_all_uses) {
  124. dead.push_back(inst);
  125. } else {
  126. return Status::Failure;
  127. }
  128. // If there are no dead instructions to clean up, return with no changes.
  129. if (dead.empty()) return Status::SuccessWithoutChange;
  130. // Clean up some dead code.
  131. while (!dead.empty()) {
  132. Instruction* toKill = dead.back();
  133. dead.pop_back();
  134. context()->KillInst(toKill);
  135. }
  136. // Attempt to further scalarize.
  137. for (auto var : replacements) {
  138. if (var->opcode() == SpvOpVariable) {
  139. if (get_def_use_mgr()->NumUsers(var) == 0) {
  140. context()->KillInst(var);
  141. } else if (CanReplaceVariable(var)) {
  142. worklist->push(var);
  143. }
  144. }
  145. }
  146. return Status::SuccessWithChange;
  147. }
  148. bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
  149. Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
  150. // Insert Deref operation to the front of the operation list of |dbg_decl|.
  151. Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
  152. dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
  153. auto* deref_expr =
  154. context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
  155. // Add DebugValue instruction with Indexes operand and Deref operation.
  156. int32_t idx = 0;
  157. for (const auto* var : replacements) {
  158. Instruction* insert_before = var->NextNode();
  159. while (insert_before->opcode() == SpvOpVariable)
  160. insert_before = insert_before->NextNode();
  161. assert(insert_before != nullptr && "unexpected end of list");
  162. Instruction* added_dbg_value =
  163. context()->get_debug_info_mgr()->AddDebugValueForDecl(
  164. dbg_decl, /*value_id=*/var->result_id(),
  165. /*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl);
  166. if (added_dbg_value == nullptr) return false;
  167. added_dbg_value->AddOperand(
  168. {SPV_OPERAND_TYPE_ID,
  169. {context()->get_constant_mgr()->GetSIntConst(idx)}});
  170. added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
  171. {deref_expr->result_id()});
  172. if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
  173. context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
  174. }
  175. ++idx;
  176. }
  177. return true;
  178. }
  179. bool ScalarReplacementPass::ReplaceWholeDebugValue(
  180. Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
  181. int32_t idx = 0;
  182. BasicBlock* block = context()->get_instr_block(dbg_value);
  183. for (auto var : replacements) {
  184. // Clone the DebugValue.
  185. std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
  186. uint32_t new_id = TakeNextId();
  187. if (new_id == 0) return false;
  188. new_dbg_value->SetResultId(new_id);
  189. // Update 'Value' operand to the |replacements|.
  190. new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
  191. // Append 'Indexes' operand.
  192. new_dbg_value->AddOperand(
  193. {SPV_OPERAND_TYPE_ID,
  194. {context()->get_constant_mgr()->GetSIntConst(idx)}});
  195. // Insert the new DebugValue to the basic block.
  196. auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
  197. get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
  198. context()->set_instr_block(added_instr, block);
  199. ++idx;
  200. }
  201. return true;
  202. }
  203. bool ScalarReplacementPass::ReplaceWholeLoad(
  204. Instruction* load, const std::vector<Instruction*>& replacements) {
  205. // Replaces the load of the entire composite with a load from each replacement
  206. // variable followed by a composite construction.
  207. BasicBlock* block = context()->get_instr_block(load);
  208. std::vector<Instruction*> loads;
  209. loads.reserve(replacements.size());
  210. BasicBlock::iterator where(load);
  211. for (auto var : replacements) {
  212. // Create a load of each replacement variable.
  213. if (var->opcode() != SpvOpVariable) {
  214. loads.push_back(var);
  215. continue;
  216. }
  217. Instruction* type = GetStorageType(var);
  218. uint32_t loadId = TakeNextId();
  219. if (loadId == 0) {
  220. return false;
  221. }
  222. std::unique_ptr<Instruction> newLoad(
  223. new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
  224. std::initializer_list<Operand>{
  225. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  226. // Copy memory access attributes which start at index 1. Index 0 is the
  227. // pointer to load.
  228. for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
  229. Operand copy(load->GetInOperand(i));
  230. newLoad->AddOperand(std::move(copy));
  231. }
  232. where = where.InsertBefore(std::move(newLoad));
  233. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  234. context()->set_instr_block(&*where, block);
  235. where->UpdateDebugInfoFrom(load);
  236. loads.push_back(&*where);
  237. }
  238. // Construct a new composite.
  239. uint32_t compositeId = TakeNextId();
  240. if (compositeId == 0) {
  241. return false;
  242. }
  243. where = load;
  244. std::unique_ptr<Instruction> compositeConstruct(new Instruction(
  245. context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
  246. for (auto l : loads) {
  247. Operand op(SPV_OPERAND_TYPE_ID,
  248. std::initializer_list<uint32_t>{l->result_id()});
  249. compositeConstruct->AddOperand(std::move(op));
  250. }
  251. where = where.InsertBefore(std::move(compositeConstruct));
  252. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  253. where->UpdateDebugInfoFrom(load);
  254. context()->set_instr_block(&*where, block);
  255. context()->ReplaceAllUsesWith(load->result_id(), compositeId);
  256. return true;
  257. }
  258. bool ScalarReplacementPass::ReplaceWholeStore(
  259. Instruction* store, const std::vector<Instruction*>& replacements) {
  260. // Replaces a store to the whole composite with a series of extract and stores
  261. // to each element.
  262. uint32_t storeInput = store->GetSingleWordInOperand(1u);
  263. BasicBlock* block = context()->get_instr_block(store);
  264. BasicBlock::iterator where(store);
  265. uint32_t elementIndex = 0;
  266. for (auto var : replacements) {
  267. // Create the extract.
  268. if (var->opcode() != SpvOpVariable) {
  269. elementIndex++;
  270. continue;
  271. }
  272. Instruction* type = GetStorageType(var);
  273. uint32_t extractId = TakeNextId();
  274. if (extractId == 0) {
  275. return false;
  276. }
  277. std::unique_ptr<Instruction> extract(new Instruction(
  278. context(), SpvOpCompositeExtract, type->result_id(), extractId,
  279. std::initializer_list<Operand>{
  280. {SPV_OPERAND_TYPE_ID, {storeInput}},
  281. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
  282. auto iter = where.InsertBefore(std::move(extract));
  283. iter->UpdateDebugInfoFrom(store);
  284. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  285. context()->set_instr_block(&*iter, block);
  286. // Create the store.
  287. std::unique_ptr<Instruction> newStore(
  288. new Instruction(context(), SpvOpStore, 0, 0,
  289. std::initializer_list<Operand>{
  290. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  291. {SPV_OPERAND_TYPE_ID, {extractId}}}));
  292. // Copy memory access attributes which start at index 2. Index 0 is the
  293. // pointer and index 1 is the data.
  294. for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
  295. Operand copy(store->GetInOperand(i));
  296. newStore->AddOperand(std::move(copy));
  297. }
  298. iter = where.InsertBefore(std::move(newStore));
  299. iter->UpdateDebugInfoFrom(store);
  300. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  301. context()->set_instr_block(&*iter, block);
  302. }
  303. return true;
  304. }
  305. bool ScalarReplacementPass::ReplaceAccessChain(
  306. Instruction* chain, const std::vector<Instruction*>& replacements) {
  307. // Replaces the access chain with either another access chain (with one fewer
  308. // indexes) or a direct use of the replacement variable.
  309. uint32_t indexId = chain->GetSingleWordInOperand(1u);
  310. const Instruction* index = get_def_use_mgr()->GetDef(indexId);
  311. int64_t indexValue = context()
  312. ->get_constant_mgr()
  313. ->GetConstantFromInst(index)
  314. ->GetSignExtendedValue();
  315. if (indexValue < 0 ||
  316. indexValue >= static_cast<int64_t>(replacements.size())) {
  317. // Out of bounds access, this is illegal IR. Notice that OpAccessChain
  318. // indexing is 0-based, so we should also reject index == size-of-array.
  319. return false;
  320. } else {
  321. const Instruction* var = replacements[static_cast<size_t>(indexValue)];
  322. if (chain->NumInOperands() > 2) {
  323. // Replace input access chain with another access chain.
  324. BasicBlock::iterator chainIter(chain);
  325. uint32_t replacementId = TakeNextId();
  326. if (replacementId == 0) {
  327. return false;
  328. }
  329. std::unique_ptr<Instruction> replacementChain(new Instruction(
  330. context(), chain->opcode(), chain->type_id(), replacementId,
  331. std::initializer_list<Operand>{
  332. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  333. // Add the remaining indexes.
  334. for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
  335. Operand copy(chain->GetInOperand(i));
  336. replacementChain->AddOperand(std::move(copy));
  337. }
  338. replacementChain->UpdateDebugInfoFrom(chain);
  339. auto iter = chainIter.InsertBefore(std::move(replacementChain));
  340. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  341. context()->set_instr_block(&*iter, context()->get_instr_block(chain));
  342. context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
  343. } else {
  344. // Replace with a use of the variable.
  345. context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
  346. }
  347. }
  348. return true;
  349. }
  350. bool ScalarReplacementPass::CreateReplacementVariables(
  351. Instruction* inst, std::vector<Instruction*>* replacements) {
  352. Instruction* type = GetStorageType(inst);
  353. std::unique_ptr<std::unordered_set<int64_t>> components_used =
  354. GetUsedComponents(inst);
  355. uint32_t elem = 0;
  356. switch (type->opcode()) {
  357. case SpvOpTypeStruct:
  358. type->ForEachInOperand(
  359. [this, inst, &elem, replacements, &components_used](uint32_t* id) {
  360. if (!components_used || components_used->count(elem)) {
  361. CreateVariable(*id, inst, elem, replacements);
  362. } else {
  363. replacements->push_back(CreateNullConstant(*id));
  364. }
  365. elem++;
  366. });
  367. break;
  368. case SpvOpTypeArray:
  369. for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
  370. if (!components_used || components_used->count(i)) {
  371. CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
  372. replacements);
  373. } else {
  374. replacements->push_back(
  375. CreateNullConstant(type->GetSingleWordInOperand(0u)));
  376. }
  377. }
  378. break;
  379. case SpvOpTypeMatrix:
  380. case SpvOpTypeVector:
  381. for (uint32_t i = 0; i != GetNumElements(type); ++i) {
  382. CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
  383. }
  384. break;
  385. default:
  386. assert(false && "Unexpected type.");
  387. break;
  388. }
  389. TransferAnnotations(inst, replacements);
  390. return std::find(replacements->begin(), replacements->end(), nullptr) ==
  391. replacements->end();
  392. }
  393. void ScalarReplacementPass::TransferAnnotations(
  394. const Instruction* source, std::vector<Instruction*>* replacements) {
  395. // Only transfer invariant and restrict decorations on the variable. There are
  396. // no type or member decorations that are necessary to transfer.
  397. for (auto inst :
  398. get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
  399. assert(inst->opcode() == SpvOpDecorate);
  400. uint32_t decoration = inst->GetSingleWordInOperand(1u);
  401. if (decoration == SpvDecorationInvariant ||
  402. decoration == SpvDecorationRestrict) {
  403. for (auto var : *replacements) {
  404. if (var == nullptr) {
  405. continue;
  406. }
  407. std::unique_ptr<Instruction> annotation(
  408. new Instruction(context(), SpvOpDecorate, 0, 0,
  409. std::initializer_list<Operand>{
  410. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  411. {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
  412. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  413. Operand copy(inst->GetInOperand(i));
  414. annotation->AddOperand(std::move(copy));
  415. }
  416. context()->AddAnnotationInst(std::move(annotation));
  417. get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
  418. }
  419. }
  420. }
  421. }
  422. void ScalarReplacementPass::CreateVariable(
  423. uint32_t typeId, Instruction* varInst, uint32_t index,
  424. std::vector<Instruction*>* replacements) {
  425. uint32_t ptrId = GetOrCreatePointerType(typeId);
  426. uint32_t id = TakeNextId();
  427. if (id == 0) {
  428. replacements->push_back(nullptr);
  429. }
  430. std::unique_ptr<Instruction> variable(new Instruction(
  431. context(), SpvOpVariable, ptrId, id,
  432. std::initializer_list<Operand>{
  433. {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
  434. BasicBlock* block = context()->get_instr_block(varInst);
  435. block->begin().InsertBefore(std::move(variable));
  436. Instruction* inst = &*block->begin();
  437. // If varInst was initialized, make sure to initialize its replacement.
  438. GetOrCreateInitialValue(varInst, index, inst);
  439. get_def_use_mgr()->AnalyzeInstDefUse(inst);
  440. context()->set_instr_block(inst, block);
  441. // Copy decorations from the member to the new variable.
  442. Instruction* typeInst = GetStorageType(varInst);
  443. for (auto dec_inst :
  444. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  445. uint32_t decoration;
  446. if (dec_inst->opcode() != SpvOpMemberDecorate) {
  447. continue;
  448. }
  449. if (dec_inst->GetSingleWordInOperand(1) != index) {
  450. continue;
  451. }
  452. decoration = dec_inst->GetSingleWordInOperand(2u);
  453. switch (decoration) {
  454. case SpvDecorationRelaxedPrecision: {
  455. std::unique_ptr<Instruction> new_dec_inst(
  456. new Instruction(context(), SpvOpDecorate, 0, 0, {}));
  457. new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
  458. for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
  459. new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
  460. }
  461. context()->AddAnnotationInst(std::move(new_dec_inst));
  462. } break;
  463. default:
  464. break;
  465. }
  466. }
  467. // Update the DebugInfo debug information.
  468. inst->UpdateDebugInfoFrom(varInst);
  469. replacements->push_back(inst);
  470. }
  471. uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
  472. auto iter = pointee_to_pointer_.find(id);
  473. if (iter != pointee_to_pointer_.end()) return iter->second;
  474. analysis::Type* pointeeTy;
  475. std::unique_ptr<analysis::Pointer> pointerTy;
  476. std::tie(pointeeTy, pointerTy) =
  477. context()->get_type_mgr()->GetTypeAndPointerType(id,
  478. SpvStorageClassFunction);
  479. uint32_t ptrId = 0;
  480. if (pointeeTy->IsUniqueType()) {
  481. // Non-ambiguous type, just ask the type manager for an id.
  482. ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
  483. pointee_to_pointer_[id] = ptrId;
  484. return ptrId;
  485. }
  486. // Ambiguous type. We must perform a linear search to try and find the right
  487. // type.
  488. for (auto global : context()->types_values()) {
  489. if (global.opcode() == SpvOpTypePointer &&
  490. global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
  491. global.GetSingleWordInOperand(1u) == id) {
  492. if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
  493. // Only reuse a decoration-less pointer of the correct type.
  494. ptrId = global.result_id();
  495. break;
  496. }
  497. }
  498. }
  499. if (ptrId != 0) {
  500. pointee_to_pointer_[id] = ptrId;
  501. return ptrId;
  502. }
  503. ptrId = TakeNextId();
  504. context()->AddType(MakeUnique<Instruction>(
  505. context(), SpvOpTypePointer, 0, ptrId,
  506. std::initializer_list<Operand>{
  507. {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
  508. {SPV_OPERAND_TYPE_ID, {id}}}));
  509. Instruction* ptr = &*--context()->types_values_end();
  510. get_def_use_mgr()->AnalyzeInstDefUse(ptr);
  511. pointee_to_pointer_[id] = ptrId;
  512. // Register with the type manager if necessary.
  513. context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
  514. return ptrId;
  515. }
  516. void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
  517. uint32_t index,
  518. Instruction* newVar) {
  519. assert(source->opcode() == SpvOpVariable);
  520. if (source->NumInOperands() < 2) return;
  521. uint32_t initId = source->GetSingleWordInOperand(1u);
  522. uint32_t storageId = GetStorageType(newVar)->result_id();
  523. Instruction* init = get_def_use_mgr()->GetDef(initId);
  524. uint32_t newInitId = 0;
  525. // TODO(dnovillo): Refactor this with constant propagation.
  526. if (init->opcode() == SpvOpConstantNull) {
  527. // Initialize to appropriate NULL.
  528. auto iter = type_to_null_.find(storageId);
  529. if (iter == type_to_null_.end()) {
  530. newInitId = TakeNextId();
  531. type_to_null_[storageId] = newInitId;
  532. context()->AddGlobalValue(
  533. MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
  534. newInitId, std::initializer_list<Operand>{}));
  535. Instruction* newNull = &*--context()->types_values_end();
  536. get_def_use_mgr()->AnalyzeInstDefUse(newNull);
  537. } else {
  538. newInitId = iter->second;
  539. }
  540. } else if (IsSpecConstantInst(init->opcode())) {
  541. // Create a new constant extract.
  542. newInitId = TakeNextId();
  543. context()->AddGlobalValue(MakeUnique<Instruction>(
  544. context(), SpvOpSpecConstantOp, storageId, newInitId,
  545. std::initializer_list<Operand>{
  546. {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
  547. {SPV_OPERAND_TYPE_ID, {init->result_id()}},
  548. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
  549. Instruction* newSpecConst = &*--context()->types_values_end();
  550. get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
  551. } else if (init->opcode() == SpvOpConstantComposite) {
  552. // Get the appropriate index constant.
  553. newInitId = init->GetSingleWordInOperand(index);
  554. Instruction* element = get_def_use_mgr()->GetDef(newInitId);
  555. if (element->opcode() == SpvOpUndef) {
  556. // Undef is not a valid initializer for a variable.
  557. newInitId = 0;
  558. }
  559. } else {
  560. assert(false);
  561. }
  562. if (newInitId != 0) {
  563. newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
  564. }
  565. }
  566. uint64_t ScalarReplacementPass::GetArrayLength(
  567. const Instruction* arrayType) const {
  568. assert(arrayType->opcode() == SpvOpTypeArray);
  569. const Instruction* length =
  570. get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
  571. return context()
  572. ->get_constant_mgr()
  573. ->GetConstantFromInst(length)
  574. ->GetZeroExtendedValue();
  575. }
  576. uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
  577. assert(type->opcode() == SpvOpTypeVector ||
  578. type->opcode() == SpvOpTypeMatrix);
  579. const Operand& op = type->GetInOperand(1u);
  580. assert(op.words.size() <= 2);
  581. uint64_t len = 0;
  582. for (size_t i = 0; i != op.words.size(); ++i) {
  583. len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
  584. }
  585. return len;
  586. }
  587. bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
  588. const Instruction* inst = get_def_use_mgr()->GetDef(id);
  589. assert(inst);
  590. return spvOpcodeIsSpecConstant(inst->opcode());
  591. }
  592. Instruction* ScalarReplacementPass::GetStorageType(
  593. const Instruction* inst) const {
  594. assert(inst->opcode() == SpvOpVariable);
  595. uint32_t ptrTypeId = inst->type_id();
  596. uint32_t typeId =
  597. get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
  598. return get_def_use_mgr()->GetDef(typeId);
  599. }
  600. bool ScalarReplacementPass::CanReplaceVariable(
  601. const Instruction* varInst) const {
  602. assert(varInst->opcode() == SpvOpVariable);
  603. // Can only replace function scope variables.
  604. if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
  605. return false;
  606. }
  607. if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
  608. return false;
  609. }
  610. const Instruction* typeInst = GetStorageType(varInst);
  611. if (!CheckType(typeInst)) {
  612. return false;
  613. }
  614. if (!CheckAnnotations(varInst)) {
  615. return false;
  616. }
  617. if (!CheckUses(varInst)) {
  618. return false;
  619. }
  620. return true;
  621. }
  622. bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
  623. if (!CheckTypeAnnotations(typeInst)) {
  624. return false;
  625. }
  626. switch (typeInst->opcode()) {
  627. case SpvOpTypeStruct:
  628. // Don't bother with empty structs or very large structs.
  629. if (typeInst->NumInOperands() == 0 ||
  630. IsLargerThanSizeLimit(typeInst->NumInOperands())) {
  631. return false;
  632. }
  633. return true;
  634. case SpvOpTypeArray:
  635. if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
  636. return false;
  637. }
  638. if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
  639. return false;
  640. }
  641. return true;
  642. // TODO(alanbaker): Develop some heuristics for when this should be
  643. // re-enabled.
  644. //// Specifically including matrix and vector in an attempt to reduce the
  645. //// number of vector registers required.
  646. // case SpvOpTypeMatrix:
  647. // case SpvOpTypeVector:
  648. // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
  649. // return true;
  650. case SpvOpTypeRuntimeArray:
  651. default:
  652. return false;
  653. }
  654. }
  655. bool ScalarReplacementPass::CheckTypeAnnotations(
  656. const Instruction* typeInst) const {
  657. for (auto inst :
  658. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  659. uint32_t decoration;
  660. if (inst->opcode() == SpvOpDecorate) {
  661. decoration = inst->GetSingleWordInOperand(1u);
  662. } else {
  663. assert(inst->opcode() == SpvOpMemberDecorate);
  664. decoration = inst->GetSingleWordInOperand(2u);
  665. }
  666. switch (decoration) {
  667. case SpvDecorationRowMajor:
  668. case SpvDecorationColMajor:
  669. case SpvDecorationArrayStride:
  670. case SpvDecorationMatrixStride:
  671. case SpvDecorationCPacked:
  672. case SpvDecorationInvariant:
  673. case SpvDecorationRestrict:
  674. case SpvDecorationOffset:
  675. case SpvDecorationAlignment:
  676. case SpvDecorationAlignmentId:
  677. case SpvDecorationMaxByteOffset:
  678. case SpvDecorationRelaxedPrecision:
  679. break;
  680. default:
  681. return false;
  682. }
  683. }
  684. return true;
  685. }
  686. bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
  687. for (auto inst :
  688. get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
  689. assert(inst->opcode() == SpvOpDecorate);
  690. uint32_t decoration = inst->GetSingleWordInOperand(1u);
  691. switch (decoration) {
  692. case SpvDecorationInvariant:
  693. case SpvDecorationRestrict:
  694. case SpvDecorationAlignment:
  695. case SpvDecorationAlignmentId:
  696. case SpvDecorationMaxByteOffset:
  697. break;
  698. default:
  699. return false;
  700. }
  701. }
  702. return true;
  703. }
  704. bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
  705. VariableStats stats = {0, 0};
  706. bool ok = CheckUses(inst, &stats);
  707. // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
  708. // SRoA is costly, such as when the structure has many (unaccessed?)
  709. // members.
  710. return ok;
  711. }
  712. bool ScalarReplacementPass::CheckUses(const Instruction* inst,
  713. VariableStats* stats) const {
  714. uint64_t max_legal_index = GetMaxLegalIndex(inst);
  715. bool ok = true;
  716. get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
  717. const Instruction* user,
  718. uint32_t index) {
  719. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare ||
  720. user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
  721. // TODO: include num_partial_accesses if it uses Fragment operation or
  722. // DebugValue has Indexes operand.
  723. stats->num_full_accesses++;
  724. return;
  725. }
  726. // Annotations are check as a group separately.
  727. if (!IsAnnotationInst(user->opcode())) {
  728. switch (user->opcode()) {
  729. case SpvOpAccessChain:
  730. case SpvOpInBoundsAccessChain:
  731. if (index == 2u && user->NumInOperands() > 1) {
  732. uint32_t id = user->GetSingleWordInOperand(1u);
  733. const Instruction* opInst = get_def_use_mgr()->GetDef(id);
  734. const auto* constant =
  735. context()->get_constant_mgr()->GetConstantFromInst(opInst);
  736. if (!constant) {
  737. ok = false;
  738. } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
  739. ok = false;
  740. } else {
  741. if (!CheckUsesRelaxed(user)) ok = false;
  742. }
  743. stats->num_partial_accesses++;
  744. } else {
  745. ok = false;
  746. }
  747. break;
  748. case SpvOpLoad:
  749. if (!CheckLoad(user, index)) ok = false;
  750. stats->num_full_accesses++;
  751. break;
  752. case SpvOpStore:
  753. if (!CheckStore(user, index)) ok = false;
  754. stats->num_full_accesses++;
  755. break;
  756. case SpvOpName:
  757. case SpvOpMemberName:
  758. break;
  759. default:
  760. ok = false;
  761. break;
  762. }
  763. }
  764. });
  765. return ok;
  766. }
  767. bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
  768. bool ok = true;
  769. get_def_use_mgr()->ForEachUse(
  770. inst, [this, &ok](const Instruction* user, uint32_t index) {
  771. switch (user->opcode()) {
  772. case SpvOpAccessChain:
  773. case SpvOpInBoundsAccessChain:
  774. if (index != 2u) {
  775. ok = false;
  776. } else {
  777. if (!CheckUsesRelaxed(user)) ok = false;
  778. }
  779. break;
  780. case SpvOpLoad:
  781. if (!CheckLoad(user, index)) ok = false;
  782. break;
  783. case SpvOpStore:
  784. if (!CheckStore(user, index)) ok = false;
  785. break;
  786. case SpvOpImageTexelPointer:
  787. if (!CheckImageTexelPointer(index)) ok = false;
  788. break;
  789. case SpvOpExtInst:
  790. if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare ||
  791. !CheckDebugDeclare(index))
  792. ok = false;
  793. break;
  794. default:
  795. ok = false;
  796. break;
  797. }
  798. });
  799. return ok;
  800. }
  801. bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
  802. return index == 2u;
  803. }
  804. bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
  805. uint32_t index) const {
  806. if (index != 2u) return false;
  807. if (inst->NumInOperands() >= 2 &&
  808. inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
  809. return false;
  810. return true;
  811. }
  812. bool ScalarReplacementPass::CheckStore(const Instruction* inst,
  813. uint32_t index) const {
  814. if (index != 0u) return false;
  815. if (inst->NumInOperands() >= 3 &&
  816. inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
  817. return false;
  818. return true;
  819. }
  820. bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const {
  821. if (index != kDebugDeclareOperandVariableIndex) return false;
  822. return true;
  823. }
  824. bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
  825. if (max_num_elements_ == 0) {
  826. return false;
  827. }
  828. return length > max_num_elements_;
  829. }
  830. std::unique_ptr<std::unordered_set<int64_t>>
  831. ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
  832. std::unique_ptr<std::unordered_set<int64_t>> result(
  833. new std::unordered_set<int64_t>());
  834. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  835. def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
  836. this](Instruction* use) {
  837. switch (use->opcode()) {
  838. case SpvOpLoad: {
  839. // Look for extract from the load.
  840. std::vector<uint32_t> t;
  841. if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
  842. if (use2->opcode() != SpvOpCompositeExtract ||
  843. use2->NumInOperands() <= 1) {
  844. return false;
  845. }
  846. t.push_back(use2->GetSingleWordInOperand(1));
  847. return true;
  848. })) {
  849. result->insert(t.begin(), t.end());
  850. return true;
  851. } else {
  852. result.reset(nullptr);
  853. return false;
  854. }
  855. }
  856. case SpvOpName:
  857. case SpvOpMemberName:
  858. case SpvOpStore:
  859. // No components are used.
  860. return true;
  861. case SpvOpAccessChain:
  862. case SpvOpInBoundsAccessChain: {
  863. // Add the first index it if is a constant.
  864. // TODO: Could be improved by checking if the address is used in a load.
  865. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  866. uint32_t index_id = use->GetSingleWordInOperand(1);
  867. const analysis::Constant* index_const =
  868. const_mgr->FindDeclaredConstant(index_id);
  869. if (index_const) {
  870. result->insert(index_const->GetSignExtendedValue());
  871. return true;
  872. } else {
  873. // Could be any element. Assuming all are used.
  874. result.reset(nullptr);
  875. return false;
  876. }
  877. }
  878. default:
  879. // We do not know what is happening. Have to assume the worst.
  880. result.reset(nullptr);
  881. return false;
  882. }
  883. });
  884. return result;
  885. }
  886. Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
  887. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  888. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  889. const analysis::Type* type = type_mgr->GetType(type_id);
  890. const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  891. Instruction* null_inst =
  892. const_mgr->GetDefiningInstruction(null_const, type_id);
  893. if (null_inst != nullptr) {
  894. context()->UpdateDefUse(null_inst);
  895. }
  896. return null_inst;
  897. }
  898. uint64_t ScalarReplacementPass::GetMaxLegalIndex(
  899. const Instruction* var_inst) const {
  900. assert(var_inst->opcode() == SpvOpVariable &&
  901. "|var_inst| must be a variable instruction.");
  902. Instruction* type = GetStorageType(var_inst);
  903. switch (type->opcode()) {
  904. case SpvOpTypeStruct:
  905. return type->NumInOperands();
  906. case SpvOpTypeArray:
  907. return GetArrayLength(type);
  908. case SpvOpTypeMatrix:
  909. case SpvOpTypeVector:
  910. return GetNumElements(type);
  911. default:
  912. return 0;
  913. }
  914. return 0;
  915. }
  916. } // namespace opt
  917. } // namespace spvtools