scalar_replacement_pass.cpp 35 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. // Copyright (c) 2017 Google Inc.
  2. // Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
  3. // reserved.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "source/opt/scalar_replacement_pass.h"
  17. #include <algorithm>
  18. #include <queue>
  19. #include <tuple>
  20. #include <utility>
  21. #include "source/extensions.h"
  22. #include "source/opt/reflect.h"
  23. #include "source/opt/types.h"
  24. #include "source/util/make_unique.h"
  25. namespace spvtools {
  26. namespace opt {
  27. namespace {
  28. constexpr uint32_t kDebugValueOperandValueIndex = 5;
  29. constexpr uint32_t kDebugValueOperandExpressionIndex = 6;
  30. constexpr uint32_t kDebugDeclareOperandVariableIndex = 5;
  31. } // namespace
  32. Pass::Status ScalarReplacementPass::Process() {
  33. Status status = Status::SuccessWithoutChange;
  34. for (auto& f : *get_module()) {
  35. if (f.IsDeclaration()) {
  36. continue;
  37. }
  38. Status functionStatus = ProcessFunction(&f);
  39. if (functionStatus == Status::Failure)
  40. return functionStatus;
  41. else if (functionStatus == Status::SuccessWithChange)
  42. status = functionStatus;
  43. }
  44. return status;
  45. }
  46. Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
  47. std::queue<Instruction*> worklist;
  48. BasicBlock& entry = *function->begin();
  49. for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
  50. // Function storage class OpVariables must appear as the first instructions
  51. // of the entry block.
  52. if (iter->opcode() != spv::Op::OpVariable) break;
  53. Instruction* varInst = &*iter;
  54. if (CanReplaceVariable(varInst)) {
  55. worklist.push(varInst);
  56. }
  57. }
  58. Status status = Status::SuccessWithoutChange;
  59. while (!worklist.empty()) {
  60. Instruction* varInst = worklist.front();
  61. worklist.pop();
  62. Status var_status = ReplaceVariable(varInst, &worklist);
  63. if (var_status == Status::Failure)
  64. return var_status;
  65. else if (var_status == Status::SuccessWithChange)
  66. status = var_status;
  67. }
  68. return status;
  69. }
  70. Pass::Status ScalarReplacementPass::ReplaceVariable(
  71. Instruction* inst, std::queue<Instruction*>* worklist) {
  72. std::vector<Instruction*> replacements;
  73. if (!CreateReplacementVariables(inst, &replacements)) {
  74. return Status::Failure;
  75. }
  76. std::vector<Instruction*> dead;
  77. bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
  78. inst, [this, &replacements, &dead](Instruction* user) {
  79. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
  80. if (ReplaceWholeDebugDeclare(user, replacements)) {
  81. dead.push_back(user);
  82. return true;
  83. }
  84. return false;
  85. }
  86. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
  87. if (ReplaceWholeDebugValue(user, replacements)) {
  88. dead.push_back(user);
  89. return true;
  90. }
  91. return false;
  92. }
  93. if (!IsAnnotationInst(user->opcode())) {
  94. switch (user->opcode()) {
  95. case spv::Op::OpLoad:
  96. if (ReplaceWholeLoad(user, replacements)) {
  97. dead.push_back(user);
  98. } else {
  99. return false;
  100. }
  101. break;
  102. case spv::Op::OpStore:
  103. if (ReplaceWholeStore(user, replacements)) {
  104. dead.push_back(user);
  105. } else {
  106. return false;
  107. }
  108. break;
  109. case spv::Op::OpAccessChain:
  110. case spv::Op::OpInBoundsAccessChain:
  111. if (ReplaceAccessChain(user, replacements))
  112. dead.push_back(user);
  113. else
  114. return false;
  115. break;
  116. case spv::Op::OpName:
  117. case spv::Op::OpMemberName:
  118. break;
  119. default:
  120. assert(false && "Unexpected opcode");
  121. break;
  122. }
  123. }
  124. return true;
  125. });
  126. if (replaced_all_uses) {
  127. dead.push_back(inst);
  128. } else {
  129. return Status::Failure;
  130. }
  131. // If there are no dead instructions to clean up, return with no changes.
  132. if (dead.empty()) return Status::SuccessWithoutChange;
  133. // Clean up some dead code.
  134. while (!dead.empty()) {
  135. Instruction* toKill = dead.back();
  136. dead.pop_back();
  137. context()->KillInst(toKill);
  138. }
  139. // Attempt to further scalarize.
  140. for (auto var : replacements) {
  141. if (var->opcode() == spv::Op::OpVariable) {
  142. if (get_def_use_mgr()->NumUsers(var) == 0) {
  143. context()->KillInst(var);
  144. } else if (CanReplaceVariable(var)) {
  145. worklist->push(var);
  146. }
  147. }
  148. }
  149. return Status::SuccessWithChange;
  150. }
  151. bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
  152. Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
  153. // Insert Deref operation to the front of the operation list of |dbg_decl|.
  154. Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
  155. dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
  156. auto* deref_expr =
  157. context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
  158. // Add DebugValue instruction with Indexes operand and Deref operation.
  159. int32_t idx = 0;
  160. for (const auto* var : replacements) {
  161. Instruction* insert_before = var->NextNode();
  162. while (insert_before->opcode() == spv::Op::OpVariable)
  163. insert_before = insert_before->NextNode();
  164. assert(insert_before != nullptr && "unexpected end of list");
  165. Instruction* added_dbg_value =
  166. context()->get_debug_info_mgr()->AddDebugValueForDecl(
  167. dbg_decl, /*value_id=*/var->result_id(),
  168. /*insert_before=*/insert_before, /*scope_and_line=*/dbg_decl);
  169. if (added_dbg_value == nullptr) return false;
  170. added_dbg_value->AddOperand(
  171. {SPV_OPERAND_TYPE_ID,
  172. {context()->get_constant_mgr()->GetSIntConstId(idx)}});
  173. added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
  174. {deref_expr->result_id()});
  175. if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
  176. context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
  177. }
  178. ++idx;
  179. }
  180. return true;
  181. }
  182. bool ScalarReplacementPass::ReplaceWholeDebugValue(
  183. Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
  184. int32_t idx = 0;
  185. BasicBlock* block = context()->get_instr_block(dbg_value);
  186. for (auto var : replacements) {
  187. // Clone the DebugValue.
  188. std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
  189. uint32_t new_id = TakeNextId();
  190. if (new_id == 0) return false;
  191. new_dbg_value->SetResultId(new_id);
  192. // Update 'Value' operand to the |replacements|.
  193. new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
  194. // Append 'Indexes' operand.
  195. new_dbg_value->AddOperand(
  196. {SPV_OPERAND_TYPE_ID,
  197. {context()->get_constant_mgr()->GetSIntConstId(idx)}});
  198. // Insert the new DebugValue to the basic block.
  199. auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
  200. get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
  201. context()->set_instr_block(added_instr, block);
  202. ++idx;
  203. }
  204. return true;
  205. }
  206. bool ScalarReplacementPass::ReplaceWholeLoad(
  207. Instruction* load, const std::vector<Instruction*>& replacements) {
  208. // Replaces the load of the entire composite with a load from each replacement
  209. // variable followed by a composite construction.
  210. BasicBlock* block = context()->get_instr_block(load);
  211. std::vector<Instruction*> loads;
  212. loads.reserve(replacements.size());
  213. BasicBlock::iterator where(load);
  214. for (auto var : replacements) {
  215. // Create a load of each replacement variable.
  216. if (var->opcode() != spv::Op::OpVariable) {
  217. loads.push_back(var);
  218. continue;
  219. }
  220. Instruction* type = GetStorageType(var);
  221. uint32_t loadId = TakeNextId();
  222. if (loadId == 0) {
  223. return false;
  224. }
  225. std::unique_ptr<Instruction> newLoad(
  226. new Instruction(context(), spv::Op::OpLoad, type->result_id(), loadId,
  227. std::initializer_list<Operand>{
  228. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  229. // Copy memory access attributes which start at index 1. Index 0 is the
  230. // pointer to load.
  231. for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
  232. Operand copy(load->GetInOperand(i));
  233. newLoad->AddOperand(std::move(copy));
  234. }
  235. where = where.InsertBefore(std::move(newLoad));
  236. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  237. context()->set_instr_block(&*where, block);
  238. where->UpdateDebugInfoFrom(load);
  239. loads.push_back(&*where);
  240. }
  241. // Construct a new composite.
  242. uint32_t compositeId = TakeNextId();
  243. if (compositeId == 0) {
  244. return false;
  245. }
  246. where = load;
  247. std::unique_ptr<Instruction> compositeConstruct(
  248. new Instruction(context(), spv::Op::OpCompositeConstruct, load->type_id(),
  249. compositeId, {}));
  250. for (auto l : loads) {
  251. Operand op(SPV_OPERAND_TYPE_ID,
  252. std::initializer_list<uint32_t>{l->result_id()});
  253. compositeConstruct->AddOperand(std::move(op));
  254. }
  255. where = where.InsertBefore(std::move(compositeConstruct));
  256. get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  257. where->UpdateDebugInfoFrom(load);
  258. context()->set_instr_block(&*where, block);
  259. context()->ReplaceAllUsesWith(load->result_id(), compositeId);
  260. return true;
  261. }
  262. bool ScalarReplacementPass::ReplaceWholeStore(
  263. Instruction* store, const std::vector<Instruction*>& replacements) {
  264. // Replaces a store to the whole composite with a series of extract and stores
  265. // to each element.
  266. uint32_t storeInput = store->GetSingleWordInOperand(1u);
  267. BasicBlock* block = context()->get_instr_block(store);
  268. BasicBlock::iterator where(store);
  269. uint32_t elementIndex = 0;
  270. for (auto var : replacements) {
  271. // Create the extract.
  272. if (var->opcode() != spv::Op::OpVariable) {
  273. elementIndex++;
  274. continue;
  275. }
  276. Instruction* type = GetStorageType(var);
  277. uint32_t extractId = TakeNextId();
  278. if (extractId == 0) {
  279. return false;
  280. }
  281. std::unique_ptr<Instruction> extract(new Instruction(
  282. context(), spv::Op::OpCompositeExtract, type->result_id(), extractId,
  283. std::initializer_list<Operand>{
  284. {SPV_OPERAND_TYPE_ID, {storeInput}},
  285. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
  286. auto iter = where.InsertBefore(std::move(extract));
  287. iter->UpdateDebugInfoFrom(store);
  288. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  289. context()->set_instr_block(&*iter, block);
  290. // Create the store.
  291. std::unique_ptr<Instruction> newStore(
  292. new Instruction(context(), spv::Op::OpStore, 0, 0,
  293. std::initializer_list<Operand>{
  294. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  295. {SPV_OPERAND_TYPE_ID, {extractId}}}));
  296. // Copy memory access attributes which start at index 2. Index 0 is the
  297. // pointer and index 1 is the data.
  298. for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
  299. Operand copy(store->GetInOperand(i));
  300. newStore->AddOperand(std::move(copy));
  301. }
  302. iter = where.InsertBefore(std::move(newStore));
  303. iter->UpdateDebugInfoFrom(store);
  304. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  305. context()->set_instr_block(&*iter, block);
  306. }
  307. return true;
  308. }
  309. bool ScalarReplacementPass::ReplaceAccessChain(
  310. Instruction* chain, const std::vector<Instruction*>& replacements) {
  311. // Replaces the access chain with either another access chain (with one fewer
  312. // indexes) or a direct use of the replacement variable.
  313. uint32_t indexId = chain->GetSingleWordInOperand(1u);
  314. const Instruction* index = get_def_use_mgr()->GetDef(indexId);
  315. int64_t indexValue = context()
  316. ->get_constant_mgr()
  317. ->GetConstantFromInst(index)
  318. ->GetSignExtendedValue();
  319. if (indexValue < 0 ||
  320. indexValue >= static_cast<int64_t>(replacements.size())) {
  321. // Out of bounds access, this is illegal IR. Notice that OpAccessChain
  322. // indexing is 0-based, so we should also reject index == size-of-array.
  323. return false;
  324. } else {
  325. const Instruction* var = replacements[static_cast<size_t>(indexValue)];
  326. if (chain->NumInOperands() > 2) {
  327. // Replace input access chain with another access chain.
  328. BasicBlock::iterator chainIter(chain);
  329. uint32_t replacementId = TakeNextId();
  330. if (replacementId == 0) {
  331. return false;
  332. }
  333. std::unique_ptr<Instruction> replacementChain(new Instruction(
  334. context(), chain->opcode(), chain->type_id(), replacementId,
  335. std::initializer_list<Operand>{
  336. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  337. // Add the remaining indexes.
  338. for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
  339. Operand copy(chain->GetInOperand(i));
  340. replacementChain->AddOperand(std::move(copy));
  341. }
  342. replacementChain->UpdateDebugInfoFrom(chain);
  343. auto iter = chainIter.InsertBefore(std::move(replacementChain));
  344. get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
  345. context()->set_instr_block(&*iter, context()->get_instr_block(chain));
  346. context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
  347. } else {
  348. // Replace with a use of the variable.
  349. context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
  350. }
  351. }
  352. return true;
  353. }
  354. bool ScalarReplacementPass::CreateReplacementVariables(
  355. Instruction* inst, std::vector<Instruction*>* replacements) {
  356. Instruction* type = GetStorageType(inst);
  357. std::unique_ptr<std::unordered_set<int64_t>> components_used =
  358. GetUsedComponents(inst);
  359. uint32_t elem = 0;
  360. switch (type->opcode()) {
  361. case spv::Op::OpTypeStruct:
  362. type->ForEachInOperand(
  363. [this, inst, &elem, replacements, &components_used](uint32_t* id) {
  364. if (!components_used || components_used->count(elem)) {
  365. CreateVariable(*id, inst, elem, replacements);
  366. } else {
  367. replacements->push_back(GetUndef(*id));
  368. }
  369. elem++;
  370. });
  371. break;
  372. case spv::Op::OpTypeArray:
  373. for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
  374. if (!components_used || components_used->count(i)) {
  375. CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
  376. replacements);
  377. } else {
  378. uint32_t element_type_id = type->GetSingleWordInOperand(0);
  379. replacements->push_back(GetUndef(element_type_id));
  380. }
  381. }
  382. break;
  383. case spv::Op::OpTypeMatrix:
  384. case spv::Op::OpTypeVector:
  385. for (uint32_t i = 0; i != GetNumElements(type); ++i) {
  386. CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
  387. }
  388. break;
  389. default:
  390. assert(false && "Unexpected type.");
  391. break;
  392. }
  393. TransferAnnotations(inst, replacements);
  394. return std::find(replacements->begin(), replacements->end(), nullptr) ==
  395. replacements->end();
  396. }
  397. Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) {
  398. return get_def_use_mgr()->GetDef(Type2Undef(type_id));
  399. }
  400. void ScalarReplacementPass::TransferAnnotations(
  401. const Instruction* source, std::vector<Instruction*>* replacements) {
  402. // Only transfer invariant and restrict decorations on the variable. There are
  403. // no type or member decorations that are necessary to transfer.
  404. for (auto inst :
  405. get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
  406. assert(inst->opcode() == spv::Op::OpDecorate);
  407. auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u));
  408. if (decoration == spv::Decoration::Invariant ||
  409. decoration == spv::Decoration::Restrict) {
  410. for (auto var : *replacements) {
  411. if (var == nullptr) {
  412. continue;
  413. }
  414. std::unique_ptr<Instruction> annotation(new Instruction(
  415. context(), spv::Op::OpDecorate, 0, 0,
  416. std::initializer_list<Operand>{
  417. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  418. {SPV_OPERAND_TYPE_DECORATION, {uint32_t(decoration)}}}));
  419. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
  420. Operand copy(inst->GetInOperand(i));
  421. annotation->AddOperand(std::move(copy));
  422. }
  423. context()->AddAnnotationInst(std::move(annotation));
  424. get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
  425. }
  426. }
  427. }
  428. }
  429. void ScalarReplacementPass::CreateVariable(
  430. uint32_t type_id, Instruction* var_inst, uint32_t index,
  431. std::vector<Instruction*>* replacements) {
  432. uint32_t ptr_id = GetOrCreatePointerType(type_id);
  433. uint32_t id = TakeNextId();
  434. if (id == 0) {
  435. replacements->push_back(nullptr);
  436. }
  437. std::unique_ptr<Instruction> variable(
  438. new Instruction(context(), spv::Op::OpVariable, ptr_id, id,
  439. std::initializer_list<Operand>{
  440. {SPV_OPERAND_TYPE_STORAGE_CLASS,
  441. {uint32_t(spv::StorageClass::Function)}}}));
  442. BasicBlock* block = context()->get_instr_block(var_inst);
  443. block->begin().InsertBefore(std::move(variable));
  444. Instruction* inst = &*block->begin();
  445. // If varInst was initialized, make sure to initialize its replacement.
  446. GetOrCreateInitialValue(var_inst, index, inst);
  447. get_def_use_mgr()->AnalyzeInstDefUse(inst);
  448. context()->set_instr_block(inst, block);
  449. CopyDecorationsToVariable(var_inst, inst, index);
  450. inst->UpdateDebugInfoFrom(var_inst);
  451. replacements->push_back(inst);
  452. }
  453. uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
  454. auto iter = pointee_to_pointer_.find(id);
  455. if (iter != pointee_to_pointer_.end()) return iter->second;
  456. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  457. uint32_t ptr_type_id =
  458. type_mgr->FindPointerToType(id, spv::StorageClass::Function);
  459. pointee_to_pointer_[id] = ptr_type_id;
  460. return ptr_type_id;
  461. }
  462. void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
  463. uint32_t index,
  464. Instruction* newVar) {
  465. assert(source->opcode() == spv::Op::OpVariable);
  466. if (source->NumInOperands() < 2) return;
  467. uint32_t initId = source->GetSingleWordInOperand(1u);
  468. uint32_t storageId = GetStorageType(newVar)->result_id();
  469. Instruction* init = get_def_use_mgr()->GetDef(initId);
  470. uint32_t newInitId = 0;
  471. // TODO(dnovillo): Refactor this with constant propagation.
  472. if (init->opcode() == spv::Op::OpConstantNull) {
  473. // Initialize to appropriate NULL.
  474. auto iter = type_to_null_.find(storageId);
  475. if (iter == type_to_null_.end()) {
  476. newInitId = TakeNextId();
  477. type_to_null_[storageId] = newInitId;
  478. context()->AddGlobalValue(
  479. MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, storageId,
  480. newInitId, std::initializer_list<Operand>{}));
  481. Instruction* newNull = &*--context()->types_values_end();
  482. get_def_use_mgr()->AnalyzeInstDefUse(newNull);
  483. } else {
  484. newInitId = iter->second;
  485. }
  486. } else if (IsSpecConstantInst(init->opcode())) {
  487. // Create a new constant extract.
  488. newInitId = TakeNextId();
  489. context()->AddGlobalValue(MakeUnique<Instruction>(
  490. context(), spv::Op::OpSpecConstantOp, storageId, newInitId,
  491. std::initializer_list<Operand>{
  492. {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER,
  493. {uint32_t(spv::Op::OpCompositeExtract)}},
  494. {SPV_OPERAND_TYPE_ID, {init->result_id()}},
  495. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
  496. Instruction* newSpecConst = &*--context()->types_values_end();
  497. get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
  498. } else if (init->opcode() == spv::Op::OpConstantComposite) {
  499. // Get the appropriate index constant.
  500. newInitId = init->GetSingleWordInOperand(index);
  501. Instruction* element = get_def_use_mgr()->GetDef(newInitId);
  502. if (element->opcode() == spv::Op::OpUndef) {
  503. // Undef is not a valid initializer for a variable.
  504. newInitId = 0;
  505. }
  506. } else {
  507. assert(false);
  508. }
  509. if (newInitId != 0) {
  510. newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
  511. }
  512. }
  513. uint64_t ScalarReplacementPass::GetArrayLength(
  514. const Instruction* arrayType) const {
  515. assert(arrayType->opcode() == spv::Op::OpTypeArray);
  516. const Instruction* length =
  517. get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
  518. return context()
  519. ->get_constant_mgr()
  520. ->GetConstantFromInst(length)
  521. ->GetZeroExtendedValue();
  522. }
  523. uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
  524. assert(type->opcode() == spv::Op::OpTypeVector ||
  525. type->opcode() == spv::Op::OpTypeMatrix);
  526. const Operand& op = type->GetInOperand(1u);
  527. assert(op.words.size() <= 2);
  528. uint64_t len = 0;
  529. for (size_t i = 0; i != op.words.size(); ++i) {
  530. len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
  531. }
  532. return len;
  533. }
  534. bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
  535. const Instruction* inst = get_def_use_mgr()->GetDef(id);
  536. assert(inst);
  537. return spvOpcodeIsSpecConstant(inst->opcode());
  538. }
  539. Instruction* ScalarReplacementPass::GetStorageType(
  540. const Instruction* inst) const {
  541. assert(inst->opcode() == spv::Op::OpVariable);
  542. uint32_t ptrTypeId = inst->type_id();
  543. uint32_t typeId =
  544. get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
  545. return get_def_use_mgr()->GetDef(typeId);
  546. }
  547. bool ScalarReplacementPass::CanReplaceVariable(
  548. const Instruction* varInst) const {
  549. assert(varInst->opcode() == spv::Op::OpVariable);
  550. // Can only replace function scope variables.
  551. if (spv::StorageClass(varInst->GetSingleWordInOperand(0u)) !=
  552. spv::StorageClass::Function) {
  553. return false;
  554. }
  555. if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
  556. return false;
  557. }
  558. const Instruction* typeInst = GetStorageType(varInst);
  559. if (!CheckType(typeInst)) {
  560. return false;
  561. }
  562. if (!CheckAnnotations(varInst)) {
  563. return false;
  564. }
  565. if (!CheckUses(varInst)) {
  566. return false;
  567. }
  568. return true;
  569. }
  570. bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
  571. if (!CheckTypeAnnotations(typeInst)) {
  572. return false;
  573. }
  574. switch (typeInst->opcode()) {
  575. case spv::Op::OpTypeStruct:
  576. // Don't bother with empty structs or very large structs.
  577. if (typeInst->NumInOperands() == 0 ||
  578. IsLargerThanSizeLimit(typeInst->NumInOperands())) {
  579. return false;
  580. }
  581. return true;
  582. case spv::Op::OpTypeArray:
  583. if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
  584. return false;
  585. }
  586. if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
  587. return false;
  588. }
  589. return true;
  590. // TODO(alanbaker): Develop some heuristics for when this should be
  591. // re-enabled.
  592. //// Specifically including matrix and vector in an attempt to reduce the
  593. //// number of vector registers required.
  594. // case spv::Op::OpTypeMatrix:
  595. // case spv::Op::OpTypeVector:
  596. // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
  597. // return true;
  598. case spv::Op::OpTypeRuntimeArray:
  599. default:
  600. return false;
  601. }
  602. }
  603. bool ScalarReplacementPass::CheckTypeAnnotations(
  604. const Instruction* typeInst) const {
  605. for (auto inst :
  606. get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
  607. uint32_t decoration;
  608. if (inst->opcode() == spv::Op::OpDecorate ||
  609. inst->opcode() == spv::Op::OpDecorateId) {
  610. decoration = inst->GetSingleWordInOperand(1u);
  611. } else {
  612. assert(inst->opcode() == spv::Op::OpMemberDecorate);
  613. decoration = inst->GetSingleWordInOperand(2u);
  614. }
  615. switch (spv::Decoration(decoration)) {
  616. case spv::Decoration::RowMajor:
  617. case spv::Decoration::ColMajor:
  618. case spv::Decoration::ArrayStride:
  619. case spv::Decoration::MatrixStride:
  620. case spv::Decoration::CPacked:
  621. case spv::Decoration::Invariant:
  622. case spv::Decoration::Restrict:
  623. case spv::Decoration::Offset:
  624. case spv::Decoration::Alignment:
  625. case spv::Decoration::AlignmentId:
  626. case spv::Decoration::MaxByteOffset:
  627. case spv::Decoration::RelaxedPrecision:
  628. case spv::Decoration::AliasedPointer:
  629. case spv::Decoration::RestrictPointer:
  630. break;
  631. default:
  632. return false;
  633. }
  634. }
  635. return true;
  636. }
  637. bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
  638. for (auto inst :
  639. get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
  640. assert(inst->opcode() == spv::Op::OpDecorate);
  641. auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u));
  642. switch (decoration) {
  643. case spv::Decoration::Invariant:
  644. case spv::Decoration::Restrict:
  645. case spv::Decoration::Alignment:
  646. case spv::Decoration::AlignmentId:
  647. case spv::Decoration::MaxByteOffset:
  648. case spv::Decoration::AliasedPointer:
  649. case spv::Decoration::RestrictPointer:
  650. break;
  651. default:
  652. return false;
  653. }
  654. }
  655. return true;
  656. }
  657. bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
  658. VariableStats stats = {0, 0};
  659. bool ok = CheckUses(inst, &stats);
  660. // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
  661. // SRoA is costly, such as when the structure has many (unaccessed?)
  662. // members.
  663. return ok;
  664. }
  665. bool ScalarReplacementPass::CheckUses(const Instruction* inst,
  666. VariableStats* stats) const {
  667. uint64_t max_legal_index = GetMaxLegalIndex(inst);
  668. bool ok = true;
  669. get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
  670. const Instruction* user,
  671. uint32_t index) {
  672. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare ||
  673. user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
  674. // TODO: include num_partial_accesses if it uses Fragment operation or
  675. // DebugValue has Indexes operand.
  676. stats->num_full_accesses++;
  677. return;
  678. }
  679. // Annotations are check as a group separately.
  680. if (!IsAnnotationInst(user->opcode())) {
  681. switch (user->opcode()) {
  682. case spv::Op::OpAccessChain:
  683. case spv::Op::OpInBoundsAccessChain:
  684. if (index == 2u && user->NumInOperands() > 1) {
  685. uint32_t id = user->GetSingleWordInOperand(1u);
  686. const Instruction* opInst = get_def_use_mgr()->GetDef(id);
  687. const auto* constant =
  688. context()->get_constant_mgr()->GetConstantFromInst(opInst);
  689. if (!constant) {
  690. ok = false;
  691. } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
  692. ok = false;
  693. } else {
  694. if (!CheckUsesRelaxed(user)) ok = false;
  695. }
  696. stats->num_partial_accesses++;
  697. } else {
  698. ok = false;
  699. }
  700. break;
  701. case spv::Op::OpLoad:
  702. if (!CheckLoad(user, index)) ok = false;
  703. stats->num_full_accesses++;
  704. break;
  705. case spv::Op::OpStore:
  706. if (!CheckStore(user, index)) ok = false;
  707. stats->num_full_accesses++;
  708. break;
  709. case spv::Op::OpName:
  710. case spv::Op::OpMemberName:
  711. break;
  712. default:
  713. ok = false;
  714. break;
  715. }
  716. }
  717. });
  718. return ok;
  719. }
  720. bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
  721. bool ok = true;
  722. get_def_use_mgr()->ForEachUse(
  723. inst, [this, &ok](const Instruction* user, uint32_t index) {
  724. switch (user->opcode()) {
  725. case spv::Op::OpAccessChain:
  726. case spv::Op::OpInBoundsAccessChain:
  727. if (index != 2u) {
  728. ok = false;
  729. } else {
  730. if (!CheckUsesRelaxed(user)) ok = false;
  731. }
  732. break;
  733. case spv::Op::OpLoad:
  734. if (!CheckLoad(user, index)) ok = false;
  735. break;
  736. case spv::Op::OpStore:
  737. if (!CheckStore(user, index)) ok = false;
  738. break;
  739. case spv::Op::OpImageTexelPointer:
  740. if (!CheckImageTexelPointer(index)) ok = false;
  741. break;
  742. case spv::Op::OpExtInst:
  743. if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare ||
  744. !CheckDebugDeclare(index))
  745. ok = false;
  746. break;
  747. default:
  748. ok = false;
  749. break;
  750. }
  751. });
  752. return ok;
  753. }
  754. bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
  755. return index == 2u;
  756. }
  757. bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
  758. uint32_t index) const {
  759. if (index != 2u) return false;
  760. if (inst->NumInOperands() >= 2 &&
  761. inst->GetSingleWordInOperand(1u) &
  762. uint32_t(spv::MemoryAccessMask::Volatile))
  763. return false;
  764. return true;
  765. }
  766. bool ScalarReplacementPass::CheckStore(const Instruction* inst,
  767. uint32_t index) const {
  768. if (index != 0u) return false;
  769. if (inst->NumInOperands() >= 3 &&
  770. inst->GetSingleWordInOperand(2u) &
  771. uint32_t(spv::MemoryAccessMask::Volatile))
  772. return false;
  773. return true;
  774. }
  775. bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const {
  776. if (index != kDebugDeclareOperandVariableIndex) return false;
  777. return true;
  778. }
  779. bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
  780. if (max_num_elements_ == 0) {
  781. return false;
  782. }
  783. return length > max_num_elements_;
  784. }
  785. std::unique_ptr<std::unordered_set<int64_t>>
  786. ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
  787. std::unique_ptr<std::unordered_set<int64_t>> result(
  788. new std::unordered_set<int64_t>());
  789. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  790. def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
  791. this](Instruction* use) {
  792. switch (use->opcode()) {
  793. case spv::Op::OpLoad: {
  794. // Look for extract from the load.
  795. std::vector<uint32_t> t;
  796. if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
  797. if (use2->opcode() != spv::Op::OpCompositeExtract ||
  798. use2->NumInOperands() <= 1) {
  799. return false;
  800. }
  801. t.push_back(use2->GetSingleWordInOperand(1));
  802. return true;
  803. })) {
  804. result->insert(t.begin(), t.end());
  805. return true;
  806. } else {
  807. result.reset(nullptr);
  808. return false;
  809. }
  810. }
  811. case spv::Op::OpName:
  812. case spv::Op::OpMemberName:
  813. case spv::Op::OpStore:
  814. // No components are used.
  815. return true;
  816. case spv::Op::OpAccessChain:
  817. case spv::Op::OpInBoundsAccessChain: {
  818. // Add the first index it if is a constant.
  819. // TODO: Could be improved by checking if the address is used in a load.
  820. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  821. uint32_t index_id = use->GetSingleWordInOperand(1);
  822. const analysis::Constant* index_const =
  823. const_mgr->FindDeclaredConstant(index_id);
  824. if (index_const) {
  825. result->insert(index_const->GetSignExtendedValue());
  826. return true;
  827. } else {
  828. // Could be any element. Assuming all are used.
  829. result.reset(nullptr);
  830. return false;
  831. }
  832. }
  833. default:
  834. // We do not know what is happening. Have to assume the worst.
  835. result.reset(nullptr);
  836. return false;
  837. }
  838. });
  839. return result;
  840. }
  841. uint64_t ScalarReplacementPass::GetMaxLegalIndex(
  842. const Instruction* var_inst) const {
  843. assert(var_inst->opcode() == spv::Op::OpVariable &&
  844. "|var_inst| must be a variable instruction.");
  845. Instruction* type = GetStorageType(var_inst);
  846. switch (type->opcode()) {
  847. case spv::Op::OpTypeStruct:
  848. return type->NumInOperands();
  849. case spv::Op::OpTypeArray:
  850. return GetArrayLength(type);
  851. case spv::Op::OpTypeMatrix:
  852. case spv::Op::OpTypeVector:
  853. return GetNumElements(type);
  854. default:
  855. return 0;
  856. }
  857. return 0;
  858. }
  859. void ScalarReplacementPass::CopyDecorationsToVariable(Instruction* from,
  860. Instruction* to,
  861. uint32_t member_index) {
  862. CopyPointerDecorationsToVariable(from, to);
  863. CopyNecessaryMemberDecorationsToVariable(from, to, member_index);
  864. }
  865. void ScalarReplacementPass::CopyPointerDecorationsToVariable(Instruction* from,
  866. Instruction* to) {
  867. // The RestrictPointer and AliasedPointer decorations are copied to all
  868. // members even if the new variable does not contain a pointer. It does
  869. // not hurt to do so.
  870. for (auto dec_inst :
  871. get_decoration_mgr()->GetDecorationsFor(from->result_id(), false)) {
  872. uint32_t decoration;
  873. decoration = dec_inst->GetSingleWordInOperand(1u);
  874. switch (spv::Decoration(decoration)) {
  875. case spv::Decoration::AliasedPointer:
  876. case spv::Decoration::RestrictPointer: {
  877. std::unique_ptr<Instruction> new_dec_inst(dec_inst->Clone(context()));
  878. new_dec_inst->SetInOperand(0, {to->result_id()});
  879. context()->AddAnnotationInst(std::move(new_dec_inst));
  880. } break;
  881. default:
  882. break;
  883. }
  884. }
  885. }
  886. void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable(
  887. Instruction* from, Instruction* to, uint32_t member_index) {
  888. Instruction* type_inst = GetStorageType(from);
  889. for (auto dec_inst :
  890. get_decoration_mgr()->GetDecorationsFor(type_inst->result_id(), false)) {
  891. uint32_t decoration;
  892. if (dec_inst->opcode() == spv::Op::OpMemberDecorate) {
  893. if (dec_inst->GetSingleWordInOperand(1) != member_index) {
  894. continue;
  895. }
  896. decoration = dec_inst->GetSingleWordInOperand(2u);
  897. switch (spv::Decoration(decoration)) {
  898. case spv::Decoration::ArrayStride:
  899. case spv::Decoration::Alignment:
  900. case spv::Decoration::AlignmentId:
  901. case spv::Decoration::MaxByteOffset:
  902. case spv::Decoration::MaxByteOffsetId:
  903. case spv::Decoration::RelaxedPrecision: {
  904. std::unique_ptr<Instruction> new_dec_inst(
  905. new Instruction(context(), spv::Op::OpDecorate, 0, 0, {}));
  906. new_dec_inst->AddOperand(
  907. Operand(SPV_OPERAND_TYPE_ID, {to->result_id()}));
  908. for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
  909. new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
  910. }
  911. context()->AddAnnotationInst(std::move(new_dec_inst));
  912. } break;
  913. default:
  914. break;
  915. }
  916. }
  917. }
  918. }
  919. } // namespace opt
  920. } // namespace spvtools