app_orm.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import logging
  2. import multiprocessing
  3. import os
  4. from operator import attrgetter
  5. from random import randint, sample
  6. from sqlalchemy import Column, Integer, String, select
  7. from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
  8. from sqlalchemy.ext.declarative import declarative_base
  9. from sqlalchemy.orm import sessionmaker
  10. from sqlalchemy.orm.attributes import flag_modified
  11. from fastapi import FastAPI, Request
  12. from fastapi.responses import PlainTextResponse, UJSONResponse
  13. from fastapi.templating import Jinja2Templates
  14. logger = logging.getLogger(__name__)
  15. Base = declarative_base()
  16. class World(Base):
  17. __tablename__ = "world"
  18. id = Column(Integer, primary_key=True)
  19. randomnumber = Column(Integer)
  20. def __json__(self):
  21. return {"id": self.id, "randomnumber": self.randomnumber}
  22. sa_data = World.__table__
  23. class Fortune(Base):
  24. __tablename__ = "fortune"
  25. id = Column(Integer, primary_key=True)
  26. message = Column(String)
  27. sa_fortunes = Fortune.__table__
  28. ADDITIONAL_FORTUNE = Fortune(
  29. id=0, message="Additional fortune added at request time."
  30. )
  31. sort_fortunes_key = attrgetter("message")
  32. template_path = os.path.join(
  33. os.path.dirname(os.path.realpath(__file__)), "templates"
  34. )
  35. templates = Jinja2Templates(directory=template_path)
  36. app = FastAPI()
  37. async def setup_database():
  38. dsn = "postgresql+asyncpg://%s:%s@tfb-database:5432/hello_world" % (
  39. os.getenv("PGPASS", "benchmarkdbuser"),
  40. os.getenv("PGPASS", "benchmarkdbpass"),
  41. )
  42. engine = create_async_engine(
  43. dsn,
  44. future=True,
  45. connect_args={
  46. "ssl": False # NEEDED FOR NGINX-UNIT OTHERWISE IT FAILS
  47. },
  48. )
  49. return sessionmaker(engine, class_=AsyncSession)
  50. def get_num_queries(queries):
  51. try:
  52. query_count = int(queries)
  53. except (ValueError, TypeError):
  54. return 1
  55. if query_count < 1:
  56. return 1
  57. if query_count > 500:
  58. return 500
  59. return query_count
  60. @app.on_event("startup")
  61. async def startup_event():
  62. app.state.db_session = await setup_database()
  63. @app.on_event("shutdown")
  64. async def shutdown_event():
  65. await app.state.db_session.close()
  66. @app.get("/json")
  67. async def json_serialization():
  68. return UJSONResponse({"message": "Hello, world!"})
  69. @app.get("/db")
  70. async def single_database_query():
  71. id_ = randint(1, 10000)
  72. async with app.state.db_session() as sess:
  73. result = await sess.get(World, id_)
  74. return UJSONResponse(result.__json__())
  75. @app.get("/queries")
  76. async def multiple_database_queries(queries=None):
  77. num_queries = get_num_queries(queries)
  78. data = []
  79. async with app.state.db_session() as sess:
  80. for id_ in sample(range(1, 10001), num_queries):
  81. result = await sess.get(World, id_)
  82. data.append(result.__json__())
  83. return UJSONResponse(data)
  84. @app.get("/fortunes")
  85. async def fortunes(request: Request):
  86. async with app.state.db_session() as sess:
  87. ret = await sess.execute(select(Fortune.id, Fortune.message))
  88. data = ret.all()
  89. data.append(ADDITIONAL_FORTUNE)
  90. data.sort(key=sort_fortunes_key)
  91. return templates.TemplateResponse(
  92. "fortune.jinja", {"request": request, "fortunes": data}
  93. )
  94. @app.get("/updates")
  95. async def database_updates(queries=None):
  96. num_queries = get_num_queries(queries)
  97. ids = sorted(sample(range(1, 10000 + 1), num_queries))
  98. data = []
  99. async with app.state.db_session.begin() as sess:
  100. for id_ in ids:
  101. world = await sess.get(World, id_, populate_existing=True)
  102. world.randomnumber = randint(1, 10000)
  103. # force sqlalchemy to UPDATE entry even if the value has not changed
  104. # doesn't make sense in a real application, added only for pass `tfb verify`
  105. flag_modified(world, "randomnumber")
  106. data.append(world.__json__())
  107. return UJSONResponse(data)
  108. @app.get("/plaintext")
  109. async def plaintext():
  110. return PlainTextResponse(b"Hello, world!")