ConnectionPool.hpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. /* (c) ZeroTier, Inc.
  2. * See LICENSE.txt in nonfree/
  3. */
  4. #ifndef ZT_CONNECTION_POOL_H_
  5. #define ZT_CONNECTION_POOL_H_
  6. #ifndef _DEBUG
  7. #define _DEBUG(x)
  8. #endif
  9. #include "../../node/Metrics.hpp"
  10. #include "opentelemetry/trace/provider.h"
  11. #include <deque>
  12. #include <exception>
  13. #include <memory>
  14. #include <mutex>
  15. #include <set>
  16. namespace ZeroTier {
  17. struct ConnectionUnavailable : std::exception {
  18. char const* what() const throw()
  19. {
  20. return "Unable to allocate connection";
  21. };
  22. };
  23. class Connection {
  24. public:
  25. virtual ~Connection() {};
  26. };
  27. class ConnectionFactory {
  28. public:
  29. virtual ~ConnectionFactory() {};
  30. virtual std::shared_ptr<Connection> create() = 0;
  31. };
  32. struct ConnectionPoolStats {
  33. size_t pool_size;
  34. size_t borrowed_size;
  35. };
  36. template <class T> class ConnectionPool {
  37. public:
  38. ConnectionPool(size_t max_pool_size, size_t min_pool_size, std::shared_ptr<ConnectionFactory> factory) : m_maxPoolSize(max_pool_size), m_minPoolSize(min_pool_size), m_factory(factory)
  39. {
  40. Metrics::max_pool_size += max_pool_size;
  41. Metrics::min_pool_size += min_pool_size;
  42. while (m_pool.size() < m_minPoolSize) {
  43. m_pool.push_back(m_factory->create());
  44. Metrics::pool_avail++;
  45. }
  46. };
  47. ConnectionPoolStats get_stats()
  48. {
  49. std::unique_lock<std::mutex> lock(m_poolMutex);
  50. ConnectionPoolStats stats;
  51. stats.pool_size = m_pool.size();
  52. stats.borrowed_size = m_borrowed.size();
  53. return stats;
  54. };
  55. ~ConnectionPool() {};
  56. /**
  57. * Borrow
  58. *
  59. * Borrow a connection for temporary use
  60. *
  61. * When done, either (a) call unborrow() to return it, or (b) (if it's bad) just let it go out of scope. This will cause it to automatically be replaced.
  62. * @retval a shared_ptr to the connection object
  63. */
  64. std::shared_ptr<T> borrow()
  65. {
  66. auto provider = opentelemetry::trace::Provider::GetTracerProvider();
  67. auto tracer = provider->GetTracer("connection_pool");
  68. auto span = tracer->StartSpan("connection_pool::borrow");
  69. auto scope = tracer->WithActiveSpan(span);
  70. std::unique_lock<std::mutex> l(m_poolMutex);
  71. while ((m_pool.size() + m_borrowed.size()) < m_minPoolSize) {
  72. std::shared_ptr<Connection> conn = m_factory->create();
  73. m_pool.push_back(conn);
  74. Metrics::pool_avail++;
  75. }
  76. if (m_pool.size() == 0) {
  77. if ((m_pool.size() + m_borrowed.size()) < m_maxPoolSize) {
  78. try {
  79. std::shared_ptr<Connection> conn = m_factory->create();
  80. m_borrowed.insert(conn);
  81. Metrics::pool_in_use++;
  82. return std::static_pointer_cast<T>(conn);
  83. }
  84. catch (std::exception& e) {
  85. span->SetStatus(opentelemetry::trace::StatusCode::kError, e.what());
  86. Metrics::pool_errors++;
  87. throw ConnectionUnavailable();
  88. }
  89. }
  90. else {
  91. for (auto it = m_borrowed.begin(); it != m_borrowed.end(); ++it) {
  92. if ((*it).unique()) {
  93. // This connection has been abandoned! Destroy it and create a new connection
  94. try {
  95. // If we are able to create a new connection, return it
  96. _DEBUG("Creating new connection to replace discarded connection");
  97. std::shared_ptr<Connection> conn = m_factory->create();
  98. m_borrowed.erase(it);
  99. m_borrowed.insert(conn);
  100. return std::static_pointer_cast<T>(conn);
  101. }
  102. catch (std::exception& e) {
  103. span->SetStatus(opentelemetry::trace::StatusCode::kError, e.what());
  104. // Error creating a replacement connection
  105. Metrics::pool_errors++;
  106. throw ConnectionUnavailable();
  107. }
  108. }
  109. }
  110. span->SetStatus(opentelemetry::trace::StatusCode::kError, "No available connections in pool");
  111. // Nothing available
  112. Metrics::pool_errors++;
  113. throw ConnectionUnavailable();
  114. }
  115. }
  116. // Take one off the front
  117. std::shared_ptr<Connection> conn = m_pool.front();
  118. m_pool.pop_front();
  119. Metrics::pool_avail--;
  120. // Add it to the borrowed list
  121. m_borrowed.insert(conn);
  122. Metrics::pool_in_use++;
  123. return std::static_pointer_cast<T>(conn);
  124. };
  125. /**
  126. * Unborrow a connection
  127. *
  128. * Only call this if you are returning a working connection. If the connection was bad, just let it go out of scope (so the connection manager can replace it).
  129. * @param the connection
  130. */
  131. void unborrow(std::shared_ptr<T> conn)
  132. {
  133. auto provider = opentelemetry::trace::Provider::GetTracerProvider();
  134. auto tracer = provider->GetTracer("connection_pool");
  135. auto span = tracer->StartSpan("connection_pool::unborrow");
  136. auto scope = tracer->WithActiveSpan(span);
  137. // Lock
  138. std::unique_lock<std::mutex> lock(m_poolMutex);
  139. m_borrowed.erase(conn);
  140. Metrics::pool_in_use--;
  141. if ((m_pool.size() + m_borrowed.size()) < m_maxPoolSize) {
  142. Metrics::pool_avail++;
  143. m_pool.push_back(conn);
  144. }
  145. };
  146. protected:
  147. size_t m_maxPoolSize;
  148. size_t m_minPoolSize;
  149. std::shared_ptr<ConnectionFactory> m_factory;
  150. std::deque<std::shared_ptr<Connection> > m_pool;
  151. std::set<std::shared_ptr<Connection> > m_borrowed;
  152. std::mutex m_poolMutex;
  153. };
  154. } // namespace ZeroTier
  155. #endif