combine_access_chains.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/combine_access_chains.h"
  15. #include <utility>
  16. #include "source/opt/constants.h"
  17. #include "source/opt/ir_builder.h"
  18. #include "source/opt/ir_context.h"
  19. namespace spvtools {
  20. namespace opt {
  21. Pass::Status CombineAccessChains::Process() {
  22. bool modified = false;
  23. for (auto& function : *get_module()) {
  24. modified |= ProcessFunction(function);
  25. }
  26. return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
  27. }
  28. bool CombineAccessChains::ProcessFunction(Function& function) {
  29. if (function.IsDeclaration()) {
  30. return false;
  31. }
  32. bool modified = false;
  33. cfg()->ForEachBlockInReversePostOrder(
  34. function.entry().get(), [&modified, this](BasicBlock* block) {
  35. block->ForEachInst([&modified, this](Instruction* inst) {
  36. switch (inst->opcode()) {
  37. case spv::Op::OpAccessChain:
  38. case spv::Op::OpInBoundsAccessChain:
  39. case spv::Op::OpPtrAccessChain:
  40. case spv::Op::OpInBoundsPtrAccessChain:
  41. modified |= CombineAccessChain(inst);
  42. break;
  43. default:
  44. break;
  45. }
  46. });
  47. });
  48. return modified;
  49. }
  50. uint32_t CombineAccessChains::GetConstantValue(
  51. const analysis::Constant* constant_inst) {
  52. if (constant_inst->type()->AsInteger()->width() <= 32) {
  53. if (constant_inst->type()->AsInteger()->IsSigned()) {
  54. return static_cast<uint32_t>(constant_inst->GetS32());
  55. } else {
  56. return constant_inst->GetU32();
  57. }
  58. } else {
  59. assert(false);
  60. return 0u;
  61. }
  62. }
  63. uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) {
  64. uint32_t array_stride = 0;
  65. context()->get_decoration_mgr()->WhileEachDecoration(
  66. inst->type_id(), uint32_t(spv::Decoration::ArrayStride),
  67. [&array_stride](const Instruction& decoration) {
  68. assert(decoration.opcode() != spv::Op::OpDecorateId);
  69. if (decoration.opcode() == spv::Op::OpDecorate) {
  70. array_stride = decoration.GetSingleWordInOperand(1);
  71. } else {
  72. array_stride = decoration.GetSingleWordInOperand(2);
  73. }
  74. return false;
  75. });
  76. return array_stride;
  77. }
  78. const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) {
  79. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  80. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  81. Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  82. const analysis::Type* type = type_mgr->GetType(base_ptr->type_id());
  83. assert(type->AsPointer());
  84. type = type->AsPointer()->pointee_type();
  85. std::vector<uint32_t> element_indices;
  86. uint32_t starting_index = 1;
  87. if (IsPtrAccessChain(inst->opcode())) {
  88. // Skip the first index of OpPtrAccessChain as it does not affect type
  89. // resolution.
  90. starting_index = 2;
  91. }
  92. for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
  93. Instruction* index_inst =
  94. def_use_mgr->GetDef(inst->GetSingleWordInOperand(i));
  95. const analysis::Constant* index_constant =
  96. context()->get_constant_mgr()->GetConstantFromInst(index_inst);
  97. if (index_constant) {
  98. uint32_t index_value = GetConstantValue(index_constant);
  99. element_indices.push_back(index_value);
  100. } else {
  101. // This index must not matter to resolve the type in valid SPIR-V.
  102. element_indices.push_back(0);
  103. }
  104. }
  105. type = type_mgr->GetMemberType(type, element_indices);
  106. return type;
  107. }
  108. bool CombineAccessChains::CombineIndices(Instruction* ptr_input,
  109. Instruction* inst,
  110. std::vector<Operand>* new_operands) {
  111. analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  112. analysis::ConstantManager* constant_mgr = context()->get_constant_mgr();
  113. Instruction* last_index_inst = def_use_mgr->GetDef(
  114. ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1));
  115. const analysis::Constant* last_index_constant =
  116. constant_mgr->GetConstantFromInst(last_index_inst);
  117. Instruction* element_inst =
  118. def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  119. const analysis::Constant* element_constant =
  120. constant_mgr->GetConstantFromInst(element_inst);
  121. // Combine the last index of the AccessChain (|ptr_inst|) with the element
  122. // operand of the PtrAccessChain (|inst|).
  123. const bool combining_element_operands =
  124. IsPtrAccessChain(inst->opcode()) &&
  125. IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2;
  126. uint32_t new_value_id = 0;
  127. const analysis::Type* type = GetIndexedType(ptr_input);
  128. if (last_index_constant && element_constant) {
  129. // Combine the constants.
  130. uint32_t new_value = GetConstantValue(last_index_constant) +
  131. GetConstantValue(element_constant);
  132. const analysis::Constant* new_value_constant =
  133. constant_mgr->GetConstant(last_index_constant->type(), {new_value});
  134. Instruction* new_value_inst =
  135. constant_mgr->GetDefiningInstruction(new_value_constant);
  136. new_value_id = new_value_inst->result_id();
  137. } else if (!type->AsStruct() || combining_element_operands) {
  138. // Generate an addition of the two indices.
  139. InstructionBuilder builder(
  140. context(), inst,
  141. IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
  142. Instruction* addition = builder.AddIAdd(last_index_inst->type_id(),
  143. last_index_inst->result_id(),
  144. element_inst->result_id());
  145. new_value_id = addition->result_id();
  146. } else {
  147. // Indexing into structs must be constant, so bail out here.
  148. return false;
  149. }
  150. new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}});
  151. return true;
  152. }
  153. bool CombineAccessChains::CreateNewInputOperands(
  154. Instruction* ptr_input, Instruction* inst,
  155. std::vector<Operand>* new_operands) {
  156. // Start by copying all the input operands of the feeder access chain.
  157. for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) {
  158. new_operands->push_back(ptr_input->GetInOperand(i));
  159. }
  160. // Deal with the last index of the feeder access chain.
  161. if (IsPtrAccessChain(inst->opcode())) {
  162. // The last index of the feeder should be combined with the element operand
  163. // of |inst|.
  164. if (!CombineIndices(ptr_input, inst, new_operands)) return false;
  165. } else {
  166. // The indices aren't being combined so now add the last index operand of
  167. // |ptr_input|.
  168. new_operands->push_back(
  169. ptr_input->GetInOperand(ptr_input->NumInOperands() - 1));
  170. }
  171. // Copy the remaining index operands.
  172. uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1;
  173. for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
  174. new_operands->push_back(inst->GetInOperand(i));
  175. }
  176. return true;
  177. }
  178. bool CombineAccessChains::CombineAccessChain(Instruction* inst) {
  179. assert((inst->opcode() == spv::Op::OpPtrAccessChain ||
  180. inst->opcode() == spv::Op::OpAccessChain ||
  181. inst->opcode() == spv::Op::OpInBoundsAccessChain ||
  182. inst->opcode() == spv::Op::OpInBoundsPtrAccessChain) &&
  183. "Wrong opcode. Expected an access chain.");
  184. Instruction* ptr_input =
  185. context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
  186. if (ptr_input->opcode() != spv::Op::OpAccessChain &&
  187. ptr_input->opcode() != spv::Op::OpInBoundsAccessChain &&
  188. ptr_input->opcode() != spv::Op::OpPtrAccessChain &&
  189. ptr_input->opcode() != spv::Op::OpInBoundsPtrAccessChain) {
  190. return false;
  191. }
  192. if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false;
  193. // Handles the following cases:
  194. // 1. |ptr_input| is an index-less access chain. Replace the pointer
  195. // in |inst| with |ptr_input|'s pointer.
  196. // 2. |inst| is a index-less access chain. Change |inst| to an
  197. // OpCopyObject.
  198. // 3. |inst| is not a pointer access chain.
  199. // |inst|'s indices are appended to |ptr_input|'s indices.
  200. // 4. |ptr_input| is not pointer access chain.
  201. // |inst| is a pointer access chain.
  202. // |inst|'s element operand is combined with the last index in
  203. // |ptr_input| to form a new operand.
  204. // 5. |ptr_input| is a pointer access chain.
  205. // Like the above scenario, |inst|'s element operand is combined
  206. // with |ptr_input|'s last index. This results is either a
  207. // combined element operand or combined regular index.
  208. // TODO(alan-baker): Support this properly. Requires analyzing the
  209. // size/alignment of the type and converting the stride into an element
  210. // index.
  211. uint32_t array_stride = GetArrayStride(ptr_input);
  212. if (array_stride != 0) return false;
  213. if (ptr_input->NumInOperands() == 1) {
  214. // The input is effectively a no-op.
  215. inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)});
  216. context()->AnalyzeUses(inst);
  217. } else if (inst->NumInOperands() == 1) {
  218. // |inst| is a no-op, change it to a copy. Instruction simplification will
  219. // clean it up.
  220. inst->SetOpcode(spv::Op::OpCopyObject);
  221. } else {
  222. std::vector<Operand> new_operands;
  223. if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false;
  224. // Update the instruction.
  225. inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode()));
  226. inst->SetInOperands(std::move(new_operands));
  227. context()->AnalyzeUses(inst);
  228. }
  229. return true;
  230. }
  231. spv::Op CombineAccessChains::UpdateOpcode(spv::Op base_opcode,
  232. spv::Op input_opcode) {
  233. auto IsInBounds = [](spv::Op opcode) {
  234. return opcode == spv::Op::OpInBoundsPtrAccessChain ||
  235. opcode == spv::Op::OpInBoundsAccessChain;
  236. };
  237. if (input_opcode == spv::Op::OpInBoundsPtrAccessChain) {
  238. if (!IsInBounds(base_opcode)) return spv::Op::OpPtrAccessChain;
  239. } else if (input_opcode == spv::Op::OpInBoundsAccessChain) {
  240. if (!IsInBounds(base_opcode)) return spv::Op::OpAccessChain;
  241. }
  242. return input_opcode;
  243. }
  244. bool CombineAccessChains::IsPtrAccessChain(spv::Op opcode) {
  245. return opcode == spv::Op::OpPtrAccessChain ||
  246. opcode == spv::Op::OpInBoundsPtrAccessChain;
  247. }
  248. bool CombineAccessChains::Has64BitIndices(Instruction* inst) {
  249. for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
  250. Instruction* index_inst =
  251. context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i));
  252. const analysis::Type* index_type =
  253. context()->get_type_mgr()->GetType(index_inst->type_id());
  254. if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32)
  255. return true;
  256. }
  257. return false;
  258. }
  259. } // namespace opt
  260. } // namespace spvtools