123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- import multiprocessing
- import os
- import pathlib
- from operator import itemgetter
- from random import randint, sample
- from typing import Annotated, AsyncIterable, Optional
- import asyncpg # type: ignore
- import jinja2 # type: ignore
- import uvicorn # type: ignore
- from pydantic import BaseModel, Field
- from starlette.responses import HTMLResponse, PlainTextResponse
- from xpresso import App, Depends, Path, Response, FromQuery
- READ_ROW_SQL = 'SELECT "randomnumber", "id" FROM "world" WHERE id = $1'
- WRITE_ROW_SQL = 'UPDATE "world" SET "randomnumber"=$1 WHERE id=$2'
- ADDITIONAL_ROW = (0, 'Additional fortune added at request time.')
- sort_fortunes_key = itemgetter(1)
- app_dir = pathlib.Path(__file__).parent
- with (app_dir / "templates" / "fortune.html").open() as template_file:
- template = jinja2.Template(template_file.read())
- async def get_db_pool() -> AsyncIterable[asyncpg.Pool]:
- async with asyncpg.create_pool( # type: ignore
- user=os.getenv('PGUSER', 'benchmarkdbuser'),
- password=os.getenv('PGPASS', 'benchmarkdbpass'),
- database=os.getenv('PGDB', 'hello_world'),
- host=os.getenv('PGHOST', 'tfb-database'),
- port=5432,
- ) as pool:
- yield pool
- DBPool = Annotated[asyncpg.Pool, Depends(get_db_pool, scope="app")]
- def get_num_queries(queries: Optional[str]) -> int:
- if not queries:
- return 1
- try:
- queries_num = int(queries)
- except (ValueError, TypeError):
- return 1
- if queries_num < 1:
- return 1
- if queries_num > 500:
- return 500
- return queries_num
- class Greeting(BaseModel):
- message: str
- def json_serialization() -> Greeting:
- return Greeting(message="Hello, world!")
- def plaintext() -> Response:
- return PlainTextResponse(b"Hello, world!")
- class QueryResult(BaseModel):
- id: int
- randomNumber: int
- async def single_database_query(pool: DBPool) -> QueryResult:
- row_id = randint(1, 10000)
- connection: "asyncpg.Connection"
- async with pool.acquire() as connection: # type: ignore
- number: int = await connection.fetchval(READ_ROW_SQL, row_id) # type: ignore
- return QueryResult.construct(id=row_id, randomNumber=number)
- QueryCount = Annotated[str, Field(gt=0, le=500)]
- async def multiple_database_queries(
- pool: DBPool,
- queries: FromQuery[str | None] = None,
- ) -> list[QueryResult]:
- num_queries = get_num_queries(queries)
- row_ids = sample(range(1, 10000), num_queries)
- connection: "asyncpg.Connection"
- async with pool.acquire() as connection: # type: ignore
- statement = await connection.prepare(READ_ROW_SQL) # type: ignore
- return [
- QueryResult.construct(
- id=row_id,
- randomNumber=await statement.fetchval(row_id), # type: ignore
- )
- for row_id in row_ids
- ]
- async def fortunes(pool: DBPool) -> Response:
- connection: "asyncpg.Connection"
- async with pool.acquire() as connection: # type: ignore
- fortunes: "list[tuple[int, str]]" = await connection.fetch("SELECT * FROM Fortune") # type: ignore
- fortunes.append(ADDITIONAL_ROW)
- fortunes.sort(key=sort_fortunes_key)
- content = template.render(fortunes=fortunes) # type: ignore
- return HTMLResponse(content)
- async def database_updates(
- pool: DBPool,
- queries: FromQuery[str | None] = None,
- ) -> list[QueryResult]:
- num_queries = get_num_queries(queries)
- updates = [(row_id, randint(1, 10000)) for row_id in sample(range(1, 10000), num_queries)]
- async with pool.acquire() as connection:
- statement = await connection.prepare(READ_ROW_SQL)
- for row_id, _ in updates:
- await statement.fetchval(row_id)
- await connection.executemany(WRITE_ROW_SQL, updates) # type: ignore
- return [QueryResult.construct(id=row_id, randomNumber=number) for row_id, number in updates]
- routes = (
- Path("/json", get=json_serialization),
- Path("/plaintext", get=plaintext),
- Path("/db", get=single_database_query),
- Path("/queries", get=multiple_database_queries),
- Path("/fortunes", get=fortunes),
- Path("/updates", get=database_updates),
- )
- app = App(routes=routes)
- if __name__ == "__main__":
- workers = multiprocessing.cpu_count()
- if os.environ.get("TRAVIS") == "true":
- workers = 2
- uvicorn.run( # type: ignore
- "main:app",
- host="0.0.0.0",
- port=8080,
- workers=workers,
- log_level="error",
- loop="uvloop",
- http="httptools",
- )
|