EBusSharedDispatchMutexTests.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/EBus/EBus.h>
  9. #include <AzCore/EBus/EBusSharedDispatchTraits.h>
  10. #include <AzCore/std/parallel/mutex.h>
  11. #include <AzCore/std/parallel/semaphore.h>
  12. #include <AzCore/std/parallel/thread.h>
  13. #include <AzCore/UnitTest/TestTypes.h>
  14. #include <Tests/AZTestShared/Utils/Utils.h>
  15. #include <gtest/gtest.h>
  16. namespace UnitTest
  17. {
  18. // Test EBus that uses the EBusSharedDispatchMutex.
  19. class SharedDispatchRequests : public AZ::EBusSharedDispatchTraits<SharedDispatchRequests>
  20. {
  21. public:
  22. static const AZ::EBusAddressPolicy AddressPolicy = AZ::EBusAddressPolicy::Single;
  23. static const AZ::EBusHandlerPolicy HandlerPolicy = AZ::EBusHandlerPolicy::Single;
  24. // Custom disconnect policy is used here to verify that disconnects do not occur while dispatches are in progress.
  25. template<class Bus>
  26. struct ConnectionPolicy : public AZ::EBusConnectionPolicy<Bus>
  27. {
  28. static void Disconnect(
  29. typename Bus::Context& context,
  30. typename Bus::HandlerNode& handler,
  31. typename Bus::BusPtr& busPtr)
  32. {
  33. EXPECT_EQ(m_totalRecursiveQueriesInProgress, 0);
  34. AZ::EBusConnectionPolicy<Bus>::Disconnect(context, handler, busPtr);
  35. }
  36. };
  37. // Provide a test EBus call that can be run in parallel.
  38. virtual void RecursiveQuery(int32_t numRecursions = 5) = 0;
  39. // These are static and defined on the EBus so that we can check the values from Disconnect.
  40. static AZStd::atomic_int m_totalRecursiveQueriesInProgress;
  41. static AZStd::atomic_int m_totalRecursiveQueriesCompleted;
  42. };
  43. using SharedDispatchRequestBus = AZ::EBus<SharedDispatchRequests>;
  44. AZStd::atomic_int SharedDispatchRequests::m_totalRecursiveQueriesInProgress = 0;
  45. AZStd::atomic_int SharedDispatchRequests::m_totalRecursiveQueriesCompleted = 0;
  46. // Test EBus handler that provides recursion and synchronization to test out the features of the EBusSharedDispatchMutex.
  47. class SharedDispatchRequestHandler : public SharedDispatchRequestBus::Handler
  48. {
  49. public:
  50. AZ_CLASS_ALLOCATOR(SharedDispatchRequestHandler, AZ::SystemAllocator);
  51. AZStd::semaphore m_querySemaphore;
  52. AZStd::semaphore m_syncSemaphore;
  53. AZStd::semaphore m_disconnectSemaphore;
  54. AZStd::atomic_int m_numDisconnects = 0;
  55. SharedDispatchRequestHandler()
  56. {
  57. // Reinitialize these for every test.
  58. m_totalRecursiveQueriesInProgress = 0;
  59. m_totalRecursiveQueriesCompleted = 0;
  60. }
  61. ~SharedDispatchRequestHandler() override
  62. {
  63. SharedDispatchRequestBus::Handler::BusDisconnect();
  64. }
  65. void Connect()
  66. {
  67. SharedDispatchRequestBus::Handler::BusConnect();
  68. }
  69. void Disconnect()
  70. {
  71. // Signal that the thread is running and has at least made it this far.
  72. m_disconnectSemaphore.release();
  73. SharedDispatchRequestBus::Handler::BusDisconnect();
  74. m_numDisconnects++;
  75. }
  76. void RecursiveQuery(int32_t numRecursions = 5) override
  77. {
  78. if (numRecursions <= 0)
  79. {
  80. // At the end of the recursion, signal the syncSemaphore that we've reached the end of the recursion.
  81. // We'll use this as a way to guarantee that all our threads have reached this point at the same time.
  82. m_syncSemaphore.release();
  83. // Block on the querySemaphore. This won't get released until every thread has released the syncSemaphore.
  84. m_querySemaphore.acquire();
  85. // Track that we've completed the query successfully.
  86. m_totalRecursiveQueriesCompleted++;
  87. return;
  88. }
  89. // Recursively call the EBus a fixed number of times, and keep track of how many times we've successfully recursed.
  90. m_totalRecursiveQueriesInProgress++;
  91. SharedDispatchRequestBus::Broadcast(&SharedDispatchRequestBus::Events::RecursiveQuery, numRecursions - 1);
  92. m_totalRecursiveQueriesInProgress--;
  93. }
  94. };
  95. class EBusSharedDispatchMutexTestFixture
  96. : public LeakDetectionFixture
  97. {
  98. public:
  99. EBusSharedDispatchMutexTestFixture()
  100. {
  101. SharedDispatchRequestBus::GetOrCreateContext();
  102. }
  103. };
  104. TEST_F(EBusSharedDispatchMutexTestFixture, RecursiveBusCallsOnSingleThreadWorks)
  105. {
  106. // Verify that multiple nested bus calls to the same bus on the same thread works without deadlocks.
  107. constexpr int32_t TotalRecursiveQueries = 10;
  108. SharedDispatchRequestHandler handler;
  109. handler.Connect();
  110. // This is a single-threaded test, so we don't need the recursive query to block before returning.
  111. handler.m_querySemaphore.release();
  112. SharedDispatchRequestBus::Broadcast(&SharedDispatchRequestBus::Events::RecursiveQuery, TotalRecursiveQueries);
  113. EXPECT_EQ(handler.m_totalRecursiveQueriesInProgress, 0);
  114. EXPECT_EQ(handler.m_totalRecursiveQueriesCompleted, 1);
  115. // Not strictly needed, but since we're doing a release() in RecursiveQuery, this keeps the semaphore acquire/release calls
  116. // balanced for the test.
  117. handler.m_syncSemaphore.acquire();
  118. handler.Disconnect();
  119. }
  120. TEST_F(EBusSharedDispatchMutexTestFixture, RecursiveBusCallsOnMultipleThreadsWork)
  121. {
  122. // Verify that multiple dispatched events run in parallel without deadlocks, even if each thread has recursively called
  123. // events on the same bus.
  124. const int32_t TotalRecursiveQueries = 10;
  125. SharedDispatchRequestHandler handler;
  126. handler.Connect();
  127. constexpr size_t ThreadCount = 4;
  128. AZStd::thread threads[ThreadCount];
  129. // Each thread will trigger the RecursiveQuery call. This call has semaphores in it so that we can guarantee that
  130. // every thread has reached the same state at the same time.
  131. for (AZStd::thread& thread : threads)
  132. {
  133. thread = AZStd::thread(
  134. [TotalRecursiveQueries]()
  135. {
  136. SharedDispatchRequestBus::Broadcast(&SharedDispatchRequestBus::Events::RecursiveQuery, TotalRecursiveQueries);
  137. });
  138. }
  139. // Wait for all the threads to reach the point where they're blocking. This will occur once they've each successfully called
  140. // down through the RecursiveQuery multiple times and are ready to finish.
  141. for (size_t threadNum = 0; threadNum < ThreadCount; threadNum++)
  142. {
  143. handler.m_syncSemaphore.acquire();
  144. }
  145. // Before unblocking the threads, verify that we've got the total number of expected recursions in progress
  146. // and that none of the calls have completed.
  147. EXPECT_EQ(handler.m_totalRecursiveQueriesInProgress, TotalRecursiveQueries * ThreadCount);
  148. EXPECT_EQ(handler.m_totalRecursiveQueriesCompleted, 0);
  149. // Unblock all the threads.
  150. for (size_t threadNum = 0; threadNum < ThreadCount; threadNum++)
  151. {
  152. handler.m_querySemaphore.release();
  153. }
  154. // Wait for the threads to finish.
  155. for (AZStd::thread& thread : threads)
  156. {
  157. thread.join();
  158. }
  159. // Verify that we ended up with the correct number of completed recursive calls and that none are still in progress.
  160. EXPECT_EQ(handler.m_totalRecursiveQueriesInProgress, 0);
  161. EXPECT_EQ(handler.m_totalRecursiveQueriesCompleted, ThreadCount);
  162. handler.Disconnect();
  163. }
  164. TEST_F(EBusSharedDispatchMutexTestFixture, DispatchCallsBlockDisconnectFromRunning)
  165. {
  166. // Verify that BusConnect / BusDisconnect cannot run in parallel with event dispatches.
  167. // We can't easily test BusConnect running in parallel, because by definition no dispatches can successfully occur before
  168. // the handler is connected. However, we can test Disconnect by doing the following:
  169. // - Run multiple dispatches in parallel and block them mid-dispatch
  170. // - Run Disconnect() on a thread
  171. // - Unblock the dispatches
  172. // - Wait for the dispatches and disconnect to complete.
  173. // The Disconnect() logic will verify that the number of running dispatches is 0. If the dispatches successfully blocked the
  174. // disconnect, the Disconnect() won't be able to execute until all the dispatches have completed. If they don't block the
  175. // disconnect, then there will be dispatches running at the same time and the verification will fail.
  176. const int32_t TotalRecursiveQueries = 5;
  177. SharedDispatchRequestHandler handler;
  178. handler.Connect();
  179. constexpr size_t ThreadCount = 4;
  180. AZStd::thread threads[ThreadCount];
  181. AZStd::thread disconnectThread;
  182. // Each thread will trigger the RecursiveQuery call. This call has semaphores in it so that we can guarantee that
  183. // every thread has reached the same state at the same time.
  184. for (AZStd::thread& thread : threads)
  185. {
  186. thread = AZStd::thread(
  187. [TotalRecursiveQueries]()
  188. {
  189. SharedDispatchRequestBus::Broadcast(&SharedDispatchRequestBus::Events::RecursiveQuery, TotalRecursiveQueries);
  190. });
  191. }
  192. // Wait for all the threads to reach the point where they're blocking. This will occur once they've each successfully called
  193. // down through the RecursiveQuery multiple times and are ready to finish.
  194. for (size_t threadNum = 0; threadNum < ThreadCount; threadNum++)
  195. {
  196. handler.m_syncSemaphore.acquire();
  197. }
  198. disconnectThread = AZStd::thread(
  199. [&handler]()
  200. {
  201. handler.Disconnect();
  202. }
  203. );
  204. // Wait for the disconnect thread to start running. At this point, no disconnects should have occurred, because it's blocked
  205. // waiting on the dispatches to finish.
  206. handler.m_disconnectSemaphore.acquire();
  207. EXPECT_EQ(handler.m_numDisconnects, 0);
  208. // Unblock all the dispatch threads.
  209. for (size_t threadNum = 0; threadNum < ThreadCount; threadNum++)
  210. {
  211. handler.m_querySemaphore.release();
  212. }
  213. // Wait for the dispatch threads to finish.
  214. for (AZStd::thread& thread : threads)
  215. {
  216. thread.join();
  217. }
  218. // Wait for the disconnect thread to finish.
  219. disconnectThread.join();
  220. // Verify that the disconnect finished. Our disconnect logic will verify that no dispatches were running during the disconnect.
  221. EXPECT_EQ(handler.m_numDisconnects, 1);
  222. }
  223. } // namespace UnitTest