main.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import multiprocessing
  2. import os
  3. import pathlib
  4. from operator import itemgetter
  5. from random import randint, sample
  6. from typing import Annotated, AsyncIterable, Optional
  7. import asyncpg # type: ignore
  8. import jinja2 # type: ignore
  9. import uvicorn # type: ignore
  10. from pydantic import BaseModel, Field
  11. from starlette.responses import HTMLResponse, PlainTextResponse
  12. from xpresso import App, Depends, Path, Response, FromQuery
  13. READ_ROW_SQL = 'SELECT "randomnumber", "id" FROM "world" WHERE id = $1'
  14. WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2'
  15. ADDITIONAL_ROW = (0, 'Additional fortune added at request time.')
  16. sort_fortunes_key = itemgetter(1)
  17. app_dir = pathlib.Path(__file__).parent
  18. with (app_dir / "templates" / "fortune.html").open() as template_file:
  19. template = jinja2.Template(template_file.read())
  20. async def get_db_pool() -> AsyncIterable[asyncpg.Pool]:
  21. async with asyncpg.create_pool( # type: ignore
  22. user=os.getenv('PGUSER', 'benchmarkdbuser'),
  23. password=os.getenv('PGPASS', 'benchmarkdbpass'),
  24. database=os.getenv('PGDB', 'hello_world'),
  25. host=os.getenv('PGHOST', 'tfb-database'),
  26. port=5432,
  27. ) as pool:
  28. yield pool
  29. DBPool = Annotated[asyncpg.Pool, Depends(get_db_pool, scope="app")]
  30. def get_num_queries(queries: Optional[str]) -> int:
  31. if not queries:
  32. return 1
  33. try:
  34. queries_num = int(queries)
  35. except (ValueError, TypeError):
  36. return 1
  37. if queries_num < 1:
  38. return 1
  39. if queries_num > 500:
  40. return 500
  41. return queries_num
  42. class Greeting(BaseModel):
  43. message: str
  44. def json_serialization() -> Greeting:
  45. return Greeting(message="Hello, world!")
  46. def plaintext() -> Response:
  47. return PlainTextResponse(b"Hello, world!")
  48. class QueryResult(BaseModel):
  49. id: int
  50. randomNumber: int
  51. async def single_database_query(pool: DBPool) -> QueryResult:
  52. row_id = randint(1, 10000)
  53. connection: "asyncpg.Connection"
  54. async with pool.acquire() as connection: # type: ignore
  55. number: int = await connection.fetchval(READ_ROW_SQL, row_id) # type: ignore
  56. return QueryResult.construct(id=row_id, randomNumber=number)
  57. QueryCount = Annotated[str, Field(gt=0, le=500)]
  58. async def multiple_database_queries(
  59. pool: DBPool,
  60. queries: FromQuery[str | None] = None,
  61. ) -> list[QueryResult]:
  62. num_queries = get_num_queries(queries)
  63. row_ids = sample(range(1, 10000), num_queries)
  64. connection: "asyncpg.Connection"
  65. async with pool.acquire() as connection: # type: ignore
  66. statement = await connection.prepare(READ_ROW_SQL) # type: ignore
  67. return [
  68. QueryResult.construct(
  69. id=row_id,
  70. randomNumber=await statement.fetchval(row_id), # type: ignore
  71. )
  72. for row_id in row_ids
  73. ]
  74. async def fortunes(pool: DBPool) -> Response:
  75. connection: "asyncpg.Connection"
  76. async with pool.acquire() as connection: # type: ignore
  77. fortunes: "list[tuple[int, str]]" = await connection.fetch("SELECT * FROM Fortune") # type: ignore
  78. fortunes.append(ADDITIONAL_ROW)
  79. fortunes.sort(key=sort_fortunes_key)
  80. content = template.render(fortunes=fortunes) # type: ignore
  81. return HTMLResponse(content)
  82. async def database_updates(
  83. pool: DBPool,
  84. queries: FromQuery[str | None] = None,
  85. ) -> list[QueryResult]:
  86. num_queries = get_num_queries(queries)
  87. updates = [(row_id, randint(1, 10000)) for row_id in sample(range(1, 10000), num_queries)]
  88. async with pool.acquire() as connection:
  89. statement = await connection.prepare(READ_ROW_SQL)
  90. for row_id, _ in updates:
  91. await statement.fetchval(row_id)
  92. await connection.executemany(WRITE_ROW_SQL, updates) # type: ignore
  93. return [QueryResult.construct(id=row_id, randomNumber=number) for row_id, number in updates]
  94. routes = (
  95. Path("/json", get=json_serialization),
  96. Path("/plaintext", get=plaintext),
  97. Path("/db", get=single_database_query),
  98. Path("/queries", get=multiple_database_queries),
  99. Path("/fortunes", get=fortunes),
  100. Path("/updates", get=database_updates),
  101. )
  102. app = App(routes=routes)
  103. if __name__ == "__main__":
  104. workers = multiprocessing.cpu_count()
  105. if os.environ.get("TRAVIS") == "true":
  106. workers = 2
  107. uvicorn.run( # type: ignore
  108. "main:app",
  109. host="0.0.0.0",
  110. port=8080,
  111. workers=workers,
  112. log_level="error",
  113. loop="uvloop",
  114. http="httptools",
  115. )