app.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import asyncio
  2. from typing import Optional, Any
  3. import asyncpg
  4. import os
  5. import jinja2
  6. from asyncpg import Pool
  7. from starlite import Starlite, get, MediaType
  8. from random import randint
  9. from operator import itemgetter
  10. READ_ROW_SQL = 'SELECT "randomnumber", "id" FROM "world" WHERE id = $1'
  11. WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2'
  12. ADDITIONAL_ROW = [0, 'Additional fortune added at request time.']
  13. connection_pool: Pool
  14. async def setup_database():
  15. global connection_pool
  16. connection_pool = await asyncpg.create_pool(
  17. user=os.getenv('PGUSER', 'benchmarkdbuser'),
  18. password=os.getenv('PGPASS', 'benchmarkdbpass'),
  19. database='hello_world',
  20. host='tfb-database',
  21. port=5432
  22. )
  23. def load_fortunes_template():
  24. path = os.path.join('templates', 'fortune.html')
  25. with open(path, 'r') as template_file:
  26. template_text = template_file.read()
  27. return jinja2.Template(template_text)
  28. def get_num_queries(queries: Any):
  29. if queries:
  30. try:
  31. query_count = int(queries)
  32. except (ValueError, TypeError):
  33. return 1
  34. if query_count < 1:
  35. return 1
  36. if query_count > 500:
  37. return 500
  38. return query_count
  39. return 1
  40. sort_fortunes_key = itemgetter(1)
  41. template = load_fortunes_template()
  42. @get(path='/json')
  43. async def json_serialization() -> dict[str, str]:
  44. return {'message': 'Hello, world!'}
  45. @get(path='/db')
  46. async def single_database_query() -> dict[str, int]:
  47. row_id = randint(1, 10000)
  48. async with connection_pool.acquire() as connection:
  49. number = await connection.fetchval(READ_ROW_SQL, row_id)
  50. return {'id': row_id, 'randomNumber': number}
  51. @get(path='/queries')
  52. async def multiple_database_queries(queries: Any = None) -> list[dict[str, int]]:
  53. num_queries = get_num_queries(queries)
  54. row_ids = [randint(1, 10000) for _ in range(num_queries)]
  55. worlds = []
  56. async with connection_pool.acquire() as connection:
  57. statement = await connection.prepare(READ_ROW_SQL)
  58. for row_id in row_ids:
  59. number = await statement.fetchval(row_id)
  60. worlds.append({'id': row_id, 'randomNumber': number})
  61. return worlds
  62. @get(path='/fortunes', media_type=MediaType.HTML)
  63. async def render_fortunes_template() -> str:
  64. async with connection_pool.acquire() as connection:
  65. fortunes = await connection.fetch('SELECT * FROM Fortune')
  66. fortunes.append(ADDITIONAL_ROW)
  67. fortunes.sort(key=sort_fortunes_key)
  68. return template.render(fortunes=fortunes)
  69. @get(path='/updates')
  70. async def database_updates(queries: Any = None) -> list[dict[str, int]]:
  71. num_queries = get_num_queries(queries)
  72. updates = [(randint(1, 10000), randint(1, 10000)) for _ in range(num_queries)]
  73. worlds = [{'id': row_id, 'randomNumber': number} for row_id, number in updates]
  74. async with connection_pool.acquire() as connection:
  75. statement = await connection.prepare(READ_ROW_SQL)
  76. for row_id, number in updates:
  77. await statement.fetchval(row_id)
  78. await connection.executemany(WRITE_ROW_SQL, updates)
  79. return worlds
  80. @get(path='/plaintext', media_type=MediaType.TEXT)
  81. async def plaintext() -> bytes:
  82. return b'Hello, world!'
  83. app = Starlite(
  84. route_handlers=[
  85. json_serialization,
  86. single_database_query,
  87. multiple_database_queries,
  88. render_fortunes_template,
  89. database_updates,
  90. plaintext
  91. ],
  92. on_startup=[setup_database]
  93. )