local_access_chain_convert_pass.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. // Copyright (c) 2017 The Khronos Group Inc.
  2. // Copyright (c) 2017 Valve Corporation
  3. // Copyright (c) 2017 LunarG Inc.
  4. // Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
  5. // reserved.
  6. //
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. //
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. //
  13. // Unless required by applicable law or agreed to in writing, software
  14. // distributed under the License is distributed on an "AS IS" BASIS,
  15. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. // See the License for the specific language governing permissions and
  17. // limitations under the License.
  18. #include "source/opt/local_access_chain_convert_pass.h"
  19. #include "ir_context.h"
  20. #include "iterator.h"
  21. #include "source/util/string_utils.h"
  22. namespace spvtools {
  23. namespace opt {
  24. namespace {
  25. constexpr uint32_t kStoreValIdInIdx = 1;
  26. constexpr uint32_t kAccessChainPtrIdInIdx = 0;
  27. } // namespace
  28. void LocalAccessChainConvertPass::BuildAndAppendInst(
  29. spv::Op opcode, uint32_t typeId, uint32_t resultId,
  30. const std::vector<Operand>& in_opnds,
  31. std::vector<std::unique_ptr<Instruction>>* newInsts) {
  32. std::unique_ptr<Instruction> newInst(
  33. new Instruction(context(), opcode, typeId, resultId, in_opnds));
  34. get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
  35. newInsts->emplace_back(std::move(newInst));
  36. }
  37. uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
  38. const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
  39. std::vector<std::unique_ptr<Instruction>>* newInsts) {
  40. const uint32_t ldResultId = TakeNextId();
  41. if (ldResultId == 0) {
  42. return 0;
  43. }
  44. *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
  45. const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
  46. assert(varInst->opcode() == spv::Op::OpVariable);
  47. *varPteTypeId = GetPointeeTypeId(varInst);
  48. BuildAndAppendInst(spv::Op::OpLoad, *varPteTypeId, ldResultId,
  49. {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
  50. newInsts);
  51. return ldResultId;
  52. }
  53. void LocalAccessChainConvertPass::AppendConstantOperands(
  54. const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
  55. uint32_t iidIdx = 0;
  56. ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
  57. if (iidIdx > 0) {
  58. const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
  59. const auto* constant_value =
  60. context()->get_constant_mgr()->GetConstantFromInst(cInst);
  61. assert(constant_value != nullptr &&
  62. "Expecting the index to be a constant.");
  63. // We take the sign extended value because OpAccessChain interprets the
  64. // index as signed.
  65. int64_t long_value = constant_value->GetSignExtendedValue();
  66. assert(long_value <= UINT32_MAX && long_value >= 0 &&
  67. "The index value is too large for a composite insert or extract "
  68. "instruction.");
  69. uint32_t val = static_cast<uint32_t>(long_value);
  70. in_opnds->push_back(
  71. {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
  72. }
  73. ++iidIdx;
  74. });
  75. }
  76. bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
  77. const Instruction* address_inst, Instruction* original_load) {
  78. // Build and append load of variable in ptrInst
  79. if (address_inst->NumInOperands() == 1) {
  80. // An access chain with no indices is essentially a copy. All that is
  81. // needed is to propagate the address.
  82. context()->ReplaceAllUsesWith(
  83. address_inst->result_id(),
  84. address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx));
  85. return true;
  86. }
  87. std::vector<std::unique_ptr<Instruction>> new_inst;
  88. uint32_t varId;
  89. uint32_t varPteTypeId;
  90. const uint32_t ldResultId =
  91. BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
  92. if (ldResultId == 0) {
  93. return false;
  94. }
  95. new_inst[0]->UpdateDebugInfoFrom(original_load);
  96. context()->get_decoration_mgr()->CloneDecorations(
  97. original_load->result_id(), ldResultId,
  98. {spv::Decoration::RelaxedPrecision});
  99. original_load->InsertBefore(std::move(new_inst));
  100. context()->get_debug_info_mgr()->AnalyzeDebugInst(
  101. original_load->PreviousNode());
  102. // Rewrite |original_load| into an extract.
  103. Instruction::OperandList new_operands;
  104. // copy the result id and the type id to the new operand list.
  105. new_operands.emplace_back(original_load->GetOperand(0));
  106. new_operands.emplace_back(original_load->GetOperand(1));
  107. new_operands.emplace_back(
  108. Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
  109. AppendConstantOperands(address_inst, &new_operands);
  110. original_load->SetOpcode(spv::Op::OpCompositeExtract);
  111. original_load->ReplaceOperands(new_operands);
  112. context()->UpdateDefUse(original_load);
  113. return true;
  114. }
  115. bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
  116. const Instruction* ptrInst, uint32_t valId,
  117. std::vector<std::unique_ptr<Instruction>>* newInsts) {
  118. if (ptrInst->NumInOperands() == 1) {
  119. // An access chain with no indices is essentially a copy. However, we still
  120. // have to create a new store because the old ones will be deleted.
  121. BuildAndAppendInst(
  122. spv::Op::OpStore, 0, 0,
  123. {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
  124. {ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}},
  125. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}},
  126. newInsts);
  127. return true;
  128. }
  129. // Build and append load of variable in ptrInst
  130. uint32_t varId;
  131. uint32_t varPteTypeId;
  132. const uint32_t ldResultId =
  133. BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
  134. if (ldResultId == 0) {
  135. return false;
  136. }
  137. context()->get_decoration_mgr()->CloneDecorations(
  138. varId, ldResultId, {spv::Decoration::RelaxedPrecision});
  139. // Build and append Insert
  140. const uint32_t insResultId = TakeNextId();
  141. if (insResultId == 0) {
  142. return false;
  143. }
  144. std::vector<Operand> ins_in_opnds = {
  145. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
  146. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
  147. AppendConstantOperands(ptrInst, &ins_in_opnds);
  148. BuildAndAppendInst(spv::Op::OpCompositeInsert, varPteTypeId, insResultId,
  149. ins_in_opnds, newInsts);
  150. context()->get_decoration_mgr()->CloneDecorations(
  151. varId, insResultId, {spv::Decoration::RelaxedPrecision});
  152. // Build and append Store
  153. BuildAndAppendInst(spv::Op::OpStore, 0, 0,
  154. {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
  155. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
  156. newInsts);
  157. return true;
  158. }
  159. bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain(
  160. const Instruction* acp) const {
  161. uint32_t inIdx = 0;
  162. return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
  163. if (inIdx > 0) {
  164. Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
  165. if (opInst->opcode() != spv::Op::OpConstant) return false;
  166. const auto* index =
  167. context()->get_constant_mgr()->GetConstantFromInst(opInst);
  168. int64_t index_value = index->GetSignExtendedValue();
  169. if (index_value > UINT32_MAX) return false;
  170. if (index_value < 0) return false;
  171. }
  172. ++inIdx;
  173. return true;
  174. });
  175. }
  176. bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
  177. if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
  178. if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
  179. if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue ||
  180. user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
  181. return true;
  182. }
  183. spv::Op op = user->opcode();
  184. if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) {
  185. if (!HasOnlySupportedRefs(user->result_id())) {
  186. return false;
  187. }
  188. } else if (op != spv::Op::OpStore && op != spv::Op::OpLoad &&
  189. op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
  190. return false;
  191. }
  192. return true;
  193. })) {
  194. supported_ref_ptrs_.insert(ptrId);
  195. return true;
  196. }
  197. return false;
  198. }
  199. void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
  200. for (auto bi = func->begin(); bi != func->end(); ++bi) {
  201. for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
  202. switch (ii->opcode()) {
  203. case spv::Op::OpStore:
  204. case spv::Op::OpLoad: {
  205. uint32_t varId;
  206. Instruction* ptrInst = GetPtr(&*ii, &varId);
  207. if (!IsTargetVar(varId)) break;
  208. const spv::Op op = ptrInst->opcode();
  209. // Rule out variables with non-supported refs eg function calls
  210. if (!HasOnlySupportedRefs(varId)) {
  211. seen_non_target_vars_.insert(varId);
  212. seen_target_vars_.erase(varId);
  213. break;
  214. }
  215. // Rule out variables with nested access chains
  216. // TODO(): Convert nested access chains
  217. bool is_non_ptr_access_chain = IsNonPtrAccessChain(op);
  218. if (is_non_ptr_access_chain && ptrInst->GetSingleWordInOperand(
  219. kAccessChainPtrIdInIdx) != varId) {
  220. seen_non_target_vars_.insert(varId);
  221. seen_target_vars_.erase(varId);
  222. break;
  223. }
  224. // Rule out variables accessed with non-constant indices
  225. if (!Is32BitConstantIndexAccessChain(ptrInst)) {
  226. seen_non_target_vars_.insert(varId);
  227. seen_target_vars_.erase(varId);
  228. break;
  229. }
  230. if (is_non_ptr_access_chain && AnyIndexIsOutOfBounds(ptrInst)) {
  231. seen_non_target_vars_.insert(varId);
  232. seen_target_vars_.erase(varId);
  233. break;
  234. }
  235. } break;
  236. default:
  237. break;
  238. }
  239. }
  240. }
  241. }
  242. Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
  243. Function* func) {
  244. FindTargetVars(func);
  245. // Replace access chains of all targeted variables with equivalent
  246. // extract and insert sequences
  247. bool modified = false;
  248. for (auto bi = func->begin(); bi != func->end(); ++bi) {
  249. std::vector<Instruction*> dead_instructions;
  250. for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
  251. switch (ii->opcode()) {
  252. case spv::Op::OpLoad: {
  253. uint32_t varId;
  254. Instruction* ptrInst = GetPtr(&*ii, &varId);
  255. if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
  256. if (!IsTargetVar(varId)) break;
  257. if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
  258. return Status::Failure;
  259. }
  260. modified = true;
  261. } break;
  262. case spv::Op::OpStore: {
  263. uint32_t varId;
  264. Instruction* store = &*ii;
  265. Instruction* ptrInst = GetPtr(store, &varId);
  266. if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
  267. if (!IsTargetVar(varId)) break;
  268. std::vector<std::unique_ptr<Instruction>> newInsts;
  269. uint32_t valId = store->GetSingleWordInOperand(kStoreValIdInIdx);
  270. if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
  271. return Status::Failure;
  272. }
  273. size_t num_of_instructions_to_skip = newInsts.size() - 1;
  274. dead_instructions.push_back(store);
  275. ++ii;
  276. ii = ii.InsertBefore(std::move(newInsts));
  277. for (size_t i = 0; i < num_of_instructions_to_skip; ++i) {
  278. ii->UpdateDebugInfoFrom(store);
  279. context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
  280. ++ii;
  281. }
  282. ii->UpdateDebugInfoFrom(store);
  283. context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
  284. modified = true;
  285. } break;
  286. default:
  287. break;
  288. }
  289. }
  290. while (!dead_instructions.empty()) {
  291. Instruction* inst = dead_instructions.back();
  292. dead_instructions.pop_back();
  293. DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
  294. auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
  295. other_inst);
  296. if (i != dead_instructions.end()) {
  297. dead_instructions.erase(i);
  298. }
  299. });
  300. }
  301. }
  302. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  303. }
  304. void LocalAccessChainConvertPass::Initialize() {
  305. // Initialize Target Variable Caches
  306. seen_target_vars_.clear();
  307. seen_non_target_vars_.clear();
  308. // Initialize collections
  309. supported_ref_ptrs_.clear();
  310. // Initialize extension allowlist
  311. InitExtensions();
  312. }
  313. bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
  314. // This capability can now exist without the extension, so we have to check
  315. // for the capability. This pass is only looking at function scope symbols,
  316. // so we do not care if there are variable pointers on storage buffers.
  317. if (context()->get_feature_mgr()->HasCapability(
  318. spv::Capability::VariablePointers))
  319. return false;
  320. // If any extension not in allowlist, return false
  321. for (auto& ei : get_module()->extensions()) {
  322. const std::string extName = ei.GetInOperand(0).AsString();
  323. if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
  324. return false;
  325. }
  326. // only allow NonSemantic.Shader.DebugInfo.100, we cannot safely optimise
  327. // around unknown extended
  328. // instruction sets even if they are non-semantic
  329. for (auto& inst : context()->module()->ext_inst_imports()) {
  330. assert(inst.opcode() == spv::Op::OpExtInstImport &&
  331. "Expecting an import of an extension's instruction set.");
  332. const std::string extension_name = inst.GetInOperand(0).AsString();
  333. if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
  334. extension_name != "NonSemantic.Shader.DebugInfo.100") {
  335. return false;
  336. }
  337. }
  338. return true;
  339. }
  340. Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
  341. // Do not process if module contains OpGroupDecorate. Additional
  342. // support required in KillNamesAndDecorates().
  343. // TODO(greg-lunarg): Add support for OpGroupDecorate
  344. for (auto& ai : get_module()->annotations())
  345. if (ai.opcode() == spv::Op::OpGroupDecorate)
  346. return Status::SuccessWithoutChange;
  347. // Do not process if any disallowed extensions are enabled
  348. if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
  349. // Process all functions in the module.
  350. Status status = Status::SuccessWithoutChange;
  351. for (Function& func : *get_module()) {
  352. status = CombineStatus(status, ConvertLocalAccessChains(&func));
  353. if (status == Status::Failure) {
  354. break;
  355. }
  356. }
  357. return status;
  358. }
  359. LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
  360. Pass::Status LocalAccessChainConvertPass::Process() {
  361. Initialize();
  362. return ProcessImpl();
  363. }
  364. void LocalAccessChainConvertPass::InitExtensions() {
  365. extensions_allowlist_.clear();
  366. extensions_allowlist_.insert(
  367. {"SPV_AMD_shader_explicit_vertex_parameter",
  368. "SPV_AMD_shader_trinary_minmax", "SPV_AMD_gcn_shader",
  369. "SPV_KHR_shader_ballot", "SPV_AMD_shader_ballot",
  370. "SPV_AMD_gpu_shader_half_float", "SPV_KHR_shader_draw_parameters",
  371. "SPV_KHR_subgroup_vote", "SPV_KHR_8bit_storage", "SPV_KHR_16bit_storage",
  372. "SPV_KHR_device_group", "SPV_KHR_multiview",
  373. "SPV_NVX_multiview_per_view_attributes", "SPV_NV_viewport_array2",
  374. "SPV_NV_stereo_view_rendering", "SPV_NV_sample_mask_override_coverage",
  375. "SPV_NV_geometry_shader_passthrough", "SPV_AMD_texture_gather_bias_lod",
  376. "SPV_KHR_storage_buffer_storage_class",
  377. // SPV_KHR_variable_pointers
  378. // Currently do not support extended pointer expressions
  379. "SPV_AMD_gpu_shader_int16", "SPV_KHR_post_depth_coverage",
  380. "SPV_KHR_shader_atomic_counter_ops", "SPV_EXT_shader_stencil_export",
  381. "SPV_EXT_shader_viewport_index_layer",
  382. "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_fragment_mask",
  383. "SPV_EXT_fragment_fully_covered", "SPV_AMD_gpu_shader_half_float_fetch",
  384. "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1",
  385. "SPV_GOOGLE_user_type", "SPV_NV_shader_subgroup_partitioned",
  386. "SPV_EXT_demote_to_helper_invocation", "SPV_EXT_descriptor_indexing",
  387. "SPV_NV_fragment_shader_barycentric",
  388. "SPV_NV_compute_shader_derivatives", "SPV_NV_shader_image_footprint",
  389. "SPV_NV_shading_rate", "SPV_NV_mesh_shader", "SPV_EXT_mesh_shader",
  390. "SPV_NV_ray_tracing", "SPV_KHR_ray_tracing", "SPV_KHR_ray_query",
  391. "SPV_EXT_fragment_invocation_density", "SPV_KHR_terminate_invocation",
  392. "SPV_KHR_subgroup_uniform_control_flow", "SPV_KHR_integer_dot_product",
  393. "SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info",
  394. "SPV_KHR_uniform_group_instructions",
  395. "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
  396. "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
  397. "SPV_EXT_fragment_shader_interlock",
  398. "SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix",
  399. "SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch",
  400. "SPV_AMDX_shader_enqueue", "SPV_KHR_fragment_shading_rate"});
  401. }
  402. bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
  403. const Instruction* access_chain_inst) {
  404. assert(IsNonPtrAccessChain(access_chain_inst->opcode()));
  405. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  406. analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
  407. auto constants = const_mgr->GetOperandConstants(access_chain_inst);
  408. uint32_t base_pointer_id = access_chain_inst->GetSingleWordInOperand(0);
  409. Instruction* base_pointer = get_def_use_mgr()->GetDef(base_pointer_id);
  410. const analysis::Pointer* base_pointer_type =
  411. type_mgr->GetType(base_pointer->type_id())->AsPointer();
  412. assert(base_pointer_type != nullptr &&
  413. "The base of the access chain is not a pointer.");
  414. const analysis::Type* current_type = base_pointer_type->pointee_type();
  415. for (uint32_t i = 1; i < access_chain_inst->NumInOperands(); ++i) {
  416. if (IsIndexOutOfBounds(constants[i], current_type)) {
  417. return true;
  418. }
  419. uint32_t index =
  420. (constants[i]
  421. ? static_cast<uint32_t>(constants[i]->GetZeroExtendedValue())
  422. : 0);
  423. current_type = type_mgr->GetMemberType(current_type, {index});
  424. }
  425. return false;
  426. }
  427. bool LocalAccessChainConvertPass::IsIndexOutOfBounds(
  428. const analysis::Constant* index, const analysis::Type* type) const {
  429. if (index == nullptr) {
  430. return false;
  431. }
  432. return index->GetZeroExtendedValue() >= type->NumberOfComponents();
  433. }
  434. } // namespace opt
  435. } // namespace spvtools