app.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import asyncpg
  2. import os
  3. from fastapi import FastAPI, Request
  4. from fastapi.responses import PlainTextResponse
  5. try:
  6. import orjson
  7. from fastapi.responses import ORJSONResponse as JSONResponse
  8. except ImportError:
  9. from fastapi.responses import UJSONResponse as JSONResponse
  10. from fastapi.templating import Jinja2Templates
  11. from random import randint, sample
  12. READ_ROW_SQL = 'SELECT "id", "randomnumber" FROM "world" WHERE id = $1'
  13. WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2'
  14. ADDITIONAL_ROW = [0, "Additional fortune added at request time."]
  15. def get_num_queries(queries):
  16. try:
  17. query_count = int(queries)
  18. except (ValueError, TypeError):
  19. return 1
  20. if query_count < 1:
  21. return 1
  22. if query_count > 500:
  23. return 500
  24. return query_count
  25. connection_pool = None
  26. templates = Jinja2Templates(directory="templates")
  27. app = FastAPI()
  28. async def setup_database():
  29. return await asyncpg.create_pool(
  30. user=os.getenv("PGUSER", "benchmarkdbuser"),
  31. password=os.getenv("PGPASS", "benchmarkdbpass"),
  32. database="hello_world",
  33. host="tfb-database",
  34. port=5432,
  35. )
  36. @app.on_event("startup")
  37. async def startup_event():
  38. app.state.connection_pool = await setup_database()
  39. @app.on_event("shutdown")
  40. async def shutdown_event():
  41. await app.state.connection_pool.close()
  42. @app.get("/json")
  43. async def json_serialization():
  44. return JSONResponse({"message": "Hello, world!"})
  45. @app.get("/db")
  46. async def single_database_query():
  47. row_id = randint(1, 10000)
  48. async with app.state.connection_pool.acquire() as connection:
  49. number = await connection.fetchval(READ_ROW_SQL, row_id)
  50. return JSONResponse({"id": row_id, "randomNumber": number})
  51. @app.get("/queries")
  52. async def multiple_database_queries(queries = None):
  53. num_queries = get_num_queries(queries)
  54. row_ids = sample(range(1, 10000), num_queries)
  55. worlds = []
  56. async with app.state.connection_pool.acquire() as connection:
  57. statement = await connection.prepare(READ_ROW_SQL)
  58. for row_id in row_ids:
  59. number = await statement.fetchval(row_id)
  60. worlds.append({"id": row_id, "randomNumber": number})
  61. return JSONResponse(worlds)
  62. @app.get("/fortunes")
  63. async def fortunes(request: Request):
  64. async with app.state.connection_pool.acquire() as connection:
  65. fortunes = await connection.fetch("SELECT * FROM Fortune")
  66. fortunes.append(ADDITIONAL_ROW)
  67. fortunes.sort(key=lambda row: row[1])
  68. return templates.TemplateResponse("fortune.html", {"fortunes": fortunes, "request": request})
  69. @app.get("/updates")
  70. async def database_updates(queries = None):
  71. num_queries = get_num_queries(queries)
  72. # To avoid deadlock
  73. ids = sorted(sample(range(1, 10000 + 1), num_queries))
  74. numbers = sorted(sample(range(1, 10000), num_queries))
  75. updates = list(zip(ids, numbers))
  76. worlds = [
  77. {"id": row_id, "randomNumber": number} for row_id, number in updates
  78. ]
  79. async with app.state.connection_pool.acquire() as connection:
  80. statement = await connection.prepare(READ_ROW_SQL)
  81. for row_id, _ in updates:
  82. await statement.fetchval(row_id)
  83. await connection.executemany(WRITE_ROW_SQL, updates)
  84. return JSONResponse(worlds)
  85. @app.get("/plaintext")
  86. async def plaintext():
  87. return PlainTextResponse(b"Hello, world!")