app.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import multiprocessing
  2. from contextlib import asynccontextmanager
  3. import asyncpg
  4. import os
  5. from fastapi import FastAPI, Request
  6. from fastapi.responses import PlainTextResponse
  7. try:
  8. import orjson
  9. from fastapi.responses import ORJSONResponse as JSONResponse
  10. except ImportError:
  11. from fastapi.responses import UJSONResponse as JSONResponse
  12. from fastapi.templating import Jinja2Templates
  13. from random import randint, sample
  14. READ_ROW_SQL = 'SELECT "id", "randomnumber" FROM "world" WHERE id = $1'
  15. WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2'
  16. ADDITIONAL_ROW = [0, "Additional fortune added at request time."]
  17. MAX_POOL_SIZE = 1000//multiprocessing.cpu_count()
  18. MIN_POOL_SIZE = max(int(MAX_POOL_SIZE / 2), 1)
  19. def get_num_queries(queries):
  20. try:
  21. query_count = int(queries)
  22. except (ValueError, TypeError):
  23. return 1
  24. if query_count < 1:
  25. return 1
  26. if query_count > 500:
  27. return 500
  28. return query_count
  29. connection_pool = None
  30. templates = Jinja2Templates(directory="templates")
  31. async def setup_database():
  32. return await asyncpg.create_pool(
  33. user=os.getenv("PGUSER", "benchmarkdbuser"),
  34. password=os.getenv("PGPASS", "benchmarkdbpass"),
  35. database="hello_world",
  36. host="tfb-database",
  37. port=5432,
  38. min_size=MIN_POOL_SIZE,
  39. max_size=MAX_POOL_SIZE,
  40. )
  41. @asynccontextmanager
  42. async def lifespan(app: FastAPI):
  43. # Setup the database connection pool
  44. app.state.connection_pool = await setup_database()
  45. yield
  46. # Close the database connection pool
  47. await app.state.connection_pool.close()
  48. app = FastAPI(lifespan=lifespan)
  49. @app.get("/json")
  50. async def json_serialization():
  51. return JSONResponse({"message": "Hello, world!"})
  52. @app.get("/db")
  53. async def single_database_query():
  54. row_id = randint(1, 10000)
  55. async with app.state.connection_pool.acquire() as connection:
  56. number = await connection.fetchval(READ_ROW_SQL, row_id)
  57. return JSONResponse({"id": row_id, "randomNumber": number})
  58. @app.get("/queries")
  59. async def multiple_database_queries(queries = None):
  60. num_queries = get_num_queries(queries)
  61. row_ids = sample(range(1, 10000), num_queries)
  62. worlds = []
  63. async with app.state.connection_pool.acquire() as connection:
  64. statement = await connection.prepare(READ_ROW_SQL)
  65. for row_id in row_ids:
  66. number = await statement.fetchval(row_id)
  67. worlds.append({"id": row_id, "randomNumber": number})
  68. return JSONResponse(worlds)
  69. @app.get("/fortunes")
  70. async def fortunes(request: Request):
  71. async with app.state.connection_pool.acquire() as connection:
  72. fortunes = await connection.fetch("SELECT * FROM Fortune")
  73. fortunes.append(ADDITIONAL_ROW)
  74. fortunes.sort(key=lambda row: row[1])
  75. return templates.TemplateResponse("fortune.html", {"fortunes": fortunes, "request": request})
  76. @app.get("/updates")
  77. async def database_updates(queries = None):
  78. num_queries = get_num_queries(queries)
  79. # To avoid deadlock
  80. ids = sorted(sample(range(1, 10000 + 1), num_queries))
  81. numbers = sorted(sample(range(1, 10000), num_queries))
  82. updates = list(zip(ids, numbers))
  83. worlds = [
  84. {"id": row_id, "randomNumber": number} for row_id, number in updates
  85. ]
  86. async with app.state.connection_pool.acquire() as connection:
  87. statement = await connection.prepare(READ_ROW_SQL)
  88. for row_id, _ in updates:
  89. await statement.fetchval(row_id)
  90. await connection.executemany(WRITE_ROW_SQL, updates)
  91. return JSONResponse(worlds)
  92. @app.get("/plaintext")
  93. async def plaintext():
  94. return PlainTextResponse(b"Hello, world!")