combine_access_chains.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/combine_access_chains.h"
  15. #include <utility>
  16. #include "source/opt/constants.h"
  17. #include "source/opt/ir_builder.h"
  18. #include "source/opt/ir_context.h"
  19. namespace spvtools {
  20. namespace opt {
  21. Pass::Status CombineAccessChains::Process() {
  22. bool modified = false;
  23. for (auto& function : *get_module()) {
  24. modified |= ProcessFunction(function);
  25. }
  26. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  27. }
  28. bool CombineAccessChains::ProcessFunction(Function& function) {
  29. bool modified = false;
  30. cfg()->ForEachBlockInReversePostOrder(
  31. function.entry().get(), [&modified, this](BasicBlock* block) {
  32. block->ForEachInst([&modified, this](Instruction* inst) {
  33. switch (inst->opcode()) {
  34. case SpvOpAccessChain:
  35. case SpvOpInBoundsAccessChain:
  36. case SpvOpPtrAccessChain:
  37. case SpvOpInBoundsPtrAccessChain:
  38. modified |= CombineAccessChain(inst);
  39. break;
  40. default:
  41. break;
  42. }
  43. });
  44. });
  45. return modified;
  46. }
  47. uint32_t CombineAccessChains::GetConstantValue(
  48. const analysis::Constant* constant_inst) {
  49. if (constant_inst->type()->AsInteger()->width() <= 32) {
  50. if (constant_inst->type()->AsInteger()->IsSigned()) {
  51. return static_cast<uint32_t>(constant_inst->GetS32());
  52. } else {
  53. return constant_inst->GetU32();
  54. }
  55. } else {
  56. assert(false);
  57. return 0u;
  58. }
  59. }
  60. uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) {
  61. uint32_t array_stride = 0;
  62. context()->get_decoration_mgr()->WhileEachDecoration(
  63. inst->type_id(), SpvDecorationArrayStride,
  64. [&array_stride](const Instruction& decoration) {
  65. assert(decoration.opcode() != SpvOpDecorateId);
  66. if (decoration.opcode() == SpvOpDecorate) {
  67. array_stride = decoration.GetSingleWordInOperand(1);
  68. } else {
  69. array_stride = decoration.GetSingleWordInOperand(2);
  70. }
  71. return false;
  72. });
  73. return array_stride;
  74. }
  75. const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) {
  76. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  77. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  78. Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  79. const analysis::Type* type = type_mgr->GetType(base_ptr->type_id());
  80. assert(type->AsPointer());
  81. type = type->AsPointer()->pointee_type();
  82. std::vector<uint32_t> element_indices;
  83. uint32_t starting_index = 1;
  84. if (IsPtrAccessChain(inst->opcode())) {
  85. // Skip the first index of OpPtrAccessChain as it does not affect type
  86. // resolution.
  87. starting_index = 2;
  88. }
  89. for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
  90. Instruction* index_inst =
  91. def_use_mgr->GetDef(inst->GetSingleWordInOperand(i));
  92. const analysis::Constant* index_constant =
  93. context()->get_constant_mgr()->GetConstantFromInst(index_inst);
  94. if (index_constant) {
  95. uint32_t index_value = GetConstantValue(index_constant);
  96. element_indices.push_back(index_value);
  97. } else {
  98. // This index must not matter to resolve the type in valid SPIR-V.
  99. element_indices.push_back(0);
  100. }
  101. }
  102. type = type_mgr->GetMemberType(type, element_indices);
  103. return type;
  104. }
  105. bool CombineAccessChains::CombineIndices(Instruction* ptr_input,
  106. Instruction* inst,
  107. std::vector<Operand>* new_operands) {
  108. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  109. analysis::ConstantManager* constant_mgr = context()->get_constant_mgr();
  110. Instruction* last_index_inst = def_use_mgr->GetDef(
  111. ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1));
  112. const analysis::Constant* last_index_constant =
  113. constant_mgr->GetConstantFromInst(last_index_inst);
  114. Instruction* element_inst =
  115. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  116. const analysis::Constant* element_constant =
  117. constant_mgr->GetConstantFromInst(element_inst);
  118. // Combine the last index of the AccessChain (|ptr_inst|) with the element
  119. // operand of the PtrAccessChain (|inst|).
  120. const bool combining_element_operands =
  121. IsPtrAccessChain(inst->opcode()) &&
  122. IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2;
  123. uint32_t new_value_id = 0;
  124. const analysis::Type* type = GetIndexedType(ptr_input);
  125. if (last_index_constant && element_constant) {
  126. // Combine the constants.
  127. uint32_t new_value = GetConstantValue(last_index_constant) +
  128. GetConstantValue(element_constant);
  129. const analysis::Constant* new_value_constant =
  130. constant_mgr->GetConstant(last_index_constant->type(), {new_value});
  131. Instruction* new_value_inst =
  132. constant_mgr->GetDefiningInstruction(new_value_constant);
  133. new_value_id = new_value_inst->result_id();
  134. } else if (!type->AsStruct() || combining_element_operands) {
  135. // Generate an addition of the two indices.
  136. InstructionBuilder builder(
  137. context(), inst,
  138. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  139. Instruction* addition = builder.AddIAdd(last_index_inst->type_id(),
  140. last_index_inst->result_id(),
  141. element_inst->result_id());
  142. new_value_id = addition->result_id();
  143. } else {
  144. // Indexing into structs must be constant, so bail out here.
  145. return false;
  146. }
  147. new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}});
  148. return true;
  149. }
  150. bool CombineAccessChains::CreateNewInputOperands(
  151. Instruction* ptr_input, Instruction* inst,
  152. std::vector<Operand>* new_operands) {
  153. // Start by copying all the input operands of the feeder access chain.
  154. for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) {
  155. new_operands->push_back(ptr_input->GetInOperand(i));
  156. }
  157. // Deal with the last index of the feeder access chain.
  158. if (IsPtrAccessChain(inst->opcode())) {
  159. // The last index of the feeder should be combined with the element operand
  160. // of |inst|.
  161. if (!CombineIndices(ptr_input, inst, new_operands)) return false;
  162. } else {
  163. // The indices aren't being combined so now add the last index operand of
  164. // |ptr_input|.
  165. new_operands->push_back(
  166. ptr_input->GetInOperand(ptr_input->NumInOperands() - 1));
  167. }
  168. // Copy the remaining index operands.
  169. uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1;
  170. for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
  171. new_operands->push_back(inst->GetInOperand(i));
  172. }
  173. return true;
  174. }
  175. bool CombineAccessChains::CombineAccessChain(Instruction* inst) {
  176. assert((inst->opcode() == SpvOpPtrAccessChain ||
  177. inst->opcode() == SpvOpAccessChain ||
  178. inst->opcode() == SpvOpInBoundsAccessChain ||
  179. inst->opcode() == SpvOpInBoundsPtrAccessChain) &&
  180. "Wrong opcode. Expected an access chain.");
  181. Instruction* ptr_input =
  182. context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
  183. if (ptr_input->opcode() != SpvOpAccessChain &&
  184. ptr_input->opcode() != SpvOpInBoundsAccessChain &&
  185. ptr_input->opcode() != SpvOpPtrAccessChain &&
  186. ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) {
  187. return false;
  188. }
  189. if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false;
  190. // Handles the following cases:
  191. // 1. |ptr_input| is an index-less access chain. Replace the pointer
  192. // in |inst| with |ptr_input|'s pointer.
  193. // 2. |inst| is a index-less access chain. Change |inst| to an
  194. // OpCopyObject.
  195. // 3. |inst| is not a pointer access chain.
  196. // |inst|'s indices are appended to |ptr_input|'s indices.
  197. // 4. |ptr_input| is not pointer access chain.
  198. // |inst| is a pointer access chain.
  199. // |inst|'s element operand is combined with the last index in
  200. // |ptr_input| to form a new operand.
  201. // 5. |ptr_input| is a pointer access chain.
  202. // Like the above scenario, |inst|'s element operand is combined
  203. // with |ptr_input|'s last index. This results is either a
  204. // combined element operand or combined regular index.
  205. // TODO(alan-baker): Support this properly. Requires analyzing the
  206. // size/alignment of the type and converting the stride into an element
  207. // index.
  208. uint32_t array_stride = GetArrayStride(ptr_input);
  209. if (array_stride != 0) return false;
  210. if (ptr_input->NumInOperands() == 1) {
  211. // The input is effectively a no-op.
  212. inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)});
  213. context()->AnalyzeUses(inst);
  214. } else if (inst->NumInOperands() == 1) {
  215. // |inst| is a no-op, change it to a copy. Instruction simplification will
  216. // clean it up.
  217. inst->SetOpcode(SpvOpCopyObject);
  218. } else {
  219. std::vector<Operand> new_operands;
  220. if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false;
  221. // Update the instruction.
  222. inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode()));
  223. inst->SetInOperands(std::move(new_operands));
  224. context()->AnalyzeUses(inst);
  225. }
  226. return true;
  227. }
  228. SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) {
  229. auto IsInBounds = [](SpvOp opcode) {
  230. return opcode == SpvOpInBoundsPtrAccessChain ||
  231. opcode == SpvOpInBoundsAccessChain;
  232. };
  233. if (input_opcode == SpvOpInBoundsPtrAccessChain) {
  234. if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain;
  235. } else if (input_opcode == SpvOpInBoundsAccessChain) {
  236. if (!IsInBounds(base_opcode)) return SpvOpAccessChain;
  237. }
  238. return input_opcode;
  239. }
  240. bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) {
  241. return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain;
  242. }
  243. bool CombineAccessChains::Has64BitIndices(Instruction* inst) {
  244. for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
  245. Instruction* index_inst =
  246. context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i));
  247. const analysis::Type* index_type =
  248. context()->get_type_mgr()->GetType(index_inst->type_id());
  249. if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32)
  250. return true;
  251. }
  252. return false;
  253. }
  254. } // namespace opt
  255. } // namespace spvtools