RemoveBufferBlockVisitor.cpp 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. //===-- RemoveBufferBlockVisitor.cpp - RemoveBufferBlock Visitor -*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "RemoveBufferBlockVisitor.h"
  10. #include "clang/SPIRV/SpirvContext.h"
  11. namespace {
  12. bool isBufferBlockDecorationDeprecated(
  13. const clang::spirv::SpirvCodeGenOptions &opts) {
  14. return opts.targetEnv.compare("vulkan1.2") >= 0;
  15. }
  16. } // end anonymous namespace
  17. namespace clang {
  18. namespace spirv {
  19. bool RemoveBufferBlockVisitor::visit(SpirvModule *mod, Phase phase) {
  20. // If the target environment is Vulkan 1.2 or later, BufferBlock decoration is
  21. // deprecated and should be removed from the module.
  22. // Otherwise, no action is needed by this IMR visitor.
  23. if (phase == Visitor::Phase::Init)
  24. if (!isBufferBlockDecorationDeprecated(spvOptions))
  25. return false;
  26. return true;
  27. }
  28. bool RemoveBufferBlockVisitor::visitInstruction(SpirvInstruction *inst) {
  29. if (!inst->getResultType())
  30. return true;
  31. // OpAccessChain can obtain pointers to any type. Its result type is
  32. // OpTypePointer, and it should get the same storage class as its base.
  33. if (auto *accessChain = dyn_cast<SpirvAccessChain>(inst)) {
  34. auto *accessChainType = accessChain->getResultType();
  35. auto *baseType = accessChain->getBase()->getResultType();
  36. // The result type of OpAccessChain and the result type of its base must be
  37. // OpTypePointer.
  38. assert(isa<SpirvPointerType>(accessChainType));
  39. assert(isa<SpirvPointerType>(baseType));
  40. auto *accessChainPtr = dyn_cast<SpirvPointerType>(accessChainType);
  41. auto *basePtr = dyn_cast<SpirvPointerType>(baseType);
  42. auto baseStorageClass = basePtr->getStorageClass();
  43. if (accessChainPtr->getStorageClass() != baseStorageClass) {
  44. auto *newAccessChainType = context.getPointerType(
  45. accessChainPtr->getPointeeType(), baseStorageClass);
  46. inst->setStorageClass(baseStorageClass);
  47. inst->setResultType(newAccessChainType);
  48. }
  49. }
  50. // For all instructions, if the result type is a pointer pointing to a struct
  51. // with StorageBuffer interface, the storage class must be updated.
  52. if (auto *ptrResultType = dyn_cast<SpirvPointerType>(inst->getResultType())) {
  53. if (auto *structPointeeType =
  54. dyn_cast<StructType>(ptrResultType->getPointeeType())) {
  55. // Update the instruction's storage class if necessary
  56. if (structPointeeType->getInterfaceType() ==
  57. StructInterfaceType::StorageBuffer &&
  58. ptrResultType->getStorageClass() !=
  59. spv::StorageClass::StorageBuffer) {
  60. inst->setStorageClass(spv::StorageClass::StorageBuffer);
  61. inst->setResultType(context.getPointerType(
  62. ptrResultType->getPointeeType(), spv::StorageClass::StorageBuffer));
  63. }
  64. }
  65. }
  66. return true;
  67. }
  68. } // end namespace spirv
  69. } // end namespace clang