pg.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import logging
  2. from operator import itemgetter
  3. from random import randint
  4. import asyncpg.exceptions
  5. import jinja2
  6. from aioworkers_pg.base import Connector
  7. from aioworkers.core.base import AbstractEntity
  8. from aioworkers.core.config import ValueExtractor
  9. from aioworkers.net.uri import URI
  10. READ_ROW_SQL = 'SELECT "randomnumber", "id" FROM "world" WHERE id = $1'
  11. WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2'
  12. ADDITIONAL_ROW = [0, "Additional fortune added at request time."]
  13. sort_fortunes_key = itemgetter(1)
  14. logger = logging.getLogger(__name__)
  15. class PG(Connector):
  16. def set_config(self, config: ValueExtractor) -> None:
  17. cfg = config.connection
  18. dsn: URI = cfg.get_uri("dsn").with_auth(
  19. username=cfg.get("username"),
  20. password=cfg.get("password"),
  21. )
  22. super().set_config(config.new_child(dsn=dsn))
  23. class Templates(AbstractEntity):
  24. fortune: jinja2.Template
  25. def set_config(self, config):
  26. super().set_config(config)
  27. self.fortune = jinja2.Template(config.fortune)
  28. def get_num_queries(request):
  29. query_count = request.url.query.get_int("queries")
  30. if query_count is None:
  31. return 1
  32. elif query_count < 1:
  33. return 1
  34. elif query_count > 500:
  35. return 500
  36. return query_count
  37. async def single_database_query(context):
  38. row_id = randint(1, 10000)
  39. async with context.pg.pool.acquire() as connection:
  40. number = await connection.fetchval(READ_ROW_SQL, row_id)
  41. return {"id": row_id, "randomNumber": number}
  42. async def multiple_database_queries(context, request):
  43. num_queries = get_num_queries(request)
  44. row_ids = [randint(1, 10000) for _ in range(num_queries)]
  45. worlds = []
  46. async with context.pg.pool.acquire() as connection:
  47. statement = await connection.prepare(READ_ROW_SQL)
  48. for row_id in row_ids:
  49. number = await statement.fetchval(row_id)
  50. worlds.append({"id": row_id, "randomNumber": number})
  51. return worlds
  52. async def fortunes(context, request):
  53. async with context.pg.pool.acquire() as connection:
  54. fortunes = await connection.fetch("SELECT * FROM Fortune")
  55. fortunes.append(ADDITIONAL_ROW)
  56. fortunes.sort(key=sort_fortunes_key)
  57. content = context.templates.fortune.render(fortunes=fortunes)
  58. return request.response(
  59. content.encode(),
  60. headers=[
  61. ("Content-Type", "text/html; charset=utf-8"),
  62. ],
  63. )
  64. async def database_updates(context, request):
  65. num_queries = get_num_queries(request)
  66. uniq = {randint(1, 10000) for _ in range(num_queries)}
  67. while len(uniq) < num_queries:
  68. uniq.add(randint(1, 10000))
  69. updates = [
  70. (row_id, randint(1, 10000)) for row_id in uniq
  71. ]
  72. worlds = [
  73. {"id": row_id, "randomNumber": number} for row_id, number in updates
  74. ]
  75. async with context.pg.pool.acquire() as connection:
  76. statement = await connection.prepare(READ_ROW_SQL)
  77. for row_id, number in updates:
  78. await statement.fetchval(row_id)
  79. for _ in range(99):
  80. try:
  81. await connection.executemany(WRITE_ROW_SQL, updates)
  82. except asyncpg.exceptions.DeadlockDetectedError as e:
  83. logger.debug('Deadlock %s', e)
  84. else:
  85. break
  86. else:
  87. worlds.clear()
  88. return worlds