Parcourir la source

[aiohttp] Use SQLAlchemy directly (rather than aiopg) (#6248)

* Update aiohttp test.

* Update aiohttp to use async sqlalchemy.

* Update main.py

* Update drivername

* Stop cheating in ORM updates.

* Remove unused import.

Co-authored-by: Sam Bull <[email protected]>
Sam Bull il y a 4 ans
Parent
commit
fe99c809fc

+ 2 - 2
frameworks/Python/aiohttp/README.md

@@ -12,7 +12,7 @@ All test implementations are located within ([./app](app)).
 
 ## Description
 
-aiohttp with [aiopg + sqlalchemy](http://aiopg.readthedocs.io/en/stable/sa.html) and 
+aiohttp with [sqlalchemy](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html) and
 separately [asyncpg](https://magicstack.github.io/asyncpg/current/) for database access.
  
 [uvloop](https://github.com/MagicStack/uvloop) is used for a more performant event loop.
@@ -22,7 +22,7 @@ separately [asyncpg](https://magicstack.github.io/asyncpg/current/) for database
 PostgreSQL.
 
 Two variants:
-* ORM using [aiopg + sqlalchemy](http://aiopg.readthedocs.io/en/stable/sa.html)
+* ORM using [sqlalchemy](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html)
 * RAW using [asyncpg](https://magicstack.github.io/asyncpg/current/)
 
 **To enabled "RAW" mode set the following environment variable:**

+ 7 - 8
frameworks/Python/aiohttp/app/main.py

@@ -3,11 +3,12 @@ import multiprocessing
 from pathlib import Path
 
 import aiohttp_jinja2
-import aiopg.sa
 import asyncpg
 import jinja2
 from aiohttp import web
 from sqlalchemy.engine.url import URL
+from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
+from sqlalchemy.orm import sessionmaker
 
 from .views import (
     json,
@@ -32,13 +33,13 @@ def pg_dsn() -> str:
     """
     :return: DSN url suitable for sqlalchemy and aiopg.
     """
-    return str(URL(
+    return str(URL.create(
         database='hello_world',
         password=os.getenv('PGPASS', 'benchmarkdbpass'),
         host='tfb-database',
         port='5432',
         username=os.getenv('PGUSER', 'benchmarkdbuser'),
-        drivername='postgres',
+        drivername='postgresql',
     ))
 
 
@@ -52,16 +53,14 @@ async def db_ctx(app: web.Application):
     min_size = max(int(max_size / 2), 1)
     print(f'connection pool: min size: {min_size}, max size: {max_size}, orm: {CONNECTION_ORM}')
     if CONNECTION_ORM:
-        app['pg'] = await aiopg.sa.create_engine(dsn=dsn, minsize=min_size, maxsize=max_size, loop=app.loop)
+        engine = create_async_engine(dsn, future=True)
+        app['db_session'] = sessionmaker(engine, class_=AsyncSession)
     else:
         app['pg'] = await asyncpg.create_pool(dsn=dsn, min_size=min_size, max_size=max_size, loop=app.loop)
 
     yield
 
-    if CONNECTION_ORM:
-        app['pg'].close()
-        await app['pg'].wait_closed()
-    else:
+    if not CONNECTION_ORM:
         await app['pg'].close()
 
 

+ 25 - 27
frameworks/Python/aiohttp/app/views.py

@@ -8,7 +8,7 @@ import ujson
 
 from sqlalchemy import select
 
-from .models import sa_fortunes, sa_worlds, Fortune
+from .models import sa_fortunes, sa_worlds, Fortune, World
 
 json_response = partial(json_response, dumps=ujson.dumps)
 
@@ -37,10 +37,11 @@ async def single_database_query_orm(request):
     Test 2 ORM
     """
     id_ = randint(1, 10000)
-    async with request.app['pg'].acquire() as conn:
-        cur = await conn.execute(select([sa_worlds.c.randomnumber]).where(sa_worlds.c.id == id_))
-        r = await cur.first()
-    return json_response({'id': id_, 'randomNumber': r[0]})
+    async with request.app['db_session']() as sess:
+        # TODO(SA1.4.0b2): sess.scalar()
+        ret = await sess.execute(select(World.randomnumber).filter_by(id=id_))
+        num = ret.scalar()
+    return json_response({'id': id_, 'randomNumber': num})
 
 
 async def single_database_query_raw(request):
@@ -64,11 +65,12 @@ async def multiple_database_queries_orm(request):
     ids.sort()
 
     result = []
-    async with request.app['pg'].acquire() as conn:
+    async with request.app['db_session']() as sess:
         for id_ in ids:
-            cur = await conn.execute(select([sa_worlds.c.randomnumber]).where(sa_worlds.c.id == id_))
-            r = await cur.first()
-            result.append({'id': id_, 'randomNumber': r[0]})
+            # TODO(SA1.4.0b2): sess.scalar()
+            ret = await sess.execute(select(World.randomnumber).filter_by(id=id_))
+            num = ret.scalar()
+            result.append({'id': id_, 'randomNumber': num})
     return json_response(result)
 
 
@@ -97,9 +99,9 @@ async def fortunes(request):
     """
     Test 4 ORM
     """
-    async with request.app['pg'].acquire() as conn:
-        cur = await conn.execute(select([sa_fortunes.c.id, sa_fortunes.c.message]))
-        fortunes = list(await cur.fetchall())
+    async with request.app['db_session']() as sess:
+        ret = await sess.execute(select(Fortune.id, Fortune.message))
+        fortunes = ret.all()
     fortunes.append(Fortune(id=0, message='Additional fortune added at request time.'))
     fortunes.sort(key=attrgetter('message'))
     return {'fortunes': fortunes}
@@ -127,21 +129,17 @@ async def updates(request):
     ids = [randint(1, 10000) for _ in range(num_queries)]
     ids.sort()
 
-    async with request.app['pg'].acquire() as conn:
-        for id_ in ids:
-            cur = await conn.execute(
-                select([sa_worlds.c.randomnumber])
-                .where(sa_worlds.c.id == id_)
-            )
-            # the result of this is a dict with the previous random number `randomnumber` which we don't actually use
-            await cur.first()
-            rand_new = randint(1, 10000)
-            await conn.execute(
-                sa_worlds.update()
-                .where(sa_worlds.c.id == id_)
-                .values(randomnumber=rand_new)
-            )
-            result.append({'id': id_, 'randomNumber': rand_new})
+    # TODO(SA1.4.0b2): async with request.app['db_session'].begin() as sess:
+    async with request.app['db_session']() as sess:
+        async with sess.begin():
+            for id_ in ids:
+                rand_new = randint(1, 10000)
+                # TODO(SA1.4.0b2): world = await sess.get(World, id_)
+                ret = await sess.execute(select(World).filter_by(id=id_))
+                world = ret.scalar()
+                world.randomnumber = rand_new
+
+                result.append({'id': id_, 'randomNumber': rand_new})
     return json_response(result)
 
 async def updates_raw(request):

+ 1 - 2
frameworks/Python/aiohttp/requirements.txt

@@ -1,10 +1,9 @@
 aiohttp==3.7.3
 aiohttp-jinja2==1.4.2
-aiopg==1.0.0
 asyncpg==0.21.0
 cchardet==2.1.7
 gunicorn==20.0.4
 psycopg2==2.8.6
-SQLAlchemy==1.3.16
+SQLAlchemy==1.4.0b1
 ujson==2.0.3
 uvloop==0.14.0