RemoveBufferBlockVisitor.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. //===-- RemoveBufferBlockVisitor.cpp - RemoveBufferBlock Visitor -*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "RemoveBufferBlockVisitor.h"
  10. #include "clang/SPIRV/SpirvContext.h"
  11. #include "clang/SPIRV/SpirvFunction.h"
  12. namespace clang {
  13. namespace spirv {
  14. bool RemoveBufferBlockVisitor::isBufferBlockDecorationDeprecated() {
  15. return featureManager.isTargetEnvVulkan1p2OrAbove();
  16. }
  17. bool RemoveBufferBlockVisitor::visit(SpirvModule *mod, Phase phase) {
  18. // If the target environment is Vulkan 1.2 or later, BufferBlock decoration is
  19. // deprecated and should be removed from the module.
  20. // Otherwise, no action is needed by this IMR visitor.
  21. if (phase == Visitor::Phase::Init)
  22. if (!isBufferBlockDecorationDeprecated())
  23. return false;
  24. return true;
  25. }
  26. bool RemoveBufferBlockVisitor::hasStorageBufferInterfaceType(
  27. const SpirvType *type) {
  28. while (type != nullptr) {
  29. if (const auto *structType = dyn_cast<StructType>(type)) {
  30. return structType->getInterfaceType() ==
  31. StructInterfaceType::StorageBuffer;
  32. } else if (const auto *elemType = dyn_cast<ArrayType>(type)) {
  33. type = elemType->getElementType();
  34. } else if (const auto *elemType = dyn_cast<RuntimeArrayType>(type)) {
  35. type = elemType->getElementType();
  36. } else {
  37. return false;
  38. }
  39. }
  40. return false;
  41. }
  42. bool RemoveBufferBlockVisitor::visitInstruction(SpirvInstruction *inst) {
  43. if (!inst->getResultType())
  44. return true;
  45. // OpAccessChain can obtain pointers to any type. Its result type is
  46. // OpTypePointer, and it should get the same storage class as its base.
  47. if (auto *accessChain = dyn_cast<SpirvAccessChain>(inst)) {
  48. auto *accessChainType = accessChain->getResultType();
  49. auto *baseType = accessChain->getBase()->getResultType();
  50. // The result type of OpAccessChain and the result type of its base must be
  51. // OpTypePointer.
  52. assert(isa<SpirvPointerType>(accessChainType));
  53. assert(isa<SpirvPointerType>(baseType));
  54. auto *accessChainPtr = dyn_cast<SpirvPointerType>(accessChainType);
  55. auto *basePtr = dyn_cast<SpirvPointerType>(baseType);
  56. auto baseStorageClass = basePtr->getStorageClass();
  57. if (accessChainPtr->getStorageClass() != baseStorageClass) {
  58. auto *newAccessChainType = context.getPointerType(
  59. accessChainPtr->getPointeeType(), baseStorageClass);
  60. inst->setStorageClass(baseStorageClass);
  61. inst->setResultType(newAccessChainType);
  62. }
  63. }
  64. // For all instructions, if the result type is a pointer pointing to a struct
  65. // with StorageBuffer interface, the storage class must be updated.
  66. const auto *instType = inst->getResultType();
  67. const auto *newInstType = instType;
  68. spv::StorageClass newInstStorageClass = spv::StorageClass::Max;
  69. if (updateStorageClass(instType, &newInstType, &newInstStorageClass)) {
  70. inst->setResultType(newInstType);
  71. inst->setStorageClass(newInstStorageClass);
  72. }
  73. return true;
  74. }
  75. bool RemoveBufferBlockVisitor::updateStorageClass(
  76. const SpirvType *type, const SpirvType **newType,
  77. spv::StorageClass *newStorageClass) {
  78. auto *ptrType = dyn_cast<SpirvPointerType>(type);
  79. if (ptrType == nullptr)
  80. return false;
  81. const auto *innerType = ptrType->getPointeeType();
  82. // For usual cases such as _ptr_Uniform_StructuredBuffer_float.
  83. if (hasStorageBufferInterfaceType(innerType) &&
  84. ptrType->getStorageClass() != spv::StorageClass::StorageBuffer) {
  85. *newType =
  86. context.getPointerType(innerType, spv::StorageClass::StorageBuffer);
  87. *newStorageClass = spv::StorageClass::StorageBuffer;
  88. return true;
  89. }
  90. // For pointer-to-pointer cases (which need legalization), we could have a
  91. // type like: _ptr_Function__ptr_Uniform_type_StructuredBuffer_float.
  92. // In such cases, we need to update the storage class for the inner pointer.
  93. if (const auto *innerPtrType = dyn_cast<SpirvPointerType>(innerType)) {
  94. if (hasStorageBufferInterfaceType(innerPtrType->getPointeeType()) &&
  95. innerPtrType->getStorageClass() != spv::StorageClass::StorageBuffer) {
  96. auto *newInnerType = context.getPointerType(
  97. innerPtrType->getPointeeType(), spv::StorageClass::StorageBuffer);
  98. *newType =
  99. context.getPointerType(newInnerType, ptrType->getStorageClass());
  100. *newStorageClass = ptrType->getStorageClass();
  101. return true;
  102. }
  103. }
  104. return false;
  105. }
  106. bool RemoveBufferBlockVisitor::visit(SpirvFunction *fn, Phase phase) {
  107. if (phase == Visitor::Phase::Init) {
  108. llvm::SmallVector<const SpirvType *, 4> paramTypes;
  109. bool updatedParamTypes = false;
  110. for (auto *param : fn->getParameters()) {
  111. const auto *paramType = param->getResultType();
  112. // This pass is run after all types are lowered.
  113. assert(paramType != nullptr);
  114. // Update the parameter type if needed (update storage class of pointers).
  115. const auto *newParamType = paramType;
  116. spv::StorageClass newParamSC = spv::StorageClass::Max;
  117. if (updateStorageClass(paramType, &newParamType, &newParamSC)) {
  118. param->setStorageClass(newParamSC);
  119. param->setResultType(newParamType);
  120. updatedParamTypes = true;
  121. }
  122. paramTypes.push_back(newParamType);
  123. }
  124. // Update the return type if needed (update storage class of pointers).
  125. const auto *returnType = fn->getReturnType();
  126. const auto *newReturnType = returnType;
  127. spv::StorageClass newReturnSC = spv::StorageClass::Max;
  128. bool updatedReturnType =
  129. updateStorageClass(returnType, &newReturnType, &newReturnSC);
  130. if (updatedReturnType) {
  131. fn->setReturnType(newReturnType);
  132. }
  133. if (updatedParamTypes || updatedReturnType) {
  134. fn->setFunctionType(context.getFunctionType(newReturnType, paramTypes));
  135. }
  136. return true;
  137. }
  138. return true;
  139. }
  140. } // end namespace spirv
  141. } // end namespace clang