app_orm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import logging
  2. import multiprocessing
  3. import os
  4. from contextlib import asynccontextmanager
  5. from operator import attrgetter
  6. from random import randint, sample
  7. from sqlalchemy import Column, Integer, String, select
  8. from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
  9. from sqlalchemy.ext.declarative import declarative_base
  10. from sqlalchemy.orm import sessionmaker
  11. from sqlalchemy.orm.attributes import flag_modified
  12. from fastapi import FastAPI, Request
  13. from fastapi.responses import PlainTextResponse, UJSONResponse
  14. from fastapi.templating import Jinja2Templates
  15. logger = logging.getLogger(__name__)
  16. Base = declarative_base()
  17. class World(Base):
  18. __tablename__ = "world"
  19. id = Column(Integer, primary_key=True)
  20. randomnumber = Column(Integer)
  21. def __json__(self):
  22. return {"id": self.id, "randomnumber": self.randomnumber}
  23. sa_data = World.__table__
  24. class Fortune(Base):
  25. __tablename__ = "fortune"
  26. id = Column(Integer, primary_key=True)
  27. message = Column(String)
  28. sa_fortunes = Fortune.__table__
  29. ADDITIONAL_FORTUNE = Fortune(
  30. id=0, message="Additional fortune added at request time."
  31. )
  32. MAX_POOL_SIZE = 1000//multiprocessing.cpu_count()
  33. sort_fortunes_key = attrgetter("message")
  34. template_path = os.path.join(
  35. os.path.dirname(os.path.realpath(__file__)), "templates"
  36. )
  37. templates = Jinja2Templates(directory=template_path)
  38. async def setup_database():
  39. dsn = "postgresql+asyncpg://%s:%s@tfb-database:5432/hello_world" % (
  40. os.getenv("PGPASS", "benchmarkdbuser"),
  41. os.getenv("PGPASS", "benchmarkdbpass"),
  42. )
  43. engine = create_async_engine(
  44. dsn,
  45. future=True,
  46. pool_size=MAX_POOL_SIZE,
  47. connect_args={
  48. "ssl": False # NEEDED FOR NGINX-UNIT OTHERWISE IT FAILS
  49. },
  50. )
  51. return sessionmaker(engine, class_=AsyncSession)
  52. @asynccontextmanager
  53. async def lifespan(app: FastAPI):
  54. # Setup the database connection pool
  55. app.state.db_session = await setup_database()
  56. yield
  57. # Close the database connection pool
  58. await app.state.db_session.close()
  59. app = FastAPI(lifespan=lifespan)
  60. def get_num_queries(queries):
  61. try:
  62. query_count = int(queries)
  63. except (ValueError, TypeError):
  64. return 1
  65. if query_count < 1:
  66. return 1
  67. if query_count > 500:
  68. return 500
  69. return query_count
  70. @app.get("/json")
  71. async def json_serialization():
  72. return UJSONResponse({"message": "Hello, world!"})
  73. @app.get("/db")
  74. async def single_database_query():
  75. id_ = randint(1, 10000)
  76. async with app.state.db_session() as sess:
  77. result = await sess.get(World, id_)
  78. return UJSONResponse(result.__json__())
  79. @app.get("/queries")
  80. async def multiple_database_queries(queries=None):
  81. num_queries = get_num_queries(queries)
  82. data = []
  83. async with app.state.db_session() as sess:
  84. for id_ in sample(range(1, 10001), num_queries):
  85. result = await sess.get(World, id_)
  86. data.append(result.__json__())
  87. return UJSONResponse(data)
  88. @app.get("/fortunes")
  89. async def fortunes(request: Request):
  90. async with app.state.db_session() as sess:
  91. ret = await sess.execute(select(Fortune.id, Fortune.message))
  92. data = ret.all()
  93. data.append(ADDITIONAL_FORTUNE)
  94. data.sort(key=sort_fortunes_key)
  95. return templates.TemplateResponse(
  96. "fortune.jinja", {"request": request, "fortunes": data}
  97. )
  98. @app.get("/updates")
  99. async def database_updates(queries=None):
  100. num_queries = get_num_queries(queries)
  101. ids = sorted(sample(range(1, 10000 + 1), num_queries))
  102. data = []
  103. async with app.state.db_session.begin() as sess:
  104. for id_ in ids:
  105. world = await sess.get(World, id_, populate_existing=True)
  106. world.randomnumber = randint(1, 10000)
  107. # force sqlalchemy to UPDATE entry even if the value has not changed
  108. # doesn't make sense in a real application, added only for pass `tfb verify`
  109. flag_modified(world, "randomnumber")
  110. data.append(world.__json__())
  111. return UJSONResponse(data)
  112. @app.get("/plaintext")
  113. async def plaintext():
  114. return PlainTextResponse(b"Hello, world!")