app.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import asyncio
  2. import asyncpg
  3. import os
  4. import jinja2
  5. from fastapi import FastAPI
  6. from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse
  7. from random import randint
  8. from operator import itemgetter
  9. from urllib.parse import parse_qs
  10. READ_ROW_SQL = 'SELECT "randomnumber", "id" FROM "world" WHERE id = $1'
  11. WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2'
  12. ADDITIONAL_ROW = [0, 'Additional fortune added at request time.']
  13. # https://www.starlette.io/responses/#custom-json-serialization
  14. try:
  15. import orjson
  16. class CustomJSONResponse(JSONResponse):
  17. def render(self, content):
  18. return orjson.dumps(content)
  19. except ImportError:
  20. class CustomJSONResponse(JSONResponse):
  21. pass
  22. async def setup_database():
  23. global connection_pool
  24. connection_pool = await asyncpg.create_pool(
  25. user=os.getenv('PGUSER', 'benchmarkdbuser'),
  26. password=os.getenv('PGPASS', 'benchmarkdbpass'),
  27. database='hello_world',
  28. host='tfb-database',
  29. port=5432
  30. )
  31. def load_fortunes_template():
  32. path = os.path.join('templates', 'fortune.html')
  33. with open(path, 'r') as template_file:
  34. template_text = template_file.read()
  35. return jinja2.Template(template_text)
  36. def get_num_queries(queries):
  37. try:
  38. query_count = int(queries)
  39. except (ValueError, TypeError):
  40. return 1
  41. if query_count < 1:
  42. return 1
  43. if query_count > 500:
  44. return 500
  45. return query_count
  46. connection_pool = None
  47. sort_fortunes_key = itemgetter(1)
  48. template = load_fortunes_template()
  49. loop = asyncio.get_event_loop()
  50. loop.run_until_complete(setup_database())
  51. app = FastAPI()
  52. @app.get('/json')
  53. async def json_serialization():
  54. return CustomJSONResponse({'message': 'Hello, world!'})
  55. @app.get('/db')
  56. async def single_database_query():
  57. row_id = randint(1, 10000)
  58. async with connection_pool.acquire() as connection:
  59. number = await connection.fetchval(READ_ROW_SQL, row_id)
  60. return CustomJSONResponse({'id': row_id, 'randomNumber': number})
  61. @app.get('/queries')
  62. async def multiple_database_queries(queries = None):
  63. num_queries = get_num_queries(queries)
  64. row_ids = [randint(1, 10000) for _ in range(num_queries)]
  65. worlds = []
  66. async with connection_pool.acquire() as connection:
  67. statement = await connection.prepare(READ_ROW_SQL)
  68. for row_id in row_ids:
  69. number = await statement.fetchval(row_id)
  70. worlds.append({'id': row_id, 'randomNumber': number})
  71. return CustomJSONResponse(worlds)
  72. @app.get('/fortunes')
  73. async def fortunes():
  74. async with connection_pool.acquire() as connection:
  75. fortunes = await connection.fetch('SELECT * FROM Fortune')
  76. fortunes.append(ADDITIONAL_ROW)
  77. fortunes.sort(key=sort_fortunes_key)
  78. content = template.render(fortunes=fortunes)
  79. return HTMLResponse(content)
  80. @app.get('/updates')
  81. async def database_updates(queries = None):
  82. num_queries = get_num_queries(queries)
  83. updates = [(randint(1, 10000), randint(1, 10000)) for _ in range(num_queries)]
  84. worlds = [{'id': row_id, 'randomNumber': number} for row_id, number in updates]
  85. async with connection_pool.acquire() as connection:
  86. statement = await connection.prepare(READ_ROW_SQL)
  87. for row_id, number in updates:
  88. await statement.fetchval(row_id)
  89. await connection.executemany(WRITE_ROW_SQL, updates)
  90. return CustomJSONResponse(worlds)
  91. @app.get('/plaintext')
  92. async def plaintext():
  93. return PlainTextResponse(b'Hello, world!')