fix_storage_class.cpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. // Copyright (c) 2019 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "fix_storage_class.h"
  15. #include "source/opt/instruction.h"
  16. #include "source/opt/ir_context.h"
  17. namespace spvtools {
  18. namespace opt {
  19. Pass::Status FixStorageClass::Process() {
  20. bool modified = false;
  21. get_module()->ForEachInst([this, &modified](Instruction* inst) {
  22. if (inst->opcode() == SpvOpVariable) {
  23. std::vector<Instruction*> uses;
  24. get_def_use_mgr()->ForEachUser(
  25. inst, [&uses](Instruction* use) { uses.push_back(use); });
  26. for (Instruction* use : uses) {
  27. modified |= PropagateStorageClass(
  28. use, static_cast<SpvStorageClass>(inst->GetSingleWordInOperand(0)));
  29. }
  30. }
  31. });
  32. return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
  33. }
  34. bool FixStorageClass::PropagateStorageClass(Instruction* inst,
  35. SpvStorageClass storage_class) {
  36. if (!IsPointerResultType(inst)) {
  37. return false;
  38. }
  39. if (IsPointerToStorageClass(inst, storage_class)) {
  40. return false;
  41. }
  42. switch (inst->opcode()) {
  43. case SpvOpAccessChain:
  44. case SpvOpPtrAccessChain:
  45. case SpvOpInBoundsAccessChain:
  46. case SpvOpCopyObject:
  47. case SpvOpPhi:
  48. case SpvOpSelect:
  49. FixInstruction(inst, storage_class);
  50. return true;
  51. case SpvOpFunctionCall:
  52. // We cannot be sure of the actual connection between the storage class
  53. // of the parameter and the storage class of the result, so we should not
  54. // do anything. If the result type needs to be fixed, the function call
  55. // should be inlined.
  56. return false;
  57. case SpvOpImageTexelPointer:
  58. case SpvOpLoad:
  59. case SpvOpStore:
  60. case SpvOpCopyMemory:
  61. case SpvOpCopyMemorySized:
  62. case SpvOpVariable:
  63. case SpvOpBitcast:
  64. // Nothing to change for these opcode. The result type is the same
  65. // regardless of the storage class of the operand.
  66. return false;
  67. default:
  68. assert(false &&
  69. "Not expecting instruction to have a pointer result type.");
  70. return false;
  71. }
  72. }
  73. void FixStorageClass::FixInstruction(Instruction* inst,
  74. SpvStorageClass storage_class) {
  75. assert(IsPointerResultType(inst) &&
  76. "The result type of the instruction must be a pointer.");
  77. ChangeResultStorageClass(inst, storage_class);
  78. std::vector<Instruction*> uses;
  79. get_def_use_mgr()->ForEachUser(
  80. inst, [&uses](Instruction* use) { uses.push_back(use); });
  81. for (Instruction* use : uses) {
  82. PropagateStorageClass(use, storage_class);
  83. }
  84. }
  85. void FixStorageClass::ChangeResultStorageClass(
  86. Instruction* inst, SpvStorageClass storage_class) const {
  87. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  88. Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
  89. assert(result_type_inst->opcode() == SpvOpTypePointer);
  90. uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
  91. uint32_t new_result_type_id =
  92. type_mgr->FindPointerToType(pointee_type_id, storage_class);
  93. inst->SetResultType(new_result_type_id);
  94. context()->UpdateDefUse(inst);
  95. }
  96. bool FixStorageClass::IsPointerResultType(Instruction* inst) {
  97. if (inst->type_id() == 0) {
  98. return false;
  99. }
  100. const analysis::Type* ret_type =
  101. context()->get_type_mgr()->GetType(inst->type_id());
  102. return ret_type->AsPointer() != nullptr;
  103. }
  104. bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
  105. SpvStorageClass storage_class) {
  106. analysis::TypeManager* type_mgr = context()->get_type_mgr();
  107. analysis::Type* pType = type_mgr->GetType(inst->type_id());
  108. const analysis::Pointer* result_type = pType->AsPointer();
  109. if (result_type == nullptr) {
  110. return false;
  111. }
  112. return (result_type->storage_class() == storage_class);
  113. }
  114. // namespace opt
  115. } // namespace opt
  116. } // namespace spvtools