app.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import multiprocessing
  2. import os
  3. from contextlib import asynccontextmanager
  4. from pathlib import Path
  5. from random import randint, sample
  6. from typing import Any
  7. import asyncpg
  8. import orjson
  9. from litestar import Litestar, MediaType, Request, get, Response
  10. from litestar.contrib.jinja import JinjaTemplateEngine
  11. from litestar.response import Template
  12. from litestar.template import TemplateConfig
  13. READ_ROW_SQL = 'SELECT "id", "randomnumber" FROM "world" WHERE id = $1'
  14. WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2'
  15. ADDITIONAL_ROW = [0, "Additional fortune added at request time."]
  16. MAX_POOL_SIZE = 1000 // multiprocessing.cpu_count()
  17. MIN_POOL_SIZE = max(int(MAX_POOL_SIZE / 2), 1)
  18. def get_num_queries(queries):
  19. try:
  20. query_count = int(queries)
  21. except (ValueError, TypeError):
  22. return 1
  23. if query_count < 1:
  24. return 1
  25. if query_count > 500:
  26. return 500
  27. return query_count
  28. connection_pool = None
  29. async def setup_database():
  30. return await asyncpg.create_pool(
  31. user=os.getenv("PGUSER", "benchmarkdbuser"),
  32. password=os.getenv("PGPASS", "benchmarkdbpass"),
  33. database="hello_world",
  34. host="tfb-database",
  35. port=5432,
  36. min_size=MIN_POOL_SIZE,
  37. max_size=MAX_POOL_SIZE,
  38. )
  39. @asynccontextmanager
  40. async def lifespan(app: Litestar):
  41. # Set up the database connection pool
  42. app.state.connection_pool = await setup_database()
  43. yield
  44. # Close the database connection pool
  45. await app.state.connection_pool.close()
  46. @get("/json")
  47. async def json_serialization() -> Response:
  48. return Response(
  49. content=orjson.dumps({"message": "Hello, world!"}),
  50. media_type=MediaType.JSON,
  51. )
  52. @get("/db")
  53. async def single_database_query() -> Response:
  54. row_id = randint(1, 10000)
  55. async with app.state.connection_pool.acquire() as connection:
  56. number = await connection.fetchval(READ_ROW_SQL, row_id)
  57. return Response(
  58. content=orjson.dumps({"id": row_id, "randomNumber": number}),
  59. media_type=MediaType.JSON,
  60. )
  61. @get("/queries")
  62. async def multiple_database_queries(queries: Any = None) -> Response:
  63. num_queries = get_num_queries(queries)
  64. row_ids = sample(range(1, 10000), num_queries)
  65. worlds = []
  66. async with app.state.connection_pool.acquire() as connection:
  67. statement = await connection.prepare(READ_ROW_SQL)
  68. for row_id in row_ids:
  69. number = await statement.fetchval(row_id)
  70. worlds.append({"id": row_id, "randomNumber": number})
  71. return Response(
  72. content=orjson.dumps(worlds),
  73. media_type=MediaType.JSON,
  74. )
  75. @get("/fortunes", media_type=MediaType.HTML)
  76. async def fortunes(request: Request) -> Template:
  77. async with app.state.connection_pool.acquire() as connection:
  78. fortunes = await connection.fetch("SELECT * FROM Fortune")
  79. fortunes.append(ADDITIONAL_ROW)
  80. fortunes.sort(key=lambda row: row[1])
  81. return Template(
  82. "fortune.html",
  83. context={"fortunes": fortunes, "request": request},
  84. media_type=MediaType.HTML,
  85. )
  86. @get("/updates")
  87. async def database_updates(queries: Any = None) -> bytes:
  88. num_queries = get_num_queries(queries)
  89. # To avoid deadlock
  90. ids = sorted(sample(range(1, 10000 + 1), num_queries))
  91. numbers = sorted(sample(range(1, 10000), num_queries))
  92. updates = list(zip(ids, numbers, strict=False))
  93. worlds = [{"id": row_id, "randomNumber": number} for row_id, number in updates]
  94. async with app.state.connection_pool.acquire() as connection:
  95. statement = await connection.prepare(READ_ROW_SQL)
  96. for row_id, _ in updates:
  97. await statement.fetchval(row_id)
  98. await connection.executemany(WRITE_ROW_SQL, updates)
  99. return Response(
  100. content=orjson.dumps(worlds),
  101. media_type=MediaType.JSON,
  102. )
  103. @get("/plaintext", media_type=MediaType.TEXT)
  104. async def plaintext() -> bytes:
  105. return b"Hello, world!"
  106. app = Litestar(
  107. lifespan=[lifespan],
  108. template_config=TemplateConfig(
  109. directory=Path("templates"),
  110. engine=JinjaTemplateEngine,
  111. ),
  112. route_handlers=[
  113. json_serialization,
  114. single_database_query,
  115. multiple_database_queries,
  116. fortunes,
  117. database_updates,
  118. plaintext,
  119. ],
  120. )