desc_sroa_util.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. // Copyright (c) 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/desc_sroa_util.h"
  15. namespace spvtools {
  16. namespace opt {
  17. namespace {
  18. const uint32_t kOpAccessChainInOperandIndexes = 1;
  19. // Returns the length of array type |type|.
  20. uint32_t GetLengthOfArrayType(IRContext* context, Instruction* type) {
  21. assert(type->opcode() == SpvOpTypeArray && "type must be array");
  22. uint32_t length_id = type->GetSingleWordInOperand(1);
  23. const analysis::Constant* length_const =
  24. context->get_constant_mgr()->FindDeclaredConstant(length_id);
  25. assert(length_const != nullptr);
  26. return length_const->GetU32();
  27. }
  28. } // namespace
  29. namespace descsroautil {
  30. bool IsDescriptorArray(IRContext* context, Instruction* var) {
  31. if (var->opcode() != SpvOpVariable) {
  32. return false;
  33. }
  34. uint32_t ptr_type_id = var->type_id();
  35. Instruction* ptr_type_inst = context->get_def_use_mgr()->GetDef(ptr_type_id);
  36. if (ptr_type_inst->opcode() != SpvOpTypePointer) {
  37. return false;
  38. }
  39. uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
  40. Instruction* var_type_inst = context->get_def_use_mgr()->GetDef(var_type_id);
  41. if (var_type_inst->opcode() != SpvOpTypeArray &&
  42. var_type_inst->opcode() != SpvOpTypeStruct) {
  43. return false;
  44. }
  45. // All structures with descriptor assignments must be replaced by variables,
  46. // one for each of their members - with the exceptions of buffers.
  47. if (IsTypeOfStructuredBuffer(context, var_type_inst)) {
  48. return false;
  49. }
  50. if (!context->get_decoration_mgr()->HasDecoration(
  51. var->result_id(), SpvDecorationDescriptorSet)) {
  52. return false;
  53. }
  54. return context->get_decoration_mgr()->HasDecoration(var->result_id(),
  55. SpvDecorationBinding);
  56. }
  57. bool IsTypeOfStructuredBuffer(IRContext* context, const Instruction* type) {
  58. if (type->opcode() != SpvOpTypeStruct) {
  59. return false;
  60. }
  61. // All buffers have offset decorations for members of their structure types.
  62. // This is how we distinguish it from a structure of descriptors.
  63. return context->get_decoration_mgr()->HasDecoration(type->result_id(),
  64. SpvDecorationOffset);
  65. }
  66. const analysis::Constant* GetAccessChainIndexAsConst(
  67. IRContext* context, Instruction* access_chain) {
  68. if (access_chain->NumInOperands() <= 1) {
  69. return nullptr;
  70. }
  71. uint32_t idx_id = GetFirstIndexOfAccessChain(access_chain);
  72. const analysis::Constant* idx_const =
  73. context->get_constant_mgr()->FindDeclaredConstant(idx_id);
  74. return idx_const;
  75. }
  76. uint32_t GetFirstIndexOfAccessChain(Instruction* access_chain) {
  77. assert(access_chain->NumInOperands() > 1 &&
  78. "OpAccessChain does not have Indexes operand");
  79. return access_chain->GetSingleWordInOperand(kOpAccessChainInOperandIndexes);
  80. }
  81. uint32_t GetNumberOfElementsForArrayOrStruct(IRContext* context,
  82. Instruction* var) {
  83. uint32_t ptr_type_id = var->type_id();
  84. Instruction* ptr_type_inst = context->get_def_use_mgr()->GetDef(ptr_type_id);
  85. assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
  86. "Variable should be a pointer to an array or structure.");
  87. uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
  88. Instruction* pointee_type_inst =
  89. context->get_def_use_mgr()->GetDef(pointee_type_id);
  90. if (pointee_type_inst->opcode() == SpvOpTypeArray) {
  91. return GetLengthOfArrayType(context, pointee_type_inst);
  92. }
  93. assert(pointee_type_inst->opcode() == SpvOpTypeStruct &&
  94. "Variable should be a pointer to an array or structure.");
  95. return pointee_type_inst->NumInOperands();
  96. }
  97. } // namespace descsroautil
  98. } // namespace opt
  99. } // namespace spvtools