PassUtils.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/std/sort.h>
  9. #include <Atom/RPI.Reflect/Pass/RenderPassData.h>
  10. #include <Atom/RPI.Public/Pass/Pass.h>
  11. #include <Atom/RPI.Public/Pass/PassUtils.h>
  12. namespace AZ
  13. {
  14. namespace RPI
  15. {
  16. namespace PassUtils
  17. {
  18. const PassData* GetPassData(const PassDescriptor& descriptor)
  19. {
  20. const PassData* passData = nullptr;
  21. // Try custom data from PassRequest
  22. if (descriptor.m_passRequest != nullptr)
  23. {
  24. passData = descriptor.m_passRequest->m_passData.get();
  25. }
  26. // Try custom data from PassTemplate
  27. if (passData == nullptr && descriptor.m_passTemplate != nullptr)
  28. {
  29. passData = descriptor.m_passTemplate->m_passData.get();
  30. }
  31. if (passData == nullptr)
  32. {
  33. passData = descriptor.m_passData.get();
  34. }
  35. return passData;
  36. }
  37. AZStd::shared_ptr<PassData> GetPassDataPtr(const PassDescriptor& descriptor)
  38. {
  39. AZStd::shared_ptr<PassData> passData = nullptr;
  40. if (descriptor.m_passRequest != nullptr)
  41. {
  42. passData = descriptor.m_passRequest->m_passData;
  43. }
  44. if (passData == nullptr && descriptor.m_passTemplate != nullptr)
  45. {
  46. passData = descriptor.m_passTemplate->m_passData;
  47. }
  48. if (passData == nullptr)
  49. {
  50. passData = descriptor.m_passData;
  51. }
  52. return passData;
  53. }
  54. void ExtractPipelineGlobalConnections(const AZStd::shared_ptr<PassData>& passData, PipelineGlobalConnectionList& outList)
  55. {
  56. for (const PipelineGlobalConnection& connection : passData->m_pipelineGlobalConnections)
  57. {
  58. outList.push_back(connection);
  59. }
  60. }
  61. bool BindDataMappingsToSrg(const PassDescriptor& descriptor, ShaderResourceGroup* shaderResourceGroup)
  62. {
  63. bool success = true;
  64. // Apply mappings from PassTemplate
  65. const RenderPassData* passData = nullptr;
  66. if (descriptor.m_passTemplate != nullptr)
  67. {
  68. passData = azrtti_cast<const RenderPassData*>(descriptor.m_passTemplate->m_passData.get());
  69. if (passData)
  70. {
  71. success = shaderResourceGroup->ApplyDataMappings(passData->m_mappings);
  72. }
  73. }
  74. // Apply mappings from PassRequest
  75. passData = nullptr;
  76. if (descriptor.m_passRequest != nullptr)
  77. {
  78. passData = azrtti_cast<const RenderPassData*>(descriptor.m_passRequest->m_passData.get());
  79. if (passData)
  80. {
  81. success = success && shaderResourceGroup->ApplyDataMappings(passData->m_mappings);
  82. }
  83. }
  84. // Apply mappings from custom data in the descriptor
  85. passData = azrtti_cast<const RenderPassData*>(descriptor.m_passData.get());
  86. if (passData)
  87. {
  88. success = success && shaderResourceGroup->ApplyDataMappings(passData->m_mappings);
  89. }
  90. return success;
  91. }
  92. // Sort so passes with less depth (closer to the root) are first. Used when changes
  93. // in the parent passes can affect the child passes, like with attachment building.
  94. void SortPassListAscending(AZStd::vector< Ptr<Pass> >& passList)
  95. {
  96. AZStd::sort(passList.begin(), passList.end(),
  97. [](const Ptr<Pass>& lhs, const Ptr<Pass>& rhs)
  98. {
  99. if ((lhs->GetTreeDepth() == rhs->GetTreeDepth()))
  100. {
  101. return lhs->GetParentChildIndex() < rhs->GetParentChildIndex();
  102. }
  103. return (lhs->GetTreeDepth() < rhs->GetTreeDepth());
  104. });
  105. }
  106. // Sort so passes with greater depth (further from the root) get called first. Used in the case of
  107. // delete, as we want to avoid deleting the parent first since this invalidates the child pointer.
  108. void SortPassListDescending(AZStd::vector< Ptr<Pass> >& passList)
  109. {
  110. AZStd::sort(passList.begin(), passList.end(),
  111. [](const Ptr<Pass>& lhs, const Ptr<Pass>& rhs)
  112. {
  113. if ((lhs->GetTreeDepth() == rhs->GetTreeDepth()))
  114. {
  115. return lhs->GetParentChildIndex() > rhs->GetParentChildIndex();
  116. }
  117. return (lhs->GetTreeDepth() > rhs->GetTreeDepth());
  118. }
  119. );
  120. }
  121. }
  122. }
  123. }