desc_sroa_util.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. // Copyright (c) 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opt/desc_sroa_util.h"
  15. namespace spvtools {
  16. namespace opt {
  17. namespace {
  18. constexpr uint32_t kOpAccessChainInOperandIndexes = 1;
  19. // Returns the length of array type |type|.
  20. uint32_t GetLengthOfArrayType(IRContext* context, Instruction* type) {
  21. assert(type->opcode() == spv::Op::OpTypeArray && "type must be array");
  22. uint32_t length_id = type->GetSingleWordInOperand(1);
  23. const analysis::Constant* length_const =
  24. context->get_constant_mgr()->FindDeclaredConstant(length_id);
  25. assert(length_const != nullptr);
  26. return length_const->GetU32();
  27. }
  28. bool HasDescriptorDecorations(IRContext* context, Instruction* var) {
  29. const auto& decoration_mgr = context->get_decoration_mgr();
  30. return decoration_mgr->HasDecoration(
  31. var->result_id(), uint32_t(spv::Decoration::DescriptorSet)) &&
  32. decoration_mgr->HasDecoration(var->result_id(),
  33. uint32_t(spv::Decoration::Binding));
  34. }
  35. Instruction* GetVariableType(IRContext* context, Instruction* var) {
  36. if (var->opcode() != spv::Op::OpVariable) {
  37. return nullptr;
  38. }
  39. uint32_t ptr_type_id = var->type_id();
  40. Instruction* ptr_type_inst = context->get_def_use_mgr()->GetDef(ptr_type_id);
  41. if (ptr_type_inst->opcode() != spv::Op::OpTypePointer) {
  42. return nullptr;
  43. }
  44. uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
  45. return context->get_def_use_mgr()->GetDef(var_type_id);
  46. }
  47. } // namespace
  48. namespace descsroautil {
  49. bool IsDescriptorArray(IRContext* context, Instruction* var) {
  50. Instruction* var_type_inst = GetVariableType(context, var);
  51. if (var_type_inst == nullptr) return false;
  52. return var_type_inst->opcode() == spv::Op::OpTypeArray &&
  53. HasDescriptorDecorations(context, var);
  54. }
  55. bool IsDescriptorStruct(IRContext* context, Instruction* var) {
  56. Instruction* var_type_inst = GetVariableType(context, var);
  57. if (var_type_inst == nullptr) return false;
  58. while (var_type_inst->opcode() == spv::Op::OpTypeArray) {
  59. var_type_inst = context->get_def_use_mgr()->GetDef(
  60. var_type_inst->GetInOperand(0).AsId());
  61. }
  62. if (var_type_inst->opcode() != spv::Op::OpTypeStruct) return false;
  63. // All structures with descriptor assignments must be replaced by variables,
  64. // one for each of their members - with the exceptions of buffers.
  65. if (IsTypeOfStructuredBuffer(context, var_type_inst)) {
  66. return false;
  67. }
  68. return HasDescriptorDecorations(context, var);
  69. }
  70. bool IsTypeOfStructuredBuffer(IRContext* context, const Instruction* type) {
  71. if (type->opcode() != spv::Op::OpTypeStruct) {
  72. return false;
  73. }
  74. // All buffers have offset decorations for members of their structure types.
  75. // This is how we distinguish it from a structure of descriptors.
  76. return context->get_decoration_mgr()->HasDecoration(
  77. type->result_id(), uint32_t(spv::Decoration::Offset));
  78. }
  79. const analysis::Constant* GetAccessChainIndexAsConst(
  80. IRContext* context, Instruction* access_chain) {
  81. if (access_chain->NumInOperands() <= 1) {
  82. return nullptr;
  83. }
  84. uint32_t idx_id = GetFirstIndexOfAccessChain(access_chain);
  85. const analysis::Constant* idx_const =
  86. context->get_constant_mgr()->FindDeclaredConstant(idx_id);
  87. return idx_const;
  88. }
  89. uint32_t GetFirstIndexOfAccessChain(Instruction* access_chain) {
  90. assert(access_chain->NumInOperands() > 1 &&
  91. "OpAccessChain does not have Indexes operand");
  92. return access_chain->GetSingleWordInOperand(kOpAccessChainInOperandIndexes);
  93. }
  94. uint32_t GetNumberOfElementsForArrayOrStruct(IRContext* context,
  95. Instruction* var) {
  96. uint32_t ptr_type_id = var->type_id();
  97. Instruction* ptr_type_inst = context->get_def_use_mgr()->GetDef(ptr_type_id);
  98. assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer &&
  99. "Variable should be a pointer to an array or structure.");
  100. uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
  101. Instruction* pointee_type_inst =
  102. context->get_def_use_mgr()->GetDef(pointee_type_id);
  103. if (pointee_type_inst->opcode() == spv::Op::OpTypeArray) {
  104. return GetLengthOfArrayType(context, pointee_type_inst);
  105. }
  106. assert(pointee_type_inst->opcode() == spv::Op::OpTypeStruct &&
  107. "Variable should be a pointer to an array or structure.");
  108. return pointee_type_inst->NumInOperands();
  109. }
  110. } // namespace descsroautil
  111. } // namespace opt
  112. } // namespace spvtools