interface_var_sroa.cpp 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. // Copyright (c) 2022 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/interface_var_sroa.h"
  15. #include <iostream>
  16. #include "source/opt/decoration_manager.h"
  17. #include "source/opt/def_use_manager.h"
  18. #include "source/opt/function.h"
  19. #include "source/opt/log.h"
  20. #include "source/opt/type_manager.h"
  21. #include "source/util/make_unique.h"
  22. const static uint32_t kOpDecorateDecorationInOperandIndex = 1;
  23. const static uint32_t kOpDecorateLiteralInOperandIndex = 2;
  24. const static uint32_t kOpEntryPointInOperandInterface = 3;
  25. const static uint32_t kOpVariableStorageClassInOperandIndex = 0;
  26. const static uint32_t kOpTypeArrayElemTypeInOperandIndex = 0;
  27. const static uint32_t kOpTypeArrayLengthInOperandIndex = 1;
  28. const static uint32_t kOpTypeMatrixColCountInOperandIndex = 1;
  29. const static uint32_t kOpTypeMatrixColTypeInOperandIndex = 0;
  30. const static uint32_t kOpTypePtrTypeInOperandIndex = 1;
  31. const static uint32_t kOpConstantValueInOperandIndex = 0;
  32. namespace spvtools {
  33. namespace opt {
  34. namespace {
  35. // Get the length of the OpTypeArray |array_type|.
  36. uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr,
  37. Instruction* array_type) {
  38. assert(array_type->opcode() == SpvOpTypeArray);
  39. uint32_t const_int_id =
  40. array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex);
  41. Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id);
  42. assert(array_length_inst->opcode() == SpvOpConstant);
  43. return array_length_inst->GetSingleWordInOperand(
  44. kOpConstantValueInOperandIndex);
  45. }
  46. // Get the element type instruction of the OpTypeArray |array_type|.
  47. Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr,
  48. Instruction* array_type) {
  49. assert(array_type->opcode() == SpvOpTypeArray);
  50. uint32_t elem_type_id =
  51. array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
  52. return def_use_mgr->GetDef(elem_type_id);
  53. }
  54. // Get the column type instruction of the OpTypeMatrix |matrix_type|.
  55. Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr,
  56. Instruction* matrix_type) {
  57. assert(matrix_type->opcode() == SpvOpTypeMatrix);
  58. uint32_t column_type_id =
  59. matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
  60. return def_use_mgr->GetDef(column_type_id);
  61. }
  62. // Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it
  63. // |depth_to_component| times recursively and returns the component type.
  64. // |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction.
  65. uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr,
  66. uint32_t type_id,
  67. uint32_t depth_to_component) {
  68. if (depth_to_component == 0) return type_id;
  69. Instruction* type_inst = def_use_mgr->GetDef(type_id);
  70. if (type_inst->opcode() == SpvOpTypeArray) {
  71. uint32_t elem_type_id =
  72. type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
  73. return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id,
  74. depth_to_component - 1);
  75. }
  76. assert(type_inst->opcode() == SpvOpTypeMatrix);
  77. uint32_t column_type_id =
  78. type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
  79. return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id,
  80. depth_to_component - 1);
  81. }
  82. // Creates an OpDecorate instruction whose Target is |var_id| and Decoration is
  83. // |decoration|. Adds |literal| as an extra operand of the instruction.
  84. void CreateDecoration(analysis::DecorationManager* decoration_mgr,
  85. uint32_t var_id, SpvDecoration decoration,
  86. uint32_t literal) {
  87. std::vector<Operand> operands({
  88. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
  89. {spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION,
  90. {static_cast<uint32_t>(decoration)}},
  91. {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}},
  92. });
  93. decoration_mgr->AddDecoration(SpvOpDecorate, std::move(operands));
  94. }
  95. // Replaces load instructions with composite construct instructions in all the
  96. // users of the loads. |loads_to_composites| is the mapping from each load to
  97. // its corresponding OpCompositeConstruct.
  98. void ReplaceLoadWithCompositeConstruct(
  99. IRContext* context,
  100. const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) {
  101. for (const auto& load_and_composite : loads_to_composites) {
  102. Instruction* load = load_and_composite.first;
  103. Instruction* composite_construct = load_and_composite.second;
  104. std::vector<Instruction*> users;
  105. context->get_def_use_mgr()->ForEachUse(
  106. load, [&users, composite_construct](Instruction* user, uint32_t index) {
  107. user->GetOperand(index).words[0] = composite_construct->result_id();
  108. users.push_back(user);
  109. });
  110. for (Instruction* user : users)
  111. context->get_def_use_mgr()->AnalyzeInstUse(user);
  112. }
  113. }
  114. // Returns the storage class of the instruction |var|.
  115. SpvStorageClass GetStorageClass(Instruction* var) {
  116. return static_cast<SpvStorageClass>(
  117. var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex));
  118. }
  119. } // namespace
  120. bool InterfaceVariableScalarReplacement::HasExtraArrayness(
  121. Instruction& entry_point, Instruction* var) {
  122. SpvExecutionModel execution_model =
  123. static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
  124. if (execution_model != SpvExecutionModelTessellationEvaluation &&
  125. execution_model != SpvExecutionModelTessellationControl) {
  126. return false;
  127. }
  128. if (!context()->get_decoration_mgr()->HasDecoration(var->result_id(),
  129. SpvDecorationPatch)) {
  130. if (execution_model == SpvExecutionModelTessellationControl) return true;
  131. return GetStorageClass(var) != SpvStorageClassOutput;
  132. }
  133. return false;
  134. }
  135. bool InterfaceVariableScalarReplacement::
  136. CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
  137. bool has_extra_arrayness) {
  138. if (has_extra_arrayness) {
  139. return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var);
  140. }
  141. return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var);
  142. }
  143. bool InterfaceVariableScalarReplacement::GetVariableLocation(
  144. Instruction* var, uint32_t* location) {
  145. return !context()->get_decoration_mgr()->WhileEachDecoration(
  146. var->result_id(), SpvDecorationLocation,
  147. [location](const Instruction& inst) {
  148. *location =
  149. inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
  150. return false;
  151. });
  152. }
  153. bool InterfaceVariableScalarReplacement::GetVariableComponent(
  154. Instruction* var, uint32_t* component) {
  155. return !context()->get_decoration_mgr()->WhileEachDecoration(
  156. var->result_id(), SpvDecorationComponent,
  157. [component](const Instruction& inst) {
  158. *component =
  159. inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
  160. return false;
  161. });
  162. }
  163. std::vector<Instruction*>
  164. InterfaceVariableScalarReplacement::CollectInterfaceVariables(
  165. Instruction& entry_point) {
  166. std::vector<Instruction*> interface_vars;
  167. for (uint32_t i = kOpEntryPointInOperandInterface;
  168. i < entry_point.NumInOperands(); ++i) {
  169. Instruction* interface_var = context()->get_def_use_mgr()->GetDef(
  170. entry_point.GetSingleWordInOperand(i));
  171. assert(interface_var->opcode() == SpvOpVariable);
  172. SpvStorageClass storage_class = GetStorageClass(interface_var);
  173. if (storage_class != SpvStorageClassInput &&
  174. storage_class != SpvStorageClassOutput) {
  175. continue;
  176. }
  177. interface_vars.push_back(interface_var);
  178. }
  179. return interface_vars;
  180. }
  181. void InterfaceVariableScalarReplacement::KillInstructionAndUsers(
  182. Instruction* inst) {
  183. if (inst->opcode() == SpvOpEntryPoint) {
  184. return;
  185. }
  186. if (inst->opcode() != SpvOpAccessChain) {
  187. context()->KillInst(inst);
  188. return;
  189. }
  190. std::vector<Instruction*> users;
  191. context()->get_def_use_mgr()->ForEachUser(
  192. inst, [&users](Instruction* user) { users.push_back(user); });
  193. for (auto user : users) {
  194. context()->KillInst(user);
  195. }
  196. context()->KillInst(inst);
  197. }
  198. void InterfaceVariableScalarReplacement::KillInstructionsAndUsers(
  199. const std::vector<Instruction*>& insts) {
  200. for (Instruction* inst : insts) {
  201. KillInstructionAndUsers(inst);
  202. }
  203. }
  204. void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
  205. uint32_t var_id) {
  206. context()->get_decoration_mgr()->RemoveDecorationsFrom(
  207. var_id, [](const Instruction& inst) {
  208. uint32_t decoration =
  209. inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex);
  210. return decoration == SpvDecorationLocation ||
  211. decoration == SpvDecorationComponent;
  212. });
  213. }
  214. bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
  215. Instruction* interface_var, Instruction* interface_var_type,
  216. uint32_t location, uint32_t component, uint32_t extra_array_length) {
  217. NestedCompositeComponents scalar_interface_vars =
  218. CreateScalarInterfaceVarsForReplacement(interface_var_type,
  219. GetStorageClass(interface_var),
  220. extra_array_length);
  221. AddLocationAndComponentDecorations(scalar_interface_vars, &location,
  222. component);
  223. KillLocationAndComponentDecorations(interface_var->result_id());
  224. if (!ReplaceInterfaceVarWith(interface_var, extra_array_length,
  225. scalar_interface_vars)) {
  226. return false;
  227. }
  228. context()->KillInst(interface_var);
  229. return true;
  230. }
  231. bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
  232. Instruction* interface_var, uint32_t extra_array_length,
  233. const NestedCompositeComponents& scalar_interface_vars) {
  234. std::vector<Instruction*> users;
  235. context()->get_def_use_mgr()->ForEachUser(
  236. interface_var, [&users](Instruction* user) { users.push_back(user); });
  237. std::vector<uint32_t> interface_var_component_indices;
  238. std::unordered_map<Instruction*, Instruction*> loads_to_composites;
  239. std::unordered_map<Instruction*, Instruction*>
  240. loads_for_access_chain_to_composites;
  241. if (extra_array_length != 0) {
  242. // Note that the extra arrayness is the first dimension of the array
  243. // interface variable.
  244. for (uint32_t index = 0; index < extra_array_length; ++index) {
  245. std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
  246. if (!ReplaceComponentsOfInterfaceVarWith(
  247. interface_var, users, scalar_interface_vars,
  248. interface_var_component_indices, &index,
  249. &loads_to_component_values,
  250. &loads_for_access_chain_to_composites)) {
  251. return false;
  252. }
  253. AddComponentsToCompositesForLoads(loads_to_component_values,
  254. &loads_to_composites, 0);
  255. }
  256. } else if (!ReplaceComponentsOfInterfaceVarWith(
  257. interface_var, users, scalar_interface_vars,
  258. interface_var_component_indices, nullptr, &loads_to_composites,
  259. &loads_for_access_chain_to_composites)) {
  260. return false;
  261. }
  262. ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
  263. ReplaceLoadWithCompositeConstruct(context(),
  264. loads_for_access_chain_to_composites);
  265. KillInstructionsAndUsers(users);
  266. return true;
  267. }
  268. void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
  269. const NestedCompositeComponents& vars, uint32_t* location,
  270. uint32_t component) {
  271. if (!vars.HasMultipleComponents()) {
  272. uint32_t var_id = vars.GetComponentVariable()->result_id();
  273. CreateDecoration(context()->get_decoration_mgr(), var_id,
  274. SpvDecorationLocation, *location);
  275. CreateDecoration(context()->get_decoration_mgr(), var_id,
  276. SpvDecorationComponent, component);
  277. ++(*location);
  278. return;
  279. }
  280. for (const auto& var : vars.GetComponents()) {
  281. AddLocationAndComponentDecorations(var, location, component);
  282. }
  283. }
  284. bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
  285. Instruction* interface_var,
  286. const std::vector<Instruction*>& interface_var_users,
  287. const NestedCompositeComponents& scalar_interface_vars,
  288. std::vector<uint32_t>& interface_var_component_indices,
  289. const uint32_t* extra_array_index,
  290. std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
  291. std::unordered_map<Instruction*, Instruction*>*
  292. loads_for_access_chain_to_composites) {
  293. if (!scalar_interface_vars.HasMultipleComponents()) {
  294. for (Instruction* interface_var_user : interface_var_users) {
  295. if (!ReplaceComponentOfInterfaceVarWith(
  296. interface_var, interface_var_user,
  297. scalar_interface_vars.GetComponentVariable(),
  298. interface_var_component_indices, extra_array_index,
  299. loads_to_composites, loads_for_access_chain_to_composites)) {
  300. return false;
  301. }
  302. }
  303. return true;
  304. }
  305. return ReplaceMultipleComponentsOfInterfaceVarWith(
  306. interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
  307. interface_var_component_indices, extra_array_index, loads_to_composites,
  308. loads_for_access_chain_to_composites);
  309. }
  310. bool InterfaceVariableScalarReplacement::
  311. ReplaceMultipleComponentsOfInterfaceVarWith(
  312. Instruction* interface_var,
  313. const std::vector<Instruction*>& interface_var_users,
  314. const std::vector<NestedCompositeComponents>& components,
  315. std::vector<uint32_t>& interface_var_component_indices,
  316. const uint32_t* extra_array_index,
  317. std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
  318. std::unordered_map<Instruction*, Instruction*>*
  319. loads_for_access_chain_to_composites) {
  320. for (uint32_t i = 0; i < components.size(); ++i) {
  321. interface_var_component_indices.push_back(i);
  322. std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
  323. std::unordered_map<Instruction*, Instruction*>
  324. loads_for_access_chain_to_component_values;
  325. if (!ReplaceComponentsOfInterfaceVarWith(
  326. interface_var, interface_var_users, components[i],
  327. interface_var_component_indices, extra_array_index,
  328. &loads_to_component_values,
  329. &loads_for_access_chain_to_component_values)) {
  330. return false;
  331. }
  332. interface_var_component_indices.pop_back();
  333. uint32_t depth_to_component =
  334. static_cast<uint32_t>(interface_var_component_indices.size());
  335. AddComponentsToCompositesForLoads(
  336. loads_for_access_chain_to_component_values,
  337. loads_for_access_chain_to_composites, depth_to_component);
  338. if (extra_array_index) ++depth_to_component;
  339. AddComponentsToCompositesForLoads(loads_to_component_values,
  340. loads_to_composites, depth_to_component);
  341. }
  342. return true;
  343. }
  344. bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
  345. Instruction* interface_var, Instruction* interface_var_user,
  346. Instruction* scalar_var,
  347. const std::vector<uint32_t>& interface_var_component_indices,
  348. const uint32_t* extra_array_index,
  349. std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
  350. std::unordered_map<Instruction*, Instruction*>*
  351. loads_for_access_chain_to_component_values) {
  352. SpvOp opcode = interface_var_user->opcode();
  353. if (opcode == SpvOpStore) {
  354. uint32_t value_id = interface_var_user->GetSingleWordInOperand(1);
  355. StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
  356. scalar_var, extra_array_index,
  357. interface_var_user);
  358. return true;
  359. }
  360. if (opcode == SpvOpLoad) {
  361. Instruction* scalar_load =
  362. LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
  363. loads_to_component_values->insert({interface_var_user, scalar_load});
  364. return true;
  365. }
  366. // Copy OpName and annotation instructions only once. Therefore, we create
  367. // them only for the first element of the extra array.
  368. if (extra_array_index && *extra_array_index != 0) return true;
  369. if (opcode == SpvOpDecorateId || opcode == SpvOpDecorateString ||
  370. opcode == SpvOpDecorate) {
  371. CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
  372. return true;
  373. }
  374. if (opcode == SpvOpName) {
  375. std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
  376. new_inst->SetInOperand(0, {scalar_var->result_id()});
  377. context()->AddDebug2Inst(std::move(new_inst));
  378. return true;
  379. }
  380. if (opcode == SpvOpEntryPoint) {
  381. return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
  382. scalar_var->result_id());
  383. }
  384. if (opcode == SpvOpAccessChain) {
  385. ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
  386. scalar_var,
  387. loads_for_access_chain_to_component_values);
  388. return true;
  389. }
  390. std::string message("Unhandled instruction");
  391. message += "\n " + interface_var_user->PrettyPrint(
  392. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  393. message +=
  394. "\nfor interface variable scalar replacement\n " +
  395. interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  396. context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
  397. return false;
  398. }
  399. void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
  400. Instruction* access_chain, Instruction* base_access_chain) {
  401. assert(base_access_chain->opcode() == SpvOpAccessChain &&
  402. access_chain->opcode() == SpvOpAccessChain &&
  403. access_chain->GetSingleWordInOperand(0) ==
  404. base_access_chain->result_id());
  405. Instruction::OperandList new_operands;
  406. for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) {
  407. new_operands.emplace_back(base_access_chain->GetInOperand(i));
  408. }
  409. for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
  410. new_operands.emplace_back(access_chain->GetInOperand(i));
  411. }
  412. access_chain->SetInOperands(std::move(new_operands));
  413. }
  414. Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
  415. uint32_t var_type_id, Instruction* var,
  416. const std::vector<uint32_t>& index_ids, Instruction* insert_before,
  417. uint32_t* component_type_id) {
  418. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  419. *component_type_id = GetComponentTypeOfArrayMatrix(
  420. def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size()));
  421. uint32_t ptr_type_id =
  422. GetPointerType(*component_type_id, GetStorageClass(var));
  423. std::unique_ptr<Instruction> new_access_chain(
  424. new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
  425. std::initializer_list<Operand>{
  426. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  427. for (uint32_t index_id : index_ids) {
  428. new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
  429. }
  430. Instruction* inst = new_access_chain.get();
  431. def_use_mgr->AnalyzeInstDefUse(inst);
  432. insert_before->InsertBefore(std::move(new_access_chain));
  433. return inst;
  434. }
  435. Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
  436. uint32_t component_type_id, Instruction* var, uint32_t index,
  437. Instruction* insert_before) {
  438. uint32_t ptr_type_id =
  439. GetPointerType(component_type_id, GetStorageClass(var));
  440. uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index);
  441. std::unique_ptr<Instruction> new_access_chain(
  442. new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
  443. std::initializer_list<Operand>{
  444. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  445. {SPV_OPERAND_TYPE_ID, {index_id}},
  446. }));
  447. Instruction* inst = new_access_chain.get();
  448. context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
  449. insert_before->InsertBefore(std::move(new_access_chain));
  450. return inst;
  451. }
  452. void InterfaceVariableScalarReplacement::ReplaceAccessChainWith(
  453. Instruction* access_chain,
  454. const std::vector<uint32_t>& interface_var_component_indices,
  455. Instruction* scalar_var,
  456. std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) {
  457. std::vector<uint32_t> indexes;
  458. for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
  459. indexes.push_back(access_chain->GetSingleWordInOperand(i));
  460. }
  461. // Note that we have a strong assumption that |access_chain| has only a single
  462. // index that is for the extra arrayness.
  463. context()->get_def_use_mgr()->ForEachUser(
  464. access_chain,
  465. [this, access_chain, &indexes, &interface_var_component_indices,
  466. scalar_var, loads_to_component_values](Instruction* user) {
  467. switch (user->opcode()) {
  468. case SpvOpAccessChain: {
  469. UseBaseAccessChainForAccessChain(user, access_chain);
  470. ReplaceAccessChainWith(user, interface_var_component_indices,
  471. scalar_var, loads_to_component_values);
  472. return;
  473. }
  474. case SpvOpStore: {
  475. uint32_t value_id = user->GetSingleWordInOperand(1);
  476. StoreComponentOfValueToAccessChainToScalarVar(
  477. value_id, interface_var_component_indices, scalar_var, indexes,
  478. user);
  479. return;
  480. }
  481. case SpvOpLoad: {
  482. Instruction* value =
  483. LoadAccessChainToVar(scalar_var, indexes, user);
  484. loads_to_component_values->insert({user, value});
  485. return;
  486. }
  487. default:
  488. break;
  489. }
  490. });
  491. }
  492. void InterfaceVariableScalarReplacement::CloneAnnotationForVariable(
  493. Instruction* annotation_inst, uint32_t var_id) {
  494. assert(annotation_inst->opcode() == SpvOpDecorate ||
  495. annotation_inst->opcode() == SpvOpDecorateId ||
  496. annotation_inst->opcode() == SpvOpDecorateString);
  497. std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context()));
  498. new_inst->SetInOperand(0, {var_id});
  499. context()->AddAnnotationInst(std::move(new_inst));
  500. }
  501. bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint(
  502. Instruction* interface_var, Instruction* entry_point,
  503. uint32_t scalar_var_id) {
  504. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  505. uint32_t interface_var_id = interface_var->result_id();
  506. if (interface_vars_removed_from_entry_point_operands_.find(
  507. interface_var_id) !=
  508. interface_vars_removed_from_entry_point_operands_.end()) {
  509. entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}});
  510. def_use_mgr->AnalyzeInstUse(entry_point);
  511. return true;
  512. }
  513. bool success = !entry_point->WhileEachInId(
  514. [&interface_var_id, &scalar_var_id](uint32_t* id) {
  515. if (*id == interface_var_id) {
  516. *id = scalar_var_id;
  517. return false;
  518. }
  519. return true;
  520. });
  521. if (!success) {
  522. std::string message(
  523. "interface variable is not an operand of the entry point");
  524. message += "\n " + interface_var->PrettyPrint(
  525. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  526. message += "\n " + entry_point->PrettyPrint(
  527. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  528. context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
  529. return false;
  530. }
  531. def_use_mgr->AnalyzeInstUse(entry_point);
  532. interface_vars_removed_from_entry_point_operands_.insert(interface_var_id);
  533. return true;
  534. }
  535. uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar(
  536. Instruction* var) {
  537. assert(var->opcode() == SpvOpVariable);
  538. uint32_t ptr_type_id = var->type_id();
  539. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  540. Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id);
  541. assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
  542. "Variable must have a pointer type.");
  543. return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex);
  544. }
  545. void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
  546. uint32_t value_id, const std::vector<uint32_t>& component_indices,
  547. Instruction* scalar_var, const uint32_t* extra_array_index,
  548. Instruction* insert_before) {
  549. uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
  550. Instruction* ptr = scalar_var;
  551. if (extra_array_index) {
  552. auto* ty_mgr = context()->get_type_mgr();
  553. analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
  554. assert(array_type != nullptr);
  555. component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
  556. ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
  557. *extra_array_index, insert_before);
  558. }
  559. StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
  560. extra_array_index, insert_before);
  561. }
  562. Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
  563. Instruction* scalar_var, const uint32_t* extra_array_index,
  564. Instruction* insert_before) {
  565. uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
  566. Instruction* ptr = scalar_var;
  567. if (extra_array_index) {
  568. auto* ty_mgr = context()->get_type_mgr();
  569. analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
  570. assert(array_type != nullptr);
  571. component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
  572. ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
  573. *extra_array_index, insert_before);
  574. }
  575. return CreateLoad(component_type_id, ptr, insert_before);
  576. }
  577. Instruction* InterfaceVariableScalarReplacement::CreateLoad(
  578. uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
  579. std::unique_ptr<Instruction> load(
  580. new Instruction(context(), SpvOpLoad, type_id, TakeNextId(),
  581. std::initializer_list<Operand>{
  582. {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
  583. Instruction* load_inst = load.get();
  584. context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst);
  585. insert_before->InsertBefore(std::move(load));
  586. return load_inst;
  587. }
  588. void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
  589. uint32_t component_type_id, uint32_t value_id,
  590. const std::vector<uint32_t>& component_indices, Instruction* ptr,
  591. const uint32_t* extra_array_index, Instruction* insert_before) {
  592. std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
  593. component_type_id, value_id, component_indices, extra_array_index));
  594. std::unique_ptr<Instruction> new_store(
  595. new Instruction(context(), SpvOpStore));
  596. new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}});
  597. new_store->AddOperand(
  598. {SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}});
  599. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  600. def_use_mgr->AnalyzeInstDefUse(composite_extract.get());
  601. def_use_mgr->AnalyzeInstDefUse(new_store.get());
  602. insert_before->InsertBefore(std::move(composite_extract));
  603. insert_before->InsertBefore(std::move(new_store));
  604. }
  605. Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
  606. uint32_t type_id, uint32_t composite_id,
  607. const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
  608. uint32_t component_id = TakeNextId();
  609. Instruction* composite_extract = new Instruction(
  610. context(), SpvOpCompositeExtract, type_id, component_id,
  611. std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
  612. if (extra_first_index) {
  613. composite_extract->AddOperand(
  614. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}});
  615. }
  616. for (uint32_t index : indexes) {
  617. composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
  618. }
  619. return composite_extract;
  620. }
  621. void InterfaceVariableScalarReplacement::
  622. StoreComponentOfValueToAccessChainToScalarVar(
  623. uint32_t value_id, const std::vector<uint32_t>& component_indices,
  624. Instruction* scalar_var,
  625. const std::vector<uint32_t>& access_chain_indices,
  626. Instruction* insert_before) {
  627. uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
  628. Instruction* ptr = scalar_var;
  629. if (!access_chain_indices.empty()) {
  630. ptr = CreateAccessChainToVar(component_type_id, scalar_var,
  631. access_chain_indices, insert_before,
  632. &component_type_id);
  633. }
  634. StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
  635. nullptr, insert_before);
  636. }
  637. Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
  638. Instruction* var, const std::vector<uint32_t>& indexes,
  639. Instruction* insert_before) {
  640. uint32_t component_type_id = GetPointeeTypeIdOfVar(var);
  641. Instruction* ptr = var;
  642. if (!indexes.empty()) {
  643. ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
  644. &component_type_id);
  645. }
  646. return CreateLoad(component_type_id, ptr, insert_before);
  647. }
  648. Instruction*
  649. InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
  650. Instruction* load, uint32_t depth_to_component) {
  651. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  652. uint32_t type_id = load->type_id();
  653. if (depth_to_component != 0) {
  654. type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
  655. depth_to_component);
  656. }
  657. uint32_t new_id = context()->TakeNextId();
  658. std::unique_ptr<Instruction> new_composite_construct(
  659. new Instruction(context(), SpvOpCompositeConstruct, type_id, new_id, {}));
  660. Instruction* composite_construct = new_composite_construct.get();
  661. def_use_mgr->AnalyzeInstDefUse(composite_construct);
  662. // Insert |new_composite_construct| after |load|. When there are multiple
  663. // recursive composite construct instructions for a load, we have to place the
  664. // composite construct with a lower depth later because it constructs the
  665. // composite that contains other composites with lower depths.
  666. auto* insert_before = load->NextNode();
  667. while (true) {
  668. auto itr =
  669. composite_ids_to_component_depths.find(insert_before->result_id());
  670. if (itr == composite_ids_to_component_depths.end()) break;
  671. if (itr->second <= depth_to_component) break;
  672. insert_before = insert_before->NextNode();
  673. }
  674. insert_before->InsertBefore(std::move(new_composite_construct));
  675. composite_ids_to_component_depths.insert({new_id, depth_to_component});
  676. return composite_construct;
  677. }
  678. void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
  679. const std::unordered_map<Instruction*, Instruction*>&
  680. loads_to_component_values,
  681. std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
  682. uint32_t depth_to_component) {
  683. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  684. for (auto& load_and_component_vale : loads_to_component_values) {
  685. Instruction* load = load_and_component_vale.first;
  686. Instruction* component_value = load_and_component_vale.second;
  687. Instruction* composite_construct = nullptr;
  688. auto itr = loads_to_composites->find(load);
  689. if (itr == loads_to_composites->end()) {
  690. composite_construct =
  691. CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
  692. loads_to_composites->insert({load, composite_construct});
  693. } else {
  694. composite_construct = itr->second;
  695. }
  696. composite_construct->AddOperand(
  697. {SPV_OPERAND_TYPE_ID, {component_value->result_id()}});
  698. def_use_mgr->AnalyzeInstDefUse(composite_construct);
  699. }
  700. }
  701. uint32_t InterfaceVariableScalarReplacement::GetArrayType(
  702. uint32_t elem_type_id, uint32_t array_length) {
  703. analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
  704. uint32_t array_length_id =
  705. context()->get_constant_mgr()->GetUIntConst(array_length);
  706. analysis::Array array_type(
  707. elem_type,
  708. analysis::Array::LengthInfo{array_length_id, {0, array_length}});
  709. return context()->get_type_mgr()->GetTypeInstruction(&array_type);
  710. }
  711. uint32_t InterfaceVariableScalarReplacement::GetPointerType(
  712. uint32_t type_id, SpvStorageClass storage_class) {
  713. analysis::Type* type = context()->get_type_mgr()->GetType(type_id);
  714. analysis::Pointer ptr_type(type, storage_class);
  715. return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
  716. }
  717. InterfaceVariableScalarReplacement::NestedCompositeComponents
  718. InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
  719. Instruction* interface_var_type, SpvStorageClass storage_class,
  720. uint32_t extra_array_length) {
  721. assert(interface_var_type->opcode() == SpvOpTypeArray);
  722. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  723. uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type);
  724. Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type);
  725. NestedCompositeComponents scalar_vars;
  726. while (array_length > 0) {
  727. NestedCompositeComponents scalar_vars_for_element =
  728. CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
  729. extra_array_length);
  730. scalar_vars.AddComponent(scalar_vars_for_element);
  731. --array_length;
  732. }
  733. return scalar_vars;
  734. }
  735. InterfaceVariableScalarReplacement::NestedCompositeComponents
  736. InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
  737. Instruction* interface_var_type, SpvStorageClass storage_class,
  738. uint32_t extra_array_length) {
  739. assert(interface_var_type->opcode() == SpvOpTypeMatrix);
  740. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  741. uint32_t column_count = interface_var_type->GetSingleWordInOperand(
  742. kOpTypeMatrixColCountInOperandIndex);
  743. Instruction* column_type =
  744. GetMatrixColumnType(def_use_mgr, interface_var_type);
  745. NestedCompositeComponents scalar_vars;
  746. while (column_count > 0) {
  747. NestedCompositeComponents scalar_vars_for_column =
  748. CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
  749. extra_array_length);
  750. scalar_vars.AddComponent(scalar_vars_for_column);
  751. --column_count;
  752. }
  753. return scalar_vars;
  754. }
  755. InterfaceVariableScalarReplacement::NestedCompositeComponents
  756. InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
  757. Instruction* interface_var_type, SpvStorageClass storage_class,
  758. uint32_t extra_array_length) {
  759. // Handle array case.
  760. if (interface_var_type->opcode() == SpvOpTypeArray) {
  761. return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class,
  762. extra_array_length);
  763. }
  764. // Handle matrix case.
  765. if (interface_var_type->opcode() == SpvOpTypeMatrix) {
  766. return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class,
  767. extra_array_length);
  768. }
  769. // Handle scalar or vector case.
  770. NestedCompositeComponents scalar_var;
  771. uint32_t type_id = interface_var_type->result_id();
  772. if (extra_array_length != 0) {
  773. type_id = GetArrayType(type_id, extra_array_length);
  774. }
  775. uint32_t ptr_type_id =
  776. context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
  777. uint32_t id = TakeNextId();
  778. std::unique_ptr<Instruction> variable(
  779. new Instruction(context(), SpvOpVariable, ptr_type_id, id,
  780. std::initializer_list<Operand>{
  781. {SPV_OPERAND_TYPE_STORAGE_CLASS,
  782. {static_cast<uint32_t>(storage_class)}}}));
  783. scalar_var.SetSingleComponentVariable(variable.get());
  784. context()->AddGlobalValue(std::move(variable));
  785. return scalar_var;
  786. }
  787. Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable(
  788. Instruction* var) {
  789. uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var);
  790. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  791. return def_use_mgr->GetDef(pointee_type_id);
  792. }
  793. Pass::Status InterfaceVariableScalarReplacement::Process() {
  794. Pass::Status status = Status::SuccessWithoutChange;
  795. for (Instruction& entry_point : get_module()->entry_points()) {
  796. status =
  797. CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point));
  798. }
  799. return status;
  800. }
  801. bool InterfaceVariableScalarReplacement::
  802. ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) {
  803. if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end())
  804. return false;
  805. std::string message(
  806. "A variable is arrayed for an entry point but it is not "
  807. "arrayed for another entry point");
  808. message +=
  809. "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  810. context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
  811. return true;
  812. }
  813. bool InterfaceVariableScalarReplacement::
  814. ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) {
  815. if (vars_without_extra_arrayness.find(var) ==
  816. vars_without_extra_arrayness.end())
  817. return false;
  818. std::string message(
  819. "A variable is not arrayed for an entry point but it is "
  820. "arrayed for another entry point");
  821. message +=
  822. "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  823. context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
  824. return true;
  825. }
  826. Pass::Status
  827. InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
  828. Instruction& entry_point) {
  829. std::vector<Instruction*> interface_vars =
  830. CollectInterfaceVariables(entry_point);
  831. Pass::Status status = Status::SuccessWithoutChange;
  832. for (Instruction* interface_var : interface_vars) {
  833. uint32_t location, component;
  834. if (!GetVariableLocation(interface_var, &location)) continue;
  835. if (!GetVariableComponent(interface_var, &component)) component = 0;
  836. Instruction* interface_var_type = GetTypeOfVariable(interface_var);
  837. uint32_t extra_array_length = 0;
  838. if (HasExtraArrayness(entry_point, interface_var)) {
  839. extra_array_length =
  840. GetArrayLength(context()->get_def_use_mgr(), interface_var_type);
  841. interface_var_type =
  842. GetArrayElementType(context()->get_def_use_mgr(), interface_var_type);
  843. vars_with_extra_arrayness.insert(interface_var);
  844. } else {
  845. vars_without_extra_arrayness.insert(interface_var);
  846. }
  847. if (!CheckExtraArraynessConflictBetweenEntries(interface_var,
  848. extra_array_length != 0)) {
  849. return Pass::Status::Failure;
  850. }
  851. if (interface_var_type->opcode() != SpvOpTypeArray &&
  852. interface_var_type->opcode() != SpvOpTypeMatrix) {
  853. continue;
  854. }
  855. if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type,
  856. location, component,
  857. extra_array_length)) {
  858. return Pass::Status::Failure;
  859. }
  860. status = Pass::Status::SuccessWithChange;
  861. }
  862. return status;
  863. }
  864. } // namespace opt
  865. } // namespace spvtools