interface_var_sroa.cpp 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. // Copyright (c) 2022 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/interface_var_sroa.h"
  15. #include <iostream>
  16. #include "source/opt/decoration_manager.h"
  17. #include "source/opt/def_use_manager.h"
  18. #include "source/opt/function.h"
  19. #include "source/opt/log.h"
  20. #include "source/opt/type_manager.h"
  21. #include "source/util/make_unique.h"
  22. namespace spvtools {
  23. namespace opt {
  24. namespace {
  25. constexpr uint32_t kOpDecorateDecorationInOperandIndex = 1;
  26. constexpr uint32_t kOpDecorateLiteralInOperandIndex = 2;
  27. constexpr uint32_t kOpEntryPointInOperandInterface = 3;
  28. constexpr uint32_t kOpVariableStorageClassInOperandIndex = 0;
  29. constexpr uint32_t kOpTypeArrayElemTypeInOperandIndex = 0;
  30. constexpr uint32_t kOpTypeArrayLengthInOperandIndex = 1;
  31. constexpr uint32_t kOpTypeMatrixColCountInOperandIndex = 1;
  32. constexpr uint32_t kOpTypeMatrixColTypeInOperandIndex = 0;
  33. constexpr uint32_t kOpTypePtrTypeInOperandIndex = 1;
  34. constexpr uint32_t kOpConstantValueInOperandIndex = 0;
  35. // Get the length of the OpTypeArray |array_type|.
  36. uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr,
  37. Instruction* array_type) {
  38. assert(array_type->opcode() == spv::Op::OpTypeArray);
  39. uint32_t const_int_id =
  40. array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex);
  41. Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id);
  42. assert(array_length_inst->opcode() == spv::Op::OpConstant);
  43. return array_length_inst->GetSingleWordInOperand(
  44. kOpConstantValueInOperandIndex);
  45. }
  46. // Get the element type instruction of the OpTypeArray |array_type|.
  47. Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr,
  48. Instruction* array_type) {
  49. assert(array_type->opcode() == spv::Op::OpTypeArray);
  50. uint32_t elem_type_id =
  51. array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
  52. return def_use_mgr->GetDef(elem_type_id);
  53. }
  54. // Get the column type instruction of the OpTypeMatrix |matrix_type|.
  55. Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr,
  56. Instruction* matrix_type) {
  57. assert(matrix_type->opcode() == spv::Op::OpTypeMatrix);
  58. uint32_t column_type_id =
  59. matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
  60. return def_use_mgr->GetDef(column_type_id);
  61. }
  62. // Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it
  63. // |depth_to_component| times recursively and returns the component type.
  64. // |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction.
  65. uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr,
  66. uint32_t type_id,
  67. uint32_t depth_to_component) {
  68. if (depth_to_component == 0) return type_id;
  69. Instruction* type_inst = def_use_mgr->GetDef(type_id);
  70. if (type_inst->opcode() == spv::Op::OpTypeArray) {
  71. uint32_t elem_type_id =
  72. type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
  73. return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id,
  74. depth_to_component - 1);
  75. }
  76. assert(type_inst->opcode() == spv::Op::OpTypeMatrix);
  77. uint32_t column_type_id =
  78. type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
  79. return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id,
  80. depth_to_component - 1);
  81. }
  82. // Creates an OpDecorate instruction whose Target is |var_id| and Decoration is
  83. // |decoration|. Adds |literal| as an extra operand of the instruction.
  84. void CreateDecoration(analysis::DecorationManager* decoration_mgr,
  85. uint32_t var_id, spv::Decoration decoration,
  86. uint32_t literal) {
  87. std::vector<Operand> operands({
  88. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
  89. {spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION,
  90. {static_cast<uint32_t>(decoration)}},
  91. {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}},
  92. });
  93. decoration_mgr->AddDecoration(spv::Op::OpDecorate, std::move(operands));
  94. }
  95. // Replaces load instructions with composite construct instructions in all the
  96. // users of the loads. |loads_to_composites| is the mapping from each load to
  97. // its corresponding OpCompositeConstruct.
  98. void ReplaceLoadWithCompositeConstruct(
  99. IRContext* context,
  100. const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) {
  101. for (const auto& load_and_composite : loads_to_composites) {
  102. Instruction* load = load_and_composite.first;
  103. Instruction* composite_construct = load_and_composite.second;
  104. std::vector<Instruction*> users;
  105. context->get_def_use_mgr()->ForEachUse(
  106. load, [&users, composite_construct](Instruction* user, uint32_t index) {
  107. user->GetOperand(index).words[0] = composite_construct->result_id();
  108. users.push_back(user);
  109. });
  110. for (Instruction* user : users)
  111. context->get_def_use_mgr()->AnalyzeInstUse(user);
  112. }
  113. }
  114. // Returns the storage class of the instruction |var|.
  115. spv::StorageClass GetStorageClass(Instruction* var) {
  116. return static_cast<spv::StorageClass>(
  117. var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex));
  118. }
  119. } // namespace
  120. bool InterfaceVariableScalarReplacement::HasExtraArrayness(
  121. Instruction& entry_point, Instruction* var) {
  122. spv::ExecutionModel execution_model =
  123. static_cast<spv::ExecutionModel>(entry_point.GetSingleWordInOperand(0));
  124. if (execution_model != spv::ExecutionModel::TessellationEvaluation &&
  125. execution_model != spv::ExecutionModel::TessellationControl) {
  126. return false;
  127. }
  128. if (!context()->get_decoration_mgr()->HasDecoration(
  129. var->result_id(), uint32_t(spv::Decoration::Patch))) {
  130. if (execution_model == spv::ExecutionModel::TessellationControl)
  131. return true;
  132. return GetStorageClass(var) != spv::StorageClass::Output;
  133. }
  134. return false;
  135. }
  136. bool InterfaceVariableScalarReplacement::
  137. CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
  138. bool has_extra_arrayness) {
  139. if (has_extra_arrayness) {
  140. return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var);
  141. }
  142. return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var);
  143. }
  144. bool InterfaceVariableScalarReplacement::GetVariableLocation(
  145. Instruction* var, uint32_t* location) {
  146. return !context()->get_decoration_mgr()->WhileEachDecoration(
  147. var->result_id(), uint32_t(spv::Decoration::Location),
  148. [location](const Instruction& inst) {
  149. *location =
  150. inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
  151. return false;
  152. });
  153. }
  154. bool InterfaceVariableScalarReplacement::GetVariableComponent(
  155. Instruction* var, uint32_t* component) {
  156. return !context()->get_decoration_mgr()->WhileEachDecoration(
  157. var->result_id(), uint32_t(spv::Decoration::Component),
  158. [component](const Instruction& inst) {
  159. *component =
  160. inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
  161. return false;
  162. });
  163. }
  164. std::vector<Instruction*>
  165. InterfaceVariableScalarReplacement::CollectInterfaceVariables(
  166. Instruction& entry_point) {
  167. std::vector<Instruction*> interface_vars;
  168. for (uint32_t i = kOpEntryPointInOperandInterface;
  169. i < entry_point.NumInOperands(); ++i) {
  170. Instruction* interface_var = context()->get_def_use_mgr()->GetDef(
  171. entry_point.GetSingleWordInOperand(i));
  172. assert(interface_var->opcode() == spv::Op::OpVariable);
  173. spv::StorageClass storage_class = GetStorageClass(interface_var);
  174. if (storage_class != spv::StorageClass::Input &&
  175. storage_class != spv::StorageClass::Output) {
  176. continue;
  177. }
  178. interface_vars.push_back(interface_var);
  179. }
  180. return interface_vars;
  181. }
  182. void InterfaceVariableScalarReplacement::KillInstructionAndUsers(
  183. Instruction* inst) {
  184. if (inst->opcode() == spv::Op::OpEntryPoint) {
  185. return;
  186. }
  187. if (inst->opcode() != spv::Op::OpAccessChain) {
  188. context()->KillInst(inst);
  189. return;
  190. }
  191. std::vector<Instruction*> users;
  192. context()->get_def_use_mgr()->ForEachUser(
  193. inst, [&users](Instruction* user) { users.push_back(user); });
  194. for (auto user : users) {
  195. context()->KillInst(user);
  196. }
  197. context()->KillInst(inst);
  198. }
  199. void InterfaceVariableScalarReplacement::KillInstructionsAndUsers(
  200. const std::vector<Instruction*>& insts) {
  201. for (Instruction* inst : insts) {
  202. KillInstructionAndUsers(inst);
  203. }
  204. }
  205. void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
  206. uint32_t var_id) {
  207. context()->get_decoration_mgr()->RemoveDecorationsFrom(
  208. var_id, [](const Instruction& inst) {
  209. spv::Decoration decoration = spv::Decoration(
  210. inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex));
  211. return decoration == spv::Decoration::Location ||
  212. decoration == spv::Decoration::Component;
  213. });
  214. }
  215. bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
  216. Instruction* interface_var, Instruction* interface_var_type,
  217. uint32_t location, uint32_t component, uint32_t extra_array_length) {
  218. NestedCompositeComponents scalar_interface_vars =
  219. CreateScalarInterfaceVarsForReplacement(interface_var_type,
  220. GetStorageClass(interface_var),
  221. extra_array_length);
  222. AddLocationAndComponentDecorations(scalar_interface_vars, &location,
  223. component);
  224. KillLocationAndComponentDecorations(interface_var->result_id());
  225. if (!ReplaceInterfaceVarWith(interface_var, extra_array_length,
  226. scalar_interface_vars)) {
  227. return false;
  228. }
  229. context()->KillInst(interface_var);
  230. return true;
  231. }
  232. bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
  233. Instruction* interface_var, uint32_t extra_array_length,
  234. const NestedCompositeComponents& scalar_interface_vars) {
  235. std::vector<Instruction*> users;
  236. context()->get_def_use_mgr()->ForEachUser(
  237. interface_var, [&users](Instruction* user) { users.push_back(user); });
  238. std::vector<uint32_t> interface_var_component_indices;
  239. std::unordered_map<Instruction*, Instruction*> loads_to_composites;
  240. std::unordered_map<Instruction*, Instruction*>
  241. loads_for_access_chain_to_composites;
  242. if (extra_array_length != 0) {
  243. // Note that the extra arrayness is the first dimension of the array
  244. // interface variable.
  245. for (uint32_t index = 0; index < extra_array_length; ++index) {
  246. std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
  247. if (!ReplaceComponentsOfInterfaceVarWith(
  248. interface_var, users, scalar_interface_vars,
  249. interface_var_component_indices, &index,
  250. &loads_to_component_values,
  251. &loads_for_access_chain_to_composites)) {
  252. return false;
  253. }
  254. AddComponentsToCompositesForLoads(loads_to_component_values,
  255. &loads_to_composites, 0);
  256. }
  257. } else if (!ReplaceComponentsOfInterfaceVarWith(
  258. interface_var, users, scalar_interface_vars,
  259. interface_var_component_indices, nullptr, &loads_to_composites,
  260. &loads_for_access_chain_to_composites)) {
  261. return false;
  262. }
  263. ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
  264. ReplaceLoadWithCompositeConstruct(context(),
  265. loads_for_access_chain_to_composites);
  266. KillInstructionsAndUsers(users);
  267. return true;
  268. }
  269. void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
  270. const NestedCompositeComponents& vars, uint32_t* location,
  271. uint32_t component) {
  272. if (!vars.HasMultipleComponents()) {
  273. uint32_t var_id = vars.GetComponentVariable()->result_id();
  274. CreateDecoration(context()->get_decoration_mgr(), var_id,
  275. spv::Decoration::Location, *location);
  276. CreateDecoration(context()->get_decoration_mgr(), var_id,
  277. spv::Decoration::Component, component);
  278. ++(*location);
  279. return;
  280. }
  281. for (const auto& var : vars.GetComponents()) {
  282. AddLocationAndComponentDecorations(var, location, component);
  283. }
  284. }
  285. bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
  286. Instruction* interface_var,
  287. const std::vector<Instruction*>& interface_var_users,
  288. const NestedCompositeComponents& scalar_interface_vars,
  289. std::vector<uint32_t>& interface_var_component_indices,
  290. const uint32_t* extra_array_index,
  291. std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
  292. std::unordered_map<Instruction*, Instruction*>*
  293. loads_for_access_chain_to_composites) {
  294. if (!scalar_interface_vars.HasMultipleComponents()) {
  295. for (Instruction* interface_var_user : interface_var_users) {
  296. if (!ReplaceComponentOfInterfaceVarWith(
  297. interface_var, interface_var_user,
  298. scalar_interface_vars.GetComponentVariable(),
  299. interface_var_component_indices, extra_array_index,
  300. loads_to_composites, loads_for_access_chain_to_composites)) {
  301. return false;
  302. }
  303. }
  304. return true;
  305. }
  306. return ReplaceMultipleComponentsOfInterfaceVarWith(
  307. interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
  308. interface_var_component_indices, extra_array_index, loads_to_composites,
  309. loads_for_access_chain_to_composites);
  310. }
  311. bool InterfaceVariableScalarReplacement::
  312. ReplaceMultipleComponentsOfInterfaceVarWith(
  313. Instruction* interface_var,
  314. const std::vector<Instruction*>& interface_var_users,
  315. const std::vector<NestedCompositeComponents>& components,
  316. std::vector<uint32_t>& interface_var_component_indices,
  317. const uint32_t* extra_array_index,
  318. std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
  319. std::unordered_map<Instruction*, Instruction*>*
  320. loads_for_access_chain_to_composites) {
  321. for (uint32_t i = 0; i < components.size(); ++i) {
  322. interface_var_component_indices.push_back(i);
  323. std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
  324. std::unordered_map<Instruction*, Instruction*>
  325. loads_for_access_chain_to_component_values;
  326. if (!ReplaceComponentsOfInterfaceVarWith(
  327. interface_var, interface_var_users, components[i],
  328. interface_var_component_indices, extra_array_index,
  329. &loads_to_component_values,
  330. &loads_for_access_chain_to_component_values)) {
  331. return false;
  332. }
  333. interface_var_component_indices.pop_back();
  334. uint32_t depth_to_component =
  335. static_cast<uint32_t>(interface_var_component_indices.size());
  336. AddComponentsToCompositesForLoads(
  337. loads_for_access_chain_to_component_values,
  338. loads_for_access_chain_to_composites, depth_to_component);
  339. if (extra_array_index) ++depth_to_component;
  340. AddComponentsToCompositesForLoads(loads_to_component_values,
  341. loads_to_composites, depth_to_component);
  342. }
  343. return true;
  344. }
  345. bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
  346. Instruction* interface_var, Instruction* interface_var_user,
  347. Instruction* scalar_var,
  348. const std::vector<uint32_t>& interface_var_component_indices,
  349. const uint32_t* extra_array_index,
  350. std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
  351. std::unordered_map<Instruction*, Instruction*>*
  352. loads_for_access_chain_to_component_values) {
  353. spv::Op opcode = interface_var_user->opcode();
  354. if (opcode == spv::Op::OpStore) {
  355. uint32_t value_id = interface_var_user->GetSingleWordInOperand(1);
  356. StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
  357. scalar_var, extra_array_index,
  358. interface_var_user);
  359. return true;
  360. }
  361. if (opcode == spv::Op::OpLoad) {
  362. Instruction* scalar_load =
  363. LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
  364. loads_to_component_values->insert({interface_var_user, scalar_load});
  365. return true;
  366. }
  367. // Copy OpName and annotation instructions only once. Therefore, we create
  368. // them only for the first element of the extra array.
  369. if (extra_array_index && *extra_array_index != 0) return true;
  370. if (opcode == spv::Op::OpDecorateId || opcode == spv::Op::OpDecorateString ||
  371. opcode == spv::Op::OpDecorate) {
  372. CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
  373. return true;
  374. }
  375. if (opcode == spv::Op::OpName) {
  376. std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
  377. new_inst->SetInOperand(0, {scalar_var->result_id()});
  378. context()->AddDebug2Inst(std::move(new_inst));
  379. return true;
  380. }
  381. if (opcode == spv::Op::OpEntryPoint) {
  382. return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
  383. scalar_var->result_id());
  384. }
  385. if (opcode == spv::Op::OpAccessChain) {
  386. ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
  387. scalar_var,
  388. loads_for_access_chain_to_component_values);
  389. return true;
  390. }
  391. std::string message("Unhandled instruction");
  392. message += "\n " + interface_var_user->PrettyPrint(
  393. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  394. message +=
  395. "\nfor interface variable scalar replacement\n " +
  396. interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  397. context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
  398. return false;
  399. }
  400. void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
  401. Instruction* access_chain, Instruction* base_access_chain) {
  402. assert(base_access_chain->opcode() == spv::Op::OpAccessChain &&
  403. access_chain->opcode() == spv::Op::OpAccessChain &&
  404. access_chain->GetSingleWordInOperand(0) ==
  405. base_access_chain->result_id());
  406. Instruction::OperandList new_operands;
  407. for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) {
  408. new_operands.emplace_back(base_access_chain->GetInOperand(i));
  409. }
  410. for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
  411. new_operands.emplace_back(access_chain->GetInOperand(i));
  412. }
  413. access_chain->SetInOperands(std::move(new_operands));
  414. }
  415. Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
  416. uint32_t var_type_id, Instruction* var,
  417. const std::vector<uint32_t>& index_ids, Instruction* insert_before,
  418. uint32_t* component_type_id) {
  419. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  420. *component_type_id = GetComponentTypeOfArrayMatrix(
  421. def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size()));
  422. uint32_t ptr_type_id =
  423. GetPointerType(*component_type_id, GetStorageClass(var));
  424. std::unique_ptr<Instruction> new_access_chain(new Instruction(
  425. context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(),
  426. std::initializer_list<Operand>{
  427. {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
  428. for (uint32_t index_id : index_ids) {
  429. new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
  430. }
  431. Instruction* inst = new_access_chain.get();
  432. def_use_mgr->AnalyzeInstDefUse(inst);
  433. insert_before->InsertBefore(std::move(new_access_chain));
  434. return inst;
  435. }
  436. Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
  437. uint32_t component_type_id, Instruction* var, uint32_t index,
  438. Instruction* insert_before) {
  439. uint32_t ptr_type_id =
  440. GetPointerType(component_type_id, GetStorageClass(var));
  441. uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(index);
  442. std::unique_ptr<Instruction> new_access_chain(new Instruction(
  443. context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(),
  444. std::initializer_list<Operand>{
  445. {SPV_OPERAND_TYPE_ID, {var->result_id()}},
  446. {SPV_OPERAND_TYPE_ID, {index_id}},
  447. }));
  448. Instruction* inst = new_access_chain.get();
  449. context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
  450. insert_before->InsertBefore(std::move(new_access_chain));
  451. return inst;
  452. }
  453. void InterfaceVariableScalarReplacement::ReplaceAccessChainWith(
  454. Instruction* access_chain,
  455. const std::vector<uint32_t>& interface_var_component_indices,
  456. Instruction* scalar_var,
  457. std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) {
  458. std::vector<uint32_t> indexes;
  459. for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
  460. indexes.push_back(access_chain->GetSingleWordInOperand(i));
  461. }
  462. // Note that we have a strong assumption that |access_chain| has only a single
  463. // index that is for the extra arrayness.
  464. context()->get_def_use_mgr()->ForEachUser(
  465. access_chain,
  466. [this, access_chain, &indexes, &interface_var_component_indices,
  467. scalar_var, loads_to_component_values](Instruction* user) {
  468. switch (user->opcode()) {
  469. case spv::Op::OpAccessChain: {
  470. UseBaseAccessChainForAccessChain(user, access_chain);
  471. ReplaceAccessChainWith(user, interface_var_component_indices,
  472. scalar_var, loads_to_component_values);
  473. return;
  474. }
  475. case spv::Op::OpStore: {
  476. uint32_t value_id = user->GetSingleWordInOperand(1);
  477. StoreComponentOfValueToAccessChainToScalarVar(
  478. value_id, interface_var_component_indices, scalar_var, indexes,
  479. user);
  480. return;
  481. }
  482. case spv::Op::OpLoad: {
  483. Instruction* value =
  484. LoadAccessChainToVar(scalar_var, indexes, user);
  485. loads_to_component_values->insert({user, value});
  486. return;
  487. }
  488. default:
  489. break;
  490. }
  491. });
  492. }
  493. void InterfaceVariableScalarReplacement::CloneAnnotationForVariable(
  494. Instruction* annotation_inst, uint32_t var_id) {
  495. assert(annotation_inst->opcode() == spv::Op::OpDecorate ||
  496. annotation_inst->opcode() == spv::Op::OpDecorateId ||
  497. annotation_inst->opcode() == spv::Op::OpDecorateString);
  498. std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context()));
  499. new_inst->SetInOperand(0, {var_id});
  500. context()->AddAnnotationInst(std::move(new_inst));
  501. }
  502. bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint(
  503. Instruction* interface_var, Instruction* entry_point,
  504. uint32_t scalar_var_id) {
  505. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  506. uint32_t interface_var_id = interface_var->result_id();
  507. if (interface_vars_removed_from_entry_point_operands_.find(
  508. interface_var_id) !=
  509. interface_vars_removed_from_entry_point_operands_.end()) {
  510. entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}});
  511. def_use_mgr->AnalyzeInstUse(entry_point);
  512. return true;
  513. }
  514. bool success = !entry_point->WhileEachInId(
  515. [&interface_var_id, &scalar_var_id](uint32_t* id) {
  516. if (*id == interface_var_id) {
  517. *id = scalar_var_id;
  518. return false;
  519. }
  520. return true;
  521. });
  522. if (!success) {
  523. std::string message(
  524. "interface variable is not an operand of the entry point");
  525. message += "\n " + interface_var->PrettyPrint(
  526. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  527. message += "\n " + entry_point->PrettyPrint(
  528. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  529. context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
  530. return false;
  531. }
  532. def_use_mgr->AnalyzeInstUse(entry_point);
  533. interface_vars_removed_from_entry_point_operands_.insert(interface_var_id);
  534. return true;
  535. }
  536. uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar(
  537. Instruction* var) {
  538. assert(var->opcode() == spv::Op::OpVariable);
  539. uint32_t ptr_type_id = var->type_id();
  540. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  541. Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id);
  542. assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer &&
  543. "Variable must have a pointer type.");
  544. return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex);
  545. }
  546. void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
  547. uint32_t value_id, const std::vector<uint32_t>& component_indices,
  548. Instruction* scalar_var, const uint32_t* extra_array_index,
  549. Instruction* insert_before) {
  550. uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
  551. Instruction* ptr = scalar_var;
  552. if (extra_array_index) {
  553. auto* ty_mgr = context()->get_type_mgr();
  554. analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
  555. assert(array_type != nullptr);
  556. component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
  557. ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
  558. *extra_array_index, insert_before);
  559. }
  560. StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
  561. extra_array_index, insert_before);
  562. }
  563. Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
  564. Instruction* scalar_var, const uint32_t* extra_array_index,
  565. Instruction* insert_before) {
  566. uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
  567. Instruction* ptr = scalar_var;
  568. if (extra_array_index) {
  569. auto* ty_mgr = context()->get_type_mgr();
  570. analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
  571. assert(array_type != nullptr);
  572. component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
  573. ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
  574. *extra_array_index, insert_before);
  575. }
  576. return CreateLoad(component_type_id, ptr, insert_before);
  577. }
  578. Instruction* InterfaceVariableScalarReplacement::CreateLoad(
  579. uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
  580. std::unique_ptr<Instruction> load(
  581. new Instruction(context(), spv::Op::OpLoad, type_id, TakeNextId(),
  582. std::initializer_list<Operand>{
  583. {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
  584. Instruction* load_inst = load.get();
  585. context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst);
  586. insert_before->InsertBefore(std::move(load));
  587. return load_inst;
  588. }
  589. void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
  590. uint32_t component_type_id, uint32_t value_id,
  591. const std::vector<uint32_t>& component_indices, Instruction* ptr,
  592. const uint32_t* extra_array_index, Instruction* insert_before) {
  593. std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
  594. component_type_id, value_id, component_indices, extra_array_index));
  595. std::unique_ptr<Instruction> new_store(
  596. new Instruction(context(), spv::Op::OpStore));
  597. new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}});
  598. new_store->AddOperand(
  599. {SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}});
  600. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  601. def_use_mgr->AnalyzeInstDefUse(composite_extract.get());
  602. def_use_mgr->AnalyzeInstDefUse(new_store.get());
  603. insert_before->InsertBefore(std::move(composite_extract));
  604. insert_before->InsertBefore(std::move(new_store));
  605. }
  606. Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
  607. uint32_t type_id, uint32_t composite_id,
  608. const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
  609. uint32_t component_id = TakeNextId();
  610. Instruction* composite_extract = new Instruction(
  611. context(), spv::Op::OpCompositeExtract, type_id, component_id,
  612. std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
  613. if (extra_first_index) {
  614. composite_extract->AddOperand(
  615. {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}});
  616. }
  617. for (uint32_t index : indexes) {
  618. composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
  619. }
  620. return composite_extract;
  621. }
  622. void InterfaceVariableScalarReplacement::
  623. StoreComponentOfValueToAccessChainToScalarVar(
  624. uint32_t value_id, const std::vector<uint32_t>& component_indices,
  625. Instruction* scalar_var,
  626. const std::vector<uint32_t>& access_chain_indices,
  627. Instruction* insert_before) {
  628. uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
  629. Instruction* ptr = scalar_var;
  630. if (!access_chain_indices.empty()) {
  631. ptr = CreateAccessChainToVar(component_type_id, scalar_var,
  632. access_chain_indices, insert_before,
  633. &component_type_id);
  634. }
  635. StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
  636. nullptr, insert_before);
  637. }
  638. Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
  639. Instruction* var, const std::vector<uint32_t>& indexes,
  640. Instruction* insert_before) {
  641. uint32_t component_type_id = GetPointeeTypeIdOfVar(var);
  642. Instruction* ptr = var;
  643. if (!indexes.empty()) {
  644. ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
  645. &component_type_id);
  646. }
  647. return CreateLoad(component_type_id, ptr, insert_before);
  648. }
  649. Instruction*
  650. InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
  651. Instruction* load, uint32_t depth_to_component) {
  652. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  653. uint32_t type_id = load->type_id();
  654. if (depth_to_component != 0) {
  655. type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
  656. depth_to_component);
  657. }
  658. uint32_t new_id = context()->TakeNextId();
  659. std::unique_ptr<Instruction> new_composite_construct(new Instruction(
  660. context(), spv::Op::OpCompositeConstruct, type_id, new_id, {}));
  661. Instruction* composite_construct = new_composite_construct.get();
  662. def_use_mgr->AnalyzeInstDefUse(composite_construct);
  663. // Insert |new_composite_construct| after |load|. When there are multiple
  664. // recursive composite construct instructions for a load, we have to place the
  665. // composite construct with a lower depth later because it constructs the
  666. // composite that contains other composites with lower depths.
  667. auto* insert_before = load->NextNode();
  668. while (true) {
  669. auto itr =
  670. composite_ids_to_component_depths.find(insert_before->result_id());
  671. if (itr == composite_ids_to_component_depths.end()) break;
  672. if (itr->second <= depth_to_component) break;
  673. insert_before = insert_before->NextNode();
  674. }
  675. insert_before->InsertBefore(std::move(new_composite_construct));
  676. composite_ids_to_component_depths.insert({new_id, depth_to_component});
  677. return composite_construct;
  678. }
  679. void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
  680. const std::unordered_map<Instruction*, Instruction*>&
  681. loads_to_component_values,
  682. std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
  683. uint32_t depth_to_component) {
  684. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  685. for (auto& load_and_component_vale : loads_to_component_values) {
  686. Instruction* load = load_and_component_vale.first;
  687. Instruction* component_value = load_and_component_vale.second;
  688. Instruction* composite_construct = nullptr;
  689. auto itr = loads_to_composites->find(load);
  690. if (itr == loads_to_composites->end()) {
  691. composite_construct =
  692. CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
  693. loads_to_composites->insert({load, composite_construct});
  694. } else {
  695. composite_construct = itr->second;
  696. }
  697. composite_construct->AddOperand(
  698. {SPV_OPERAND_TYPE_ID, {component_value->result_id()}});
  699. def_use_mgr->AnalyzeInstDefUse(composite_construct);
  700. }
  701. }
  702. uint32_t InterfaceVariableScalarReplacement::GetArrayType(
  703. uint32_t elem_type_id, uint32_t array_length) {
  704. analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
  705. uint32_t array_length_id =
  706. context()->get_constant_mgr()->GetUIntConstId(array_length);
  707. analysis::Array array_type(
  708. elem_type,
  709. analysis::Array::LengthInfo{array_length_id, {0, array_length}});
  710. return context()->get_type_mgr()->GetTypeInstruction(&array_type);
  711. }
  712. uint32_t InterfaceVariableScalarReplacement::GetPointerType(
  713. uint32_t type_id, spv::StorageClass storage_class) {
  714. analysis::Type* type = context()->get_type_mgr()->GetType(type_id);
  715. analysis::Pointer ptr_type(type, storage_class);
  716. return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
  717. }
  718. InterfaceVariableScalarReplacement::NestedCompositeComponents
  719. InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
  720. Instruction* interface_var_type, spv::StorageClass storage_class,
  721. uint32_t extra_array_length) {
  722. assert(interface_var_type->opcode() == spv::Op::OpTypeArray);
  723. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  724. uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type);
  725. Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type);
  726. NestedCompositeComponents scalar_vars;
  727. while (array_length > 0) {
  728. NestedCompositeComponents scalar_vars_for_element =
  729. CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
  730. extra_array_length);
  731. scalar_vars.AddComponent(scalar_vars_for_element);
  732. --array_length;
  733. }
  734. return scalar_vars;
  735. }
  736. InterfaceVariableScalarReplacement::NestedCompositeComponents
  737. InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
  738. Instruction* interface_var_type, spv::StorageClass storage_class,
  739. uint32_t extra_array_length) {
  740. assert(interface_var_type->opcode() == spv::Op::OpTypeMatrix);
  741. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  742. uint32_t column_count = interface_var_type->GetSingleWordInOperand(
  743. kOpTypeMatrixColCountInOperandIndex);
  744. Instruction* column_type =
  745. GetMatrixColumnType(def_use_mgr, interface_var_type);
  746. NestedCompositeComponents scalar_vars;
  747. while (column_count > 0) {
  748. NestedCompositeComponents scalar_vars_for_column =
  749. CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
  750. extra_array_length);
  751. scalar_vars.AddComponent(scalar_vars_for_column);
  752. --column_count;
  753. }
  754. return scalar_vars;
  755. }
  756. InterfaceVariableScalarReplacement::NestedCompositeComponents
  757. InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
  758. Instruction* interface_var_type, spv::StorageClass storage_class,
  759. uint32_t extra_array_length) {
  760. // Handle array case.
  761. if (interface_var_type->opcode() == spv::Op::OpTypeArray) {
  762. return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class,
  763. extra_array_length);
  764. }
  765. // Handle matrix case.
  766. if (interface_var_type->opcode() == spv::Op::OpTypeMatrix) {
  767. return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class,
  768. extra_array_length);
  769. }
  770. // Handle scalar or vector case.
  771. NestedCompositeComponents scalar_var;
  772. uint32_t type_id = interface_var_type->result_id();
  773. if (extra_array_length != 0) {
  774. type_id = GetArrayType(type_id, extra_array_length);
  775. }
  776. uint32_t ptr_type_id =
  777. context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
  778. uint32_t id = TakeNextId();
  779. std::unique_ptr<Instruction> variable(
  780. new Instruction(context(), spv::Op::OpVariable, ptr_type_id, id,
  781. std::initializer_list<Operand>{
  782. {SPV_OPERAND_TYPE_STORAGE_CLASS,
  783. {static_cast<uint32_t>(storage_class)}}}));
  784. scalar_var.SetSingleComponentVariable(variable.get());
  785. context()->AddGlobalValue(std::move(variable));
  786. return scalar_var;
  787. }
  788. Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable(
  789. Instruction* var) {
  790. uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var);
  791. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  792. return def_use_mgr->GetDef(pointee_type_id);
  793. }
  794. Pass::Status InterfaceVariableScalarReplacement::Process() {
  795. Pass::Status status = Status::SuccessWithoutChange;
  796. for (Instruction& entry_point : get_module()->entry_points()) {
  797. status =
  798. CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point));
  799. }
  800. return status;
  801. }
  802. bool InterfaceVariableScalarReplacement::
  803. ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) {
  804. if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end())
  805. return false;
  806. std::string message(
  807. "A variable is arrayed for an entry point but it is not "
  808. "arrayed for another entry point");
  809. message +=
  810. "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  811. context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
  812. return true;
  813. }
  814. bool InterfaceVariableScalarReplacement::
  815. ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) {
  816. if (vars_without_extra_arrayness.find(var) ==
  817. vars_without_extra_arrayness.end())
  818. return false;
  819. std::string message(
  820. "A variable is not arrayed for an entry point but it is "
  821. "arrayed for another entry point");
  822. message +=
  823. "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
  824. context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
  825. return true;
  826. }
  827. Pass::Status
  828. InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
  829. Instruction& entry_point) {
  830. std::vector<Instruction*> interface_vars =
  831. CollectInterfaceVariables(entry_point);
  832. Pass::Status status = Status::SuccessWithoutChange;
  833. for (Instruction* interface_var : interface_vars) {
  834. uint32_t location, component;
  835. if (!GetVariableLocation(interface_var, &location)) continue;
  836. if (!GetVariableComponent(interface_var, &component)) component = 0;
  837. Instruction* interface_var_type = GetTypeOfVariable(interface_var);
  838. uint32_t extra_array_length = 0;
  839. if (HasExtraArrayness(entry_point, interface_var)) {
  840. extra_array_length =
  841. GetArrayLength(context()->get_def_use_mgr(), interface_var_type);
  842. interface_var_type =
  843. GetArrayElementType(context()->get_def_use_mgr(), interface_var_type);
  844. vars_with_extra_arrayness.insert(interface_var);
  845. } else {
  846. vars_without_extra_arrayness.insert(interface_var);
  847. }
  848. if (!CheckExtraArraynessConflictBetweenEntries(interface_var,
  849. extra_array_length != 0)) {
  850. return Pass::Status::Failure;
  851. }
  852. if (interface_var_type->opcode() != spv::Op::OpTypeArray &&
  853. interface_var_type->opcode() != spv::Op::OpTypeMatrix) {
  854. continue;
  855. }
  856. if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type,
  857. location, component,
  858. extra_array_length)) {
  859. return Pass::Status::Failure;
  860. }
  861. status = Pass::Status::SuccessWithChange;
  862. }
  863. return status;
  864. }
  865. } // namespace opt
  866. } // namespace spvtools