app_orm.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import logging
  2. import multiprocessing
  3. import os
  4. from contextlib import asynccontextmanager
  5. from operator import attrgetter
  6. from pathlib import Path
  7. from random import randint, sample
  8. from typing import Any
  9. import orjson
  10. from litestar import Litestar, MediaType, Request, get, Response
  11. from litestar.contrib.jinja import JinjaTemplateEngine
  12. from litestar.response import Template
  13. from litestar.template import TemplateConfig
  14. from sqlalchemy import Column, Integer, String, select
  15. from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
  16. from sqlalchemy.ext.declarative import declarative_base
  17. from sqlalchemy.orm.attributes import flag_modified
  18. logger = logging.getLogger(__name__)
  19. Base = declarative_base()
  20. class World(Base):
  21. __tablename__ = "world"
  22. id = Column(Integer, primary_key=True)
  23. randomnumber = Column(Integer)
  24. def __json__(self):
  25. return {"id": self.id, "randomnumber": self.randomnumber}
  26. sa_data = World.__table__
  27. class Fortune(Base):
  28. __tablename__ = "fortune"
  29. id = Column(Integer, primary_key=True)
  30. message = Column(String)
  31. sa_fortunes = Fortune.__table__
  32. ADDITIONAL_FORTUNE = Fortune(id=0, message="Additional fortune added at request time.")
  33. MAX_POOL_SIZE = 1000 // multiprocessing.cpu_count()
  34. sort_fortunes_key = attrgetter("message")
  35. template_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "templates")
  36. async def setup_database():
  37. dsn = "postgresql+asyncpg://%s:%s@tfb-database:5432/hello_world" % (
  38. os.getenv("PGPASS", "benchmarkdbuser"),
  39. os.getenv("PGPASS", "benchmarkdbpass"),
  40. )
  41. engine = create_async_engine(
  42. dsn,
  43. future=True,
  44. pool_size=MAX_POOL_SIZE,
  45. connect_args={
  46. "ssl": False # NEEDED FOR NGINX-UNIT OTHERWISE IT FAILS
  47. },
  48. )
  49. return async_sessionmaker(engine)
  50. @asynccontextmanager
  51. async def lifespan(app: Litestar):
  52. # Set up the database connection pool
  53. app.state.db_session = await setup_database()
  54. yield
  55. # Close the database connection pool
  56. await app.state.db_session().close()
  57. def get_num_queries(queries):
  58. try:
  59. query_count = int(queries)
  60. except (ValueError, TypeError):
  61. return 1
  62. if query_count < 1:
  63. return 1
  64. if query_count > 500:
  65. return 500
  66. return query_count
  67. @get("/json")
  68. async def json_serialization() -> Response:
  69. return Response(
  70. content=orjson.dumps({"message": "Hello, world!"}),
  71. media_type=MediaType.JSON,
  72. )
  73. @get("/db")
  74. async def single_database_query() -> Response:
  75. id_ = randint(1, 10000)
  76. async with app.state.db_session() as sess:
  77. result = await sess.get(World, id_)
  78. return Response(
  79. content=orjson.dumps(result.__json__()),
  80. media_type=MediaType.JSON,
  81. )
  82. @get("/queries")
  83. async def multiple_database_queries(queries: Any = None) -> Response:
  84. num_queries = get_num_queries(queries)
  85. data = []
  86. async with app.state.db_session() as sess:
  87. for id_ in sample(range(1, 10001), num_queries):
  88. result = await sess.get(World, id_)
  89. data.append(result.__json__())
  90. return Response(
  91. content=orjson.dumps(data),
  92. media_type=MediaType.JSON,
  93. )
  94. @get("/fortunes", media_type=MediaType.HTML)
  95. async def fortunes(request: Request) -> Template:
  96. async with app.state.db_session() as sess:
  97. ret = await sess.execute(select(Fortune.id, Fortune.message))
  98. data = ret.all()
  99. data.append(ADDITIONAL_FORTUNE)
  100. data.sort(key=sort_fortunes_key)
  101. return Template(
  102. "fortune.jinja",
  103. context={"request": request, "fortunes": data},
  104. media_type=MediaType.HTML,
  105. )
  106. @get("/updates")
  107. async def database_updates(queries: Any = None) -> Response:
  108. num_queries = get_num_queries(queries)
  109. ids = sorted(sample(range(1, 10000 + 1), num_queries))
  110. data = []
  111. async with app.state.db_session.begin() as sess:
  112. for id_ in ids:
  113. world = await sess.get(World, id_, populate_existing=True)
  114. world.randomnumber = randint(1, 10000)
  115. # force sqlalchemy to UPDATE entry even if the value has not changed
  116. # doesn't make sense in a real application, added only for pass `tfb verify`
  117. flag_modified(world, "randomnumber")
  118. data.append(world.__json__())
  119. return Response(
  120. content=orjson.dumps(data),
  121. media_type=MediaType.JSON,
  122. )
  123. @get("/plaintext", media_type=MediaType.TEXT)
  124. async def plaintext() -> bytes:
  125. return b"Hello, world!"
  126. app = Litestar(
  127. template_config=TemplateConfig(
  128. directory=Path("templates"),
  129. engine=JinjaTemplateEngine,
  130. ),
  131. lifespan=[lifespan],
  132. route_handlers=[
  133. json_serialization,
  134. single_database_query,
  135. multiple_database_queries,
  136. fortunes,
  137. database_updates,
  138. plaintext,
  139. ],
  140. )