3
0

QueryPool.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/Factory.h>
  9. #include <Atom/RHI/FrameGraphInterface.h>
  10. #include <Atom/RHI/Query.h>
  11. #include <Atom/RHI/RHISystemInterface.h>
  12. #include <Atom/RHI/Scope.h>
  13. #include <Atom/RPI.Public/GpuQuery/GpuQuerySystem.h>
  14. #include <Atom/RPI.Public/GpuQuery/QueryPool.h>
  15. namespace AZ
  16. {
  17. namespace RPI
  18. {
  19. static const char* GetQueryTypeString(RHI::QueryType queryType)
  20. {
  21. switch (queryType)
  22. {
  23. case RHI::QueryType::Occlusion:
  24. {
  25. return "Occlusion";
  26. }
  27. case RHI::QueryType::Timestamp:
  28. {
  29. return "Timestamp";
  30. }
  31. case RHI::QueryType::PipelineStatistics:
  32. {
  33. return "PipelineStatistics";
  34. }
  35. default:
  36. {
  37. AZ_Assert(false, "Unknown QueryType supplied");
  38. return "UnknownQueryType";
  39. }
  40. };
  41. }
  42. QueryPoolPtr QueryPool::CreateQueryPool(uint32_t queryCount, uint32_t rhiQueriesPerResult, RHI::QueryType queryType, RHI::PipelineStatisticsFlags pipelineStatisticsFlags)
  43. {
  44. return AZStd::unique_ptr<QueryPool>(aznew QueryPool(queryCount, rhiQueriesPerResult, queryType, pipelineStatisticsFlags));
  45. }
  46. QueryPool::QueryPool(uint32_t queryCapacity, uint32_t queriesPerResult, RHI::QueryType queryType, RHI::PipelineStatisticsFlags statisticsFlags)
  47. {
  48. RHI::Device* device = RHI::RHISystemInterface::Get()->GetDevice();
  49. m_queryCapacity = queryCapacity;
  50. m_queriesPerResult = queriesPerResult;
  51. m_statisticsFlags = statisticsFlags;
  52. m_queryType = queryType;
  53. // Calculate the total amount of RHI queries the RPI QueryPool needs to initialize.
  54. m_rhiQueryCapacity = m_queryCapacity * m_queriesPerResult * RPI::Query::BufferedFrames;
  55. m_rhiQueryArray.resize(m_rhiQueryCapacity);
  56. m_availableIntervalArray.reserve(queryCapacity);
  57. // Calculate the query result size.
  58. CalculateResultSize();
  59. // Populate the array with available RHI Query intervals.
  60. CreateRhiQueryIntervals();
  61. // Setup the query pool.
  62. {
  63. RHI::QueryPoolDescriptor queryPoolDesc;
  64. queryPoolDesc.m_queriesCount = m_rhiQueryCapacity;
  65. queryPoolDesc.m_type = m_queryType;
  66. queryPoolDesc.m_pipelineStatisticsMask = m_statisticsFlags;
  67. m_rhiQueryPool = RHI::Factory::Get().CreateQueryPool();
  68. AZStd::string poolName = AZStd::string::format("%sQueryPool", GetQueryTypeString(queryType));
  69. m_rhiQueryPool->SetName(AZ::Name(poolName));
  70. [[maybe_unused]] auto result = m_rhiQueryPool->Init(*device, queryPoolDesc);
  71. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create the query pool");
  72. }
  73. // Create the RHI queries.
  74. {
  75. AZStd::vector<RHI::Query*> rawQueryArray;
  76. rawQueryArray.reserve(m_rhiQueryArray.size());
  77. for (RHI::Ptr<RHI::Query>& query : m_rhiQueryArray)
  78. {
  79. query = RHI::Factory::Get().CreateQuery();
  80. rawQueryArray.emplace_back(query.get());
  81. }
  82. m_rhiQueryPool->InitQuery(rawQueryArray.data(), static_cast<uint32_t>(rawQueryArray.size()));
  83. }
  84. }
  85. QueryPool::~QueryPool()
  86. {
  87. // Unregister the queries first
  88. for (auto& query : m_queryRegistry)
  89. {
  90. query->UnregisterFromPool();
  91. }
  92. AZ_Assert(m_queryRegistry.empty(), "The QueryRegistry should be empty.");
  93. m_availableIntervalArray.clear();
  94. m_rhiQueryArray.clear();
  95. m_rhiQueryPool = nullptr;
  96. }
  97. void QueryPool::Update()
  98. {
  99. // Increment the QueryPool's FrameIndex.
  100. m_poolFrameIndex++;
  101. }
  102. RHI::Ptr<RPI::Query> QueryPool::CreateQuery(RHI::QueryPoolScopeAttachmentType attachmentType, RHI::ScopeAttachmentAccess attachmentAccess)
  103. {
  104. AZStd::unique_lock<AZStd::mutex> lock(m_queryRegistryMutex);
  105. // Get an available RHI Query interval.
  106. if (m_availableIntervalArray.empty())
  107. {
  108. AZ_WarningOnce("Gpu QueryPool", false,
  109. "There are no more available query indices left. This will result in Query data not being available for certain passes. \
  110. Initialize the RPI::QueryPool with a bigger capacity.");
  111. return nullptr;
  112. }
  113. RHI::Interval rhiQueryIndices = m_availableIntervalArray.back();
  114. m_availableIntervalArray.pop_back();
  115. // Create the RPI Query, and add it to the registry.
  116. auto* query = aznew RPI::Query(this, rhiQueryIndices, m_queryType, attachmentType, attachmentAccess);
  117. m_queryRegistry.emplace(query);
  118. return query;
  119. }
  120. void QueryPool::UnregisterQuery(RPI::Query* query)
  121. {
  122. AZ_Assert(query, "The RPI::Query has to be valid");
  123. AZStd::unique_lock<AZStd::mutex> lock(m_queryRegistryMutex);
  124. // Push the RHI Query indices back into the array of available indices for reuse.
  125. m_availableIntervalArray.emplace_back(query->m_rhiQueryIndices);
  126. // Invalidate the RPI Query's QueryPool.
  127. query->m_queryPool = nullptr;
  128. // Remove the RPI Query from the registry.
  129. m_queryRegistry.erase(query);
  130. }
  131. RHI::ResultCode QueryPool::BeginQueryInternal(RHI::Interval rhiQueryIndices, RHI::CommandList& commandList)
  132. {
  133. auto rhiQueryArray = GetRhiQueryArray();
  134. RHI::Ptr<RHI::Query> beginQuery = rhiQueryArray[rhiQueryIndices.m_min];
  135. return beginQuery->Begin(commandList);
  136. }
  137. RHI::ResultCode QueryPool::EndQueryInternal(RHI::Interval rhiQueryIndices, RHI::CommandList& commandList)
  138. {
  139. auto rhiQueryArray = GetRhiQueryArray();
  140. RHI::Ptr<RHI::Query> endQuery = rhiQueryArray[rhiQueryIndices.m_max];
  141. return endQuery->End(commandList);
  142. }
  143. AZStd::span<const RHI::Ptr<RHI::Query>> RPI::QueryPool::GetRhiQueryArray() const
  144. {
  145. return m_rhiQueryArray;
  146. }
  147. QueryResultCode QueryPool::GetQueryResultFromIndices(uint64_t* result, RHI::Interval rhiQueryIndices, RHI::QueryResultFlagBits queryResultFlag)
  148. {
  149. // Get the raw RHI Query pointers.
  150. AZStd::vector<RHI::Query*> queryArray = GetRawRhiQueriesFromInterval(rhiQueryIndices);
  151. // RHI Query results are readback with values that are a multiple of uint64_t.
  152. const uint32_t resultCount = m_queryResultSize / sizeof(uint64_t);
  153. const RHI::ResultCode resultCode = m_rhiQueryPool->GetResults(queryArray.data(), m_queriesPerResult, result, resultCount, queryResultFlag);
  154. return resultCode == RHI::ResultCode::Success ? QueryResultCode::Success : QueryResultCode::Fail;
  155. }
  156. uint32_t QueryPool::GetQueryResultSize() const
  157. {
  158. return m_queryResultSize;
  159. }
  160. void QueryPool::CalculateResultSize()
  161. {
  162. using namespace RHI;
  163. // Query result element count per QueryType.
  164. const uint32_t TimestampResultCount = 2u;
  165. const uint32_t OcclusionResultCount = 1u;
  166. uint32_t resultCount = 0u;
  167. // Determine the result size in uint64 by the QueryType.
  168. switch (m_queryType)
  169. {
  170. case QueryType::PipelineStatistics:
  171. // Each bit set, is translated to an additional result.
  172. resultCount = CountBitsSet(static_cast<uint64_t>(m_statisticsFlags));
  173. break;
  174. case QueryType::Timestamp:
  175. // A single timestamp result consists of two values.
  176. resultCount = TimestampResultCount;
  177. break;
  178. case QueryType::Occlusion:
  179. // A single occlusion result consists of one value.
  180. resultCount = OcclusionResultCount;
  181. break;
  182. default:
  183. AZ_Assert(false, "Unsupported QueryType");
  184. break;
  185. }
  186. m_queryResultSize = resultCount * sizeof(uint64_t);
  187. }
  188. void QueryPool::CreateRhiQueryIntervals()
  189. {
  190. // Calculates the RHI Query indices that are associated with the RPI Query.
  191. const auto getRhiQuriesFromRpiQueryIndex = [this](uint32_t rpiQueryIndex)
  192. {
  193. // The amount of RHI Queries that are required for a single RPI Query.
  194. const uint32_t queryIntervalSize = m_queriesPerResult * RPI::Query::BufferedFrames;
  195. const uint32_t queryIntervalOffset = rpiQueryIndex * queryIntervalSize;
  196. return RHI::Interval(queryIntervalOffset, queryIntervalOffset + queryIntervalSize - 1u);
  197. };
  198. for (uint32_t i = 0u; i < m_queryCapacity; i++)
  199. {
  200. m_availableIntervalArray.emplace_back(getRhiQuriesFromRpiQueryIndex(i));
  201. }
  202. }
  203. uint64_t QueryPool::GetPoolFrameIndex() const
  204. {
  205. return m_poolFrameIndex;
  206. }
  207. uint32_t QueryPool::GetQueriesPerResult() const
  208. {
  209. return m_queriesPerResult;
  210. }
  211. AZStd::span<const RHI::Ptr<RHI::Query>> QueryPool::GetRhiQueriesFromInterval(const RHI::Interval& rhiQueryIndices) const
  212. {
  213. const uint32_t queryCount = rhiQueryIndices.m_max - rhiQueryIndices.m_min + 1u;
  214. AZ_Assert(rhiQueryIndices.m_max < m_rhiQueryCapacity, "Query array index is going over the limit");
  215. return AZStd::span<const RHI::Ptr<RHI::Query>>(m_rhiQueryArray.begin() + rhiQueryIndices.m_min, queryCount);
  216. }
  217. AZStd::vector<RHI::Query*> QueryPool::GetRawRhiQueriesFromInterval(const RHI::Interval& rhiQueryIndices) const
  218. {
  219. auto rhiQueries = GetRhiQueriesFromInterval(rhiQueryIndices);
  220. AZStd::vector<RHI::Query*> queryArray;
  221. queryArray.reserve(rhiQueries.size());
  222. for (RHI::Ptr<RHI::Query> rhiQuery : rhiQueries)
  223. {
  224. queryArray.emplace_back(rhiQuery.get());
  225. }
  226. return queryArray;
  227. }
  228. }; // Namespace RPI
  229. }; // Namespace AZ