db.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import asyncio
  2. from contextlib import asynccontextmanager
  3. import asyncpg
  4. import os
  5. class Connection(asyncpg.Connection):
  6. async def reset(self, *, timeout=None):
  7. pass
  8. class Pool:
  9. def __init__(self, connect_url, max_size=10, connection_class=None):
  10. self._connect_url = connect_url
  11. self._connection_class = connection_class or Connection
  12. self._queue = asyncio.LifoQueue(max_size)
  13. def __await__(self):
  14. return self._async_init__().__await__()
  15. async def _async_init__(self):
  16. for _ in range(self._queue.maxsize):
  17. self._queue.put_nowait(await asyncpg.connect(self._connect_url, connection_class=self._connection_class))
  18. return self
  19. @asynccontextmanager
  20. async def acquire(self):
  21. conn = await self._queue.get()
  22. try:
  23. yield conn
  24. finally:
  25. self._queue.put_nowait(conn)
  26. async def close(self):
  27. for _ in range(self._queue.maxsize):
  28. conn = await self._queue.get()
  29. await conn.close()
  30. async def init_db(app):
  31. app.db_pool = await Pool("postgresql://%s:%s@tfb-database:5432/hello_world" % (os.getenv("PGUSER", "benchmarkdbuser"), os.getenv("PSPASS", "benchmarkdbpass")), connection_class=asyncpg.Connection)
  32. async def close_db(app):
  33. await asyncio.wait_for(app.db_pool.close(), timeout=1)