local_access_chain_convert_pass.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. // Copyright (c) 2017 The Khronos Group Inc.
  2. // Copyright (c) 2017 Valve Corporation
  3. // Copyright (c) 2017 LunarG Inc.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. #include "source/opt/local_access_chain_convert_pass.h"
  17. #include "ir_builder.h"
  18. #include "ir_context.h"
  19. #include "iterator.h"
  20. namespace spvtools {
  21. namespace opt {
  22. namespace {
  23. const uint32_t kStoreValIdInIdx = 1;
  24. const uint32_t kAccessChainPtrIdInIdx = 0;
  25. const uint32_t kConstantValueInIdx = 0;
  26. const uint32_t kTypeIntWidthInIdx = 0;
  27. } // anonymous namespace
  28. void LocalAccessChainConvertPass::BuildAndAppendInst(
  29. SpvOp opcode, uint32_t typeId, uint32_t resultId,
  30. const std::vector<Operand>& in_opnds,
  31. std::vector<std::unique_ptr<Instruction>>* newInsts) {
  32. std::unique_ptr<Instruction> newInst(
  33. new Instruction(context(), opcode, typeId, resultId, in_opnds));
  34. get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
  35. newInsts->emplace_back(std::move(newInst));
  36. }
  37. uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
  38. const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
  39. std::vector<std::unique_ptr<Instruction>>* newInsts) {
  40. const uint32_t ldResultId = TakeNextId();
  41. if (ldResultId == 0) {
  42. return 0;
  43. }
  44. *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
  45. const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
  46. assert(varInst->opcode() == SpvOpVariable);
  47. *varPteTypeId = GetPointeeTypeId(varInst);
  48. BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId,
  49. {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
  50. newInsts);
  51. return ldResultId;
  52. }
  53. void LocalAccessChainConvertPass::AppendConstantOperands(
  54. const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
  55. uint32_t iidIdx = 0;
  56. ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
  57. if (iidIdx > 0) {
  58. const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
  59. uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
  60. in_opnds->push_back(
  61. {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
  62. }
  63. ++iidIdx;
  64. });
  65. }
  66. bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
  67. const Instruction* address_inst, Instruction* original_load) {
  68. // Build and append load of variable in ptrInst
  69. std::vector<std::unique_ptr<Instruction>> new_inst;
  70. uint32_t varId;
  71. uint32_t varPteTypeId;
  72. const uint32_t ldResultId =
  73. BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
  74. if (ldResultId == 0) {
  75. return false;
  76. }
  77. context()->get_decoration_mgr()->CloneDecorations(
  78. original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision});
  79. original_load->InsertBefore(std::move(new_inst));
  80. // Rewrite |original_load| into an extract.
  81. Instruction::OperandList new_operands;
  82. // copy the result id and the type id to the new operand list.
  83. new_operands.emplace_back(original_load->GetOperand(0));
  84. new_operands.emplace_back(original_load->GetOperand(1));
  85. new_operands.emplace_back(
  86. Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
  87. AppendConstantOperands(address_inst, &new_operands);
  88. original_load->SetOpcode(SpvOpCompositeExtract);
  89. original_load->ReplaceOperands(new_operands);
  90. context()->UpdateDefUse(original_load);
  91. return true;
  92. }
  93. bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
  94. const Instruction* ptrInst, uint32_t valId,
  95. std::vector<std::unique_ptr<Instruction>>* newInsts) {
  96. // Build and append load of variable in ptrInst
  97. uint32_t varId;
  98. uint32_t varPteTypeId;
  99. const uint32_t ldResultId =
  100. BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
  101. if (ldResultId == 0) {
  102. return false;
  103. }
  104. context()->get_decoration_mgr()->CloneDecorations(
  105. varId, ldResultId, {SpvDecorationRelaxedPrecision});
  106. // Build and append Insert
  107. const uint32_t insResultId = TakeNextId();
  108. if (insResultId == 0) {
  109. return false;
  110. }
  111. std::vector<Operand> ins_in_opnds = {
  112. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
  113. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
  114. AppendConstantOperands(ptrInst, &ins_in_opnds);
  115. BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId,
  116. ins_in_opnds, newInsts);
  117. context()->get_decoration_mgr()->CloneDecorations(
  118. varId, insResultId, {SpvDecorationRelaxedPrecision});
  119. // Build and append Store
  120. BuildAndAppendInst(SpvOpStore, 0, 0,
  121. {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
  122. {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
  123. newInsts);
  124. return true;
  125. }
  126. bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
  127. const Instruction* acp) const {
  128. uint32_t inIdx = 0;
  129. return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
  130. if (inIdx > 0) {
  131. Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
  132. if (opInst->opcode() != SpvOpConstant) return false;
  133. }
  134. ++inIdx;
  135. return true;
  136. });
  137. }
  138. bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
  139. if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
  140. if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
  141. SpvOp op = user->opcode();
  142. if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
  143. if (!HasOnlySupportedRefs(user->result_id())) {
  144. return false;
  145. }
  146. } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
  147. !IsNonTypeDecorate(op)) {
  148. return false;
  149. }
  150. return true;
  151. })) {
  152. supported_ref_ptrs_.insert(ptrId);
  153. return true;
  154. }
  155. return false;
  156. }
  157. void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
  158. for (auto bi = func->begin(); bi != func->end(); ++bi) {
  159. for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
  160. switch (ii->opcode()) {
  161. case SpvOpStore:
  162. case SpvOpLoad: {
  163. uint32_t varId;
  164. Instruction* ptrInst = GetPtr(&*ii, &varId);
  165. if (!IsTargetVar(varId)) break;
  166. const SpvOp op = ptrInst->opcode();
  167. // Rule out variables with non-supported refs eg function calls
  168. if (!HasOnlySupportedRefs(varId)) {
  169. seen_non_target_vars_.insert(varId);
  170. seen_target_vars_.erase(varId);
  171. break;
  172. }
  173. // Rule out variables with nested access chains
  174. // TODO(): Convert nested access chains
  175. if (IsNonPtrAccessChain(op) && ptrInst->GetSingleWordInOperand(
  176. kAccessChainPtrIdInIdx) != varId) {
  177. seen_non_target_vars_.insert(varId);
  178. seen_target_vars_.erase(varId);
  179. break;
  180. }
  181. // Rule out variables accessed with non-constant indices
  182. if (!IsConstantIndexAccessChain(ptrInst)) {
  183. seen_non_target_vars_.insert(varId);
  184. seen_target_vars_.erase(varId);
  185. break;
  186. }
  187. } break;
  188. default:
  189. break;
  190. }
  191. }
  192. }
  193. }
  194. Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
  195. Function* func) {
  196. FindTargetVars(func);
  197. // Replace access chains of all targeted variables with equivalent
  198. // extract and insert sequences
  199. bool modified = false;
  200. for (auto bi = func->begin(); bi != func->end(); ++bi) {
  201. std::vector<Instruction*> dead_instructions;
  202. for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
  203. switch (ii->opcode()) {
  204. case SpvOpLoad: {
  205. uint32_t varId;
  206. Instruction* ptrInst = GetPtr(&*ii, &varId);
  207. if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
  208. if (!IsTargetVar(varId)) break;
  209. std::vector<std::unique_ptr<Instruction>> newInsts;
  210. if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
  211. return Status::Failure;
  212. }
  213. modified = true;
  214. } break;
  215. case SpvOpStore: {
  216. uint32_t varId;
  217. Instruction* ptrInst = GetPtr(&*ii, &varId);
  218. if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
  219. if (!IsTargetVar(varId)) break;
  220. std::vector<std::unique_ptr<Instruction>> newInsts;
  221. uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx);
  222. if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
  223. return Status::Failure;
  224. }
  225. dead_instructions.push_back(&*ii);
  226. ++ii;
  227. ii = ii.InsertBefore(std::move(newInsts));
  228. ++ii;
  229. ++ii;
  230. modified = true;
  231. } break;
  232. default:
  233. break;
  234. }
  235. }
  236. while (!dead_instructions.empty()) {
  237. Instruction* inst = dead_instructions.back();
  238. dead_instructions.pop_back();
  239. DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
  240. auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
  241. other_inst);
  242. if (i != dead_instructions.end()) {
  243. dead_instructions.erase(i);
  244. }
  245. });
  246. }
  247. }
  248. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  249. }
  250. void LocalAccessChainConvertPass::Initialize() {
  251. // Initialize Target Variable Caches
  252. seen_target_vars_.clear();
  253. seen_non_target_vars_.clear();
  254. // Initialize collections
  255. supported_ref_ptrs_.clear();
  256. // Initialize extension whitelist
  257. InitExtensions();
  258. }
  259. bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
  260. // This capability can now exist without the extension, so we have to check
  261. // for the capability. This pass is only looking at function scope symbols,
  262. // so we do not care if there are variable pointers on storage buffers.
  263. if (context()->get_feature_mgr()->HasCapability(
  264. SpvCapabilityVariablePointers))
  265. return false;
  266. // If any extension not in whitelist, return false
  267. for (auto& ei : get_module()->extensions()) {
  268. const char* extName =
  269. reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
  270. if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
  271. return false;
  272. }
  273. return true;
  274. }
  275. Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
  276. // If non-32-bit integer type in module, terminate processing
  277. // TODO(): Handle non-32-bit integer constants in access chains
  278. for (const Instruction& inst : get_module()->types_values())
  279. if (inst.opcode() == SpvOpTypeInt &&
  280. inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
  281. return Status::SuccessWithoutChange;
  282. // Do not process if module contains OpGroupDecorate. Additional
  283. // support required in KillNamesAndDecorates().
  284. // TODO(greg-lunarg): Add support for OpGroupDecorate
  285. for (auto& ai : get_module()->annotations())
  286. if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
  287. // Do not process if any disallowed extensions are enabled
  288. if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
  289. // Process all functions in the module.
  290. Status status = Status::SuccessWithoutChange;
  291. for (Function& func : *get_module()) {
  292. status = CombineStatus(status, ConvertLocalAccessChains(&func));
  293. if (status == Status::Failure) {
  294. break;
  295. }
  296. }
  297. return status;
  298. }
  299. LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
  300. Pass::Status LocalAccessChainConvertPass::Process() {
  301. Initialize();
  302. return ProcessImpl();
  303. }
  304. void LocalAccessChainConvertPass::InitExtensions() {
  305. extensions_whitelist_.clear();
  306. extensions_whitelist_.insert({
  307. "SPV_AMD_shader_explicit_vertex_parameter",
  308. "SPV_AMD_shader_trinary_minmax",
  309. "SPV_AMD_gcn_shader",
  310. "SPV_KHR_shader_ballot",
  311. "SPV_AMD_shader_ballot",
  312. "SPV_AMD_gpu_shader_half_float",
  313. "SPV_KHR_shader_draw_parameters",
  314. "SPV_KHR_subgroup_vote",
  315. "SPV_KHR_16bit_storage",
  316. "SPV_KHR_device_group",
  317. "SPV_KHR_multiview",
  318. "SPV_NVX_multiview_per_view_attributes",
  319. "SPV_NV_viewport_array2",
  320. "SPV_NV_stereo_view_rendering",
  321. "SPV_NV_sample_mask_override_coverage",
  322. "SPV_NV_geometry_shader_passthrough",
  323. "SPV_AMD_texture_gather_bias_lod",
  324. "SPV_KHR_storage_buffer_storage_class",
  325. // SPV_KHR_variable_pointers
  326. // Currently do not support extended pointer expressions
  327. "SPV_AMD_gpu_shader_int16",
  328. "SPV_KHR_post_depth_coverage",
  329. "SPV_KHR_shader_atomic_counter_ops",
  330. "SPV_EXT_shader_stencil_export",
  331. "SPV_EXT_shader_viewport_index_layer",
  332. "SPV_AMD_shader_image_load_store_lod",
  333. "SPV_AMD_shader_fragment_mask",
  334. "SPV_EXT_fragment_fully_covered",
  335. "SPV_AMD_gpu_shader_half_float_fetch",
  336. "SPV_GOOGLE_decorate_string",
  337. "SPV_GOOGLE_hlsl_functionality1",
  338. "SPV_GOOGLE_user_type",
  339. "SPV_NV_shader_subgroup_partitioned",
  340. "SPV_EXT_descriptor_indexing",
  341. "SPV_NV_fragment_shader_barycentric",
  342. "SPV_NV_compute_shader_derivatives",
  343. "SPV_NV_shader_image_footprint",
  344. "SPV_NV_shading_rate",
  345. "SPV_NV_mesh_shader",
  346. "SPV_NV_ray_tracing",
  347. "SPV_EXT_fragment_invocation_density",
  348. });
  349. }
  350. } // namespace opt
  351. } // namespace spvtools