PipelineStateCache.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/PipelineStateCache.h>
  9. #include <Atom/RHI/Factory.h>
  10. #include <AzCore/Debug/Profiler.h>
  11. #include <AzCore/std/sort.h>
  12. #include <AzCore/std/parallel/exponential_backoff.h>
  13. namespace AZ::RHI
  14. {
  15. Ptr<PipelineStateCache> PipelineStateCache::Create(MultiDevice::DeviceMask deviceMask)
  16. {
  17. return aznew PipelineStateCache(deviceMask);
  18. }
  19. PipelineStateCache::PipelineStateCache(MultiDevice::DeviceMask deviceMask)
  20. : m_deviceMask{ deviceMask }
  21. {
  22. }
  23. void PipelineStateCache::ValidateCacheIntegrity() const
  24. {
  25. #if defined(AZ_ENABLE_TRACING)
  26. for (size_t i = 0; i < m_globalLibrarySet.size(); ++i)
  27. {
  28. const GlobalLibraryEntry& globalLibraryEntry = m_globalLibrarySet[i];
  29. const PipelineStateSet& readOnlyCache = globalLibraryEntry.m_readOnlyCache;
  30. AZ_Assert(globalLibraryEntry.m_pendingCompileCount == 0, "Compiles are pending for pipeline library");
  31. AZ_Assert(globalLibraryEntry.m_pendingCache.empty(), "Pending cache is not empty.");
  32. if (!m_globalLibraryActiveBits[i])
  33. {
  34. AZ_Assert(readOnlyCache.empty(), "Inactive library has pipeline states in its global entry.");
  35. }
  36. #if defined(AZ_DEBUG_BUILD)
  37. // the PipelineStateSet is expensive to duplicate, only do this in debug.
  38. PipelineStateSet readOnlyCacheCopy = readOnlyCache;
  39. AZ_Assert(AZStd::unique(readOnlyCacheCopy.begin(), readOnlyCacheCopy.end()) == readOnlyCacheCopy.end(),
  40. "'%d' Duplicates existed in the read-only cache!", readOnlyCache.size() - readOnlyCacheCopy.size());
  41. #endif
  42. }
  43. m_threadLibrarySet.ForEach([this](const ThreadLibrarySet& threadLibrarySet)
  44. {
  45. const size_t libraryCount = m_globalLibrarySet.size();
  46. for (size_t i = 0; i < libraryCount; ++i)
  47. {
  48. const ThreadLibraryEntry& threadLibraryEntry = threadLibrarySet[i];
  49. if (!m_globalLibraryActiveBits[i])
  50. {
  51. AZ_Assert(!threadLibraryEntry.m_library, "Inactive library has a valid RHI::PipelineLibrary instance.");
  52. }
  53. AZ_Assert(threadLibraryEntry.m_threadLocalCache.empty(), "Thread library should not have any items in its local cache.");
  54. }
  55. });
  56. #endif
  57. }
  58. void PipelineStateCache::Reset()
  59. {
  60. AZStd::unique_lock<AZStd::shared_mutex> lock(m_mutex);
  61. for (size_t i = 0; i < m_globalLibrarySet.size(); ++i)
  62. {
  63. if (m_globalLibraryActiveBits[i])
  64. {
  65. ResetLibraryImpl(PipelineLibraryHandle(i));
  66. }
  67. }
  68. }
  69. PipelineLibraryHandle PipelineStateCache::CreateLibrary(
  70. const AZStd::unordered_map<int, ConstPtr<RHI::PipelineLibraryData>>& serializedData, const AZStd::unordered_map<int, AZStd::string>& filePaths)
  71. {
  72. AZStd::unique_lock<AZStd::shared_mutex> lock(m_mutex);
  73. PipelineLibraryHandle handle;
  74. if (!m_libraryFreeList.empty())
  75. {
  76. handle = m_libraryFreeList.back();
  77. m_libraryFreeList.pop_back();
  78. }
  79. else
  80. {
  81. if (m_globalLibrarySet.size() == LibraryCountMax)
  82. {
  83. AZ_Error(
  84. "PipelineStateCache",
  85. false,
  86. "Exceeded maximum number of allowed pipeline libraries in "
  87. "cache. You must update LibraryCountMax to add more.");
  88. return {};
  89. }
  90. handle = PipelineLibraryHandle(m_globalLibrarySet.size());
  91. m_globalLibrarySet.emplace_back();
  92. }
  93. AZ_Assert(m_globalLibraryActiveBits[handle.GetIndex()] == false, "Attempted to allocate active library entry!");
  94. m_globalLibraryActiveBits[handle.GetIndex()] = true;
  95. GlobalLibraryEntry& libraryEntry = m_globalLibrarySet[handle.GetIndex()];
  96. libraryEntry.m_pipelineLibraryDescriptor.Init(m_deviceMask, serializedData, filePaths);
  97. AZ_Assert(libraryEntry.m_readOnlyCache.empty() && libraryEntry.m_pendingCache.empty(), "Library entry has entries in its caches!");
  98. return handle;
  99. }
  100. void PipelineStateCache::ReleaseLibrary(PipelineLibraryHandle handle)
  101. {
  102. if (handle.IsValid())
  103. {
  104. AZStd::unique_lock<AZStd::shared_mutex> lock(m_mutex);
  105. AZ_Assert(m_globalLibraryActiveBits[handle.GetIndex()], "Releasing a library that is no longer valid.");
  106. ResetLibraryImpl(handle);
  107. GlobalLibraryEntry& libraryEntry = m_globalLibrarySet[handle.GetIndex()];
  108. libraryEntry.m_readOnlyCache.clear();
  109. libraryEntry.m_pipelineLibraryDescriptor.Init(m_deviceMask, {}, {});
  110. m_globalLibraryActiveBits[handle.GetIndex()] = false;
  111. m_libraryFreeList.push_back(handle);
  112. }
  113. }
  114. void PipelineStateCache::ResetLibrary(PipelineLibraryHandle handle)
  115. {
  116. if (handle.IsValid())
  117. {
  118. AZStd::unique_lock<AZStd::shared_mutex> lock(m_mutex);
  119. ResetLibraryImpl(handle);
  120. }
  121. }
  122. void PipelineStateCache::ResetLibraryImpl(PipelineLibraryHandle handle)
  123. {
  124. m_threadLibrarySet.ForEach(
  125. [handle](ThreadLibrarySet& librarySet)
  126. {
  127. ThreadLibraryEntry& libraryEntry = librarySet[handle.GetIndex()];
  128. libraryEntry.m_library = nullptr;
  129. libraryEntry.m_threadLocalCache.clear();
  130. });
  131. GlobalLibraryEntry& libraryEntry = m_globalLibrarySet[handle.GetIndex()];
  132. AZ_Assert(libraryEntry.m_pendingCompileCount == 0, "Reseting library while compiles are still pending!");
  133. libraryEntry.m_readOnlyCache.clear();
  134. libraryEntry.m_pendingCacheMutex.lock();
  135. libraryEntry.m_pendingCache.clear();
  136. libraryEntry.m_pendingCacheMutex.unlock();
  137. }
  138. Ptr<PipelineLibrary> PipelineStateCache::GetMergedLibrary(PipelineLibraryHandle handle) const
  139. {
  140. if (handle.IsNull())
  141. {
  142. return nullptr;
  143. }
  144. AZStd::unique_lock<AZStd::shared_mutex> lock(m_mutex);
  145. const GlobalLibraryEntry& entry = m_globalLibrarySet[handle.GetIndex()];
  146. //! Each thread has its own PipelineLibrary instance. To produce the final serialized data, we
  147. //! coalesce data from each individual library by merging the thread-local ones into a single
  148. //! global (temporary) library. The data is then extracted from this global library and returned.
  149. //! This operation is designed to happen once at application shutdown; certainly not every frame.
  150. AZStd::vector<const PipelineLibrary*> threadLibraries;
  151. m_threadLibrarySet.ForEach(
  152. [handle, &threadLibraries](const ThreadLibrarySet& threadLibrarySet)
  153. {
  154. const ThreadLibraryEntry& threadLibraryEntry = threadLibrarySet[handle.GetIndex()];
  155. // Skip libraries that failed to initialize.
  156. if (threadLibraryEntry.m_library && threadLibraryEntry.m_library->IsInitialized())
  157. {
  158. threadLibraries.push_back(threadLibraryEntry.m_library.get());
  159. }
  160. });
  161. bool doesPSODataExist{ false };
  162. for (auto& [deviceIndex, devicePipelineLibraryDescriptor] : entry.m_pipelineLibraryDescriptor.m_devicePipelineLibraryDescriptors)
  163. {
  164. doesPSODataExist |= devicePipelineLibraryDescriptor.m_serializedData.get() != nullptr;
  165. }
  166. for (const RHI::PipelineLibrary* libraryBase : threadLibraries)
  167. {
  168. const PipelineLibrary* library = static_cast<const PipelineLibrary*>(libraryBase);
  169. doesPSODataExist |= library->IsMergeRequired();
  170. }
  171. if (doesPSODataExist)
  172. {
  173. Ptr<PipelineLibrary> pipelineLibrary = aznew PipelineLibrary;
  174. ResultCode resultCode = pipelineLibrary->Init(m_deviceMask, entry.m_pipelineLibraryDescriptor);
  175. if (resultCode == ResultCode::Success)
  176. {
  177. resultCode = pipelineLibrary->MergeInto(threadLibraries);
  178. if (resultCode == ResultCode::Success)
  179. {
  180. return pipelineLibrary;
  181. }
  182. }
  183. }
  184. return nullptr;
  185. }
  186. void PipelineStateCache::Compact()
  187. {
  188. AZ_PROFILE_SCOPE(RHI, "PipelineStateCache: Compact");
  189. AZStd::unique_lock<AZStd::shared_mutex> lock(m_mutex);
  190. // Merge the pending cache into the read-only cache.
  191. bool hasCompiledPipelineStates = false;
  192. for (size_t i = 0; i < m_globalLibrarySet.size(); ++i)
  193. {
  194. GlobalLibraryEntry& globalLibraryEntry = m_globalLibrarySet[i];
  195. // Skip inactive libraries and ones that didn't compile anything this cycle.
  196. if (m_globalLibraryActiveBits[i] && !globalLibraryEntry.m_pendingCache.empty())
  197. {
  198. hasCompiledPipelineStates = true;
  199. // Allocate a temporary staging set, perform the merge, and then move it back into the read-only cache.
  200. PipelineStateSet mergeResult;
  201. mergeResult.reserve(globalLibraryEntry.m_readOnlyCache.size() + globalLibraryEntry.m_pendingCache.size());
  202. AZStd::merge(
  203. globalLibraryEntry.m_readOnlyCache.begin(), globalLibraryEntry.m_readOnlyCache.end(),
  204. globalLibraryEntry.m_pendingCache.begin(), globalLibraryEntry.m_pendingCache.end(),
  205. AZStd::inserter(mergeResult, mergeResult.begin()));
  206. globalLibraryEntry.m_readOnlyCache.swap(mergeResult);
  207. globalLibraryEntry.m_pendingCache.clear();
  208. }
  209. }
  210. // If we had compilation events, then the thread-local caches are not empty and need to be cleared.
  211. if (hasCompiledPipelineStates)
  212. {
  213. const size_t libraryCount = m_globalLibrarySet.size();
  214. m_threadLibrarySet.ForEach([this, libraryCount](ThreadLibrarySet& threadLibrarySet)
  215. {
  216. for (size_t i = 0; i < libraryCount; ++i)
  217. {
  218. if (m_globalLibraryActiveBits[i])
  219. {
  220. threadLibrarySet[i].m_threadLocalCache.clear();
  221. }
  222. }
  223. });
  224. }
  225. ValidateCacheIntegrity();
  226. }
  227. const PipelineState* PipelineStateCache::FindPipelineState(
  228. const PipelineStateSet& pipelineStateSet, const PipelineStateDescriptor& descriptor)
  229. {
  230. auto pipelineStateIt = pipelineStateSet.find(PipelineStateEntry(descriptor.GetHash(), nullptr, descriptor));
  231. if (pipelineStateIt != pipelineStateSet.end())
  232. {
  233. return pipelineStateIt->m_pipelineState.get();
  234. }
  235. return nullptr;
  236. }
  237. bool PipelineStateCache::InsertPipelineState(PipelineStateSet& pipelineStateSet, PipelineStateEntry pipelineStateEntry)
  238. {
  239. auto ret = pipelineStateSet.insert(pipelineStateEntry);
  240. return ret.second;
  241. }
  242. const PipelineState* PipelineStateCache::AcquirePipelineState(
  243. PipelineLibraryHandle handle, const PipelineStateDescriptor& descriptor, const AZ::Name& name /*= AZ::Name()*/)
  244. {
  245. if (handle.IsNull())
  246. {
  247. return nullptr;
  248. }
  249. AZStd::shared_lock<AZStd::shared_mutex> lock(m_mutex);
  250. GlobalLibraryEntry& globalLibraryEntry = m_globalLibrarySet[handle.GetIndex()];
  251. PipelineStateHash pipelineStateHash = descriptor.GetHash();
  252. // Search the read-only cache first.
  253. if (const PipelineState* pipelineState = FindPipelineState(globalLibraryEntry.m_readOnlyCache, descriptor))
  254. {
  255. return pipelineState;
  256. }
  257. // Search the thread-local cache next.
  258. {
  259. ThreadLibrarySet& threadLibrarySet = m_threadLibrarySet.GetStorage();
  260. ThreadLibraryEntry& threadLibraryEntry = threadLibrarySet[handle.GetIndex()];
  261. PipelineStateSet& threadLocalCache = threadLibraryEntry.m_threadLocalCache;
  262. if (const PipelineState* pipelineState = FindPipelineState(threadLocalCache, descriptor))
  263. {
  264. return pipelineState;
  265. }
  266. // No entry in the thread-local set. Request a pipeline state from the pending cache and add
  267. // it to the thread-local cache to reduce contention on the pending cache.
  268. {
  269. // Lazy-init the library on first access.
  270. if (!threadLibraryEntry.m_library)
  271. {
  272. Ptr<PipelineLibrary> pipelineLibrary = aznew PipelineLibrary;
  273. RHI::ResultCode resultCode = pipelineLibrary->Init(m_deviceMask, globalLibraryEntry.m_pipelineLibraryDescriptor);
  274. if (resultCode != RHI::ResultCode::Success)
  275. {
  276. AZ_Warning(
  277. "PipelineStateCache",
  278. false,
  279. "Failed to initialize pipeline library. PipelineLibrary usage is disabled.");
  280. }
  281. // We store a valid pointer even if initialization failed, to avoid attempting
  282. // to re-create it with every access.
  283. threadLibraryEntry.m_library = AZStd::move(pipelineLibrary);
  284. }
  285. ConstPtr<PipelineState> pipelineState =
  286. CompilePipelineState(globalLibraryEntry, threadLibraryEntry, descriptor, pipelineStateHash, name);
  287. [[maybe_unused]] bool success =
  288. InsertPipelineState(threadLocalCache, PipelineStateEntry(pipelineStateHash, pipelineState, descriptor));
  289. AZ_Assert(success, "PipelineStateEntry already exists in the thread cache.");
  290. return pipelineState.get();
  291. }
  292. }
  293. }
  294. ConstPtr<PipelineState> PipelineStateCache::CompilePipelineState(
  295. GlobalLibraryEntry& globalLibraryEntry,
  296. ThreadLibraryEntry& threadLibraryEntry,
  297. const PipelineStateDescriptor& descriptor,
  298. PipelineStateHash pipelineStateHash,
  299. const AZ::Name& name)
  300. {
  301. Ptr<PipelineState> pipelineState;
  302. PipelineStateSet& pendingCache = globalLibraryEntry.m_pendingCache;
  303. {
  304. AZStd::lock_guard<AZStd::mutex> lock(globalLibraryEntry.m_pendingCacheMutex);
  305. // Another thread may have started compiling this pipeline state. Check the pending cache.
  306. if (const PipelineState* pipeline = FindPipelineState(pendingCache, descriptor))
  307. {
  308. return pipeline;
  309. }
  310. // We need to create and insert the pipeline state into the locked cache. Create the pipeline state
  311. // but don't initialize it yet. We can safely allocate the 'empty' instance and cache it.
  312. pipelineState = aznew PipelineState;
  313. [[maybe_unused]] bool success =
  314. InsertPipelineState(pendingCache, PipelineStateEntry(pipelineStateHash, pipelineState, descriptor));
  315. AZ_Assert(success, "PipelineStateEntry already exists in the pending cache.");
  316. }
  317. [[maybe_unused]] ResultCode resultCode = ResultCode::InvalidArgument;
  318. // Increment the pending compile count on the global entry, which tracks how many pipeline states
  319. // are currently being compiled across all threads.
  320. if (Validation::IsEnabled())
  321. {
  322. ++globalLibraryEntry.m_pendingCompileCount;
  323. }
  324. // If the pipeline library failed to initialize, then we don't use it.
  325. PipelineLibrary* pipelineLibrary = threadLibraryEntry.m_library.get();
  326. if (!pipelineLibrary->IsInitialized())
  327. {
  328. pipelineLibrary = nullptr;
  329. }
  330. // We no longer have the lock, but we own compilation of the pipeline state. Use the
  331. // thread-local library to perform compilation without blocking other threads.
  332. resultCode = pipelineState->Init(m_deviceMask, descriptor, pipelineLibrary);
  333. pipelineState->SetName(name);
  334. if (Validation::IsEnabled())
  335. {
  336. --globalLibraryEntry.m_pendingCompileCount;
  337. }
  338. // NOTE: We can't return null on a failure, since other threads will return the entry without compiling
  339. // it. Instead, the pipeline state remains uninitialized.
  340. AZ_Error(
  341. "PipelineStateCache",
  342. resultCode == ResultCode::Success,
  343. "Failed to compile pipeline state. It will remain in an initialized state.");
  344. return AZStd::move(pipelineState);
  345. }
  346. PipelineStateCache::PipelineStateEntry::PipelineStateEntry(
  347. PipelineStateHash hash, ConstPtr<PipelineState> pipelineState, const PipelineStateDescriptor& descriptor)
  348. : m_hash{ hash }
  349. , m_pipelineState{ AZStd::move(pipelineState) }
  350. {
  351. switch (descriptor.GetType())
  352. {
  353. case PipelineStateType::Dispatch:
  354. m_pipelineStateDescriptorVariant = static_cast<const AZ::RHI::PipelineStateDescriptorForDispatch&>(descriptor);
  355. break;
  356. case PipelineStateType::Draw:
  357. m_pipelineStateDescriptorVariant = static_cast<const AZ::RHI::PipelineStateDescriptorForDraw&>(descriptor);
  358. break;
  359. case PipelineStateType::RayTracing:
  360. m_pipelineStateDescriptorVariant = static_cast<const AZ::RHI::PipelineStateDescriptorForRayTracing&>(descriptor);
  361. break;
  362. }
  363. }
  364. bool PipelineStateCache::PipelineStateEntry::operator == (const PipelineStateCache::PipelineStateEntry& rhs) const
  365. {
  366. if(AZStd::get_if<AZ::RHI::PipelineStateDescriptorForDispatch>(&rhs.m_pipelineStateDescriptorVariant) &&
  367. AZStd::get_if<AZ::RHI::PipelineStateDescriptorForDispatch>(&m_pipelineStateDescriptorVariant))
  368. {
  369. const AZ::RHI::PipelineStateDescriptorForDispatch& lhsDesc = AZStd::get<PipelineStateDescriptorForDispatch>(m_pipelineStateDescriptorVariant);
  370. const AZ::RHI::PipelineStateDescriptorForDispatch& rhsDesc = AZStd::get<PipelineStateDescriptorForDispatch>(rhs.m_pipelineStateDescriptorVariant);
  371. return lhsDesc == rhsDesc;
  372. }
  373. else if(AZStd::get_if<AZ::RHI::PipelineStateDescriptorForDraw>(&rhs.m_pipelineStateDescriptorVariant) &&
  374. AZStd::get_if<AZ::RHI::PipelineStateDescriptorForDraw>(&m_pipelineStateDescriptorVariant))
  375. {
  376. const AZ::RHI::PipelineStateDescriptorForDraw& lhsDesc = AZStd::get<PipelineStateDescriptorForDraw>(m_pipelineStateDescriptorVariant);
  377. const AZ::RHI::PipelineStateDescriptorForDraw& rhsDesc = AZStd::get<PipelineStateDescriptorForDraw>(rhs.m_pipelineStateDescriptorVariant);
  378. return lhsDesc == rhsDesc;
  379. }
  380. else if(AZStd::get_if<AZ::RHI::PipelineStateDescriptorForRayTracing>(&rhs.m_pipelineStateDescriptorVariant) &&
  381. AZStd::get_if<AZ::RHI::PipelineStateDescriptorForRayTracing>(&m_pipelineStateDescriptorVariant))
  382. {
  383. const AZ::RHI::PipelineStateDescriptorForRayTracing& lhsDesc = AZStd::get<PipelineStateDescriptorForRayTracing>(m_pipelineStateDescriptorVariant);
  384. const AZ::RHI::PipelineStateDescriptorForRayTracing& rhsDesc = AZStd::get<PipelineStateDescriptorForRayTracing>(rhs.m_pipelineStateDescriptorVariant);
  385. return lhsDesc == rhsDesc;
  386. }
  387. return false;
  388. }
  389. }