Browse Source

refact server

Daniel Gatis 3 years ago
parent
commit
4d90e8e92e
4 changed files with 59 additions and 65 deletions
  1. 1 1
      .github/workflows/lint_python.yml
  2. 3 2
      rembg/bg.py
  3. 12 7
      rembg/cli.py
  4. 43 55
      rembg/server.py

+ 1 - 1
.github/workflows/lint_python.yml

@@ -10,7 +10,7 @@ jobs:
       - uses: actions/setup-python@v2
       - run: pip install --upgrade pip wheel
       - run: pip install bandit black flake8 flake8-bugbear flake8-comprehensions isort safety
-      - run: bandit --recursive --skip B101,B104,B310,B311,B303 --exclude ./rembg/_version.py ./rembg
+      - run: bandit --recursive --skip B008,B101,B104,B310,B311,B303 --exclude ./rembg/_version.py ./rembg
       - run: black --force-exclude rembg/_version.py --check --diff ./rembg
       - run: flake8 ./rembg --count --ignore=E203,E266,E731,F401,F811,F841,W503 --max-complexity=15 --max-line-length=120 --show-source --statistics --exclude ./rembg/_version.py
       - run: isort --check-only --profile black ./rembg

+ 3 - 2
rembg/bg.py

@@ -94,7 +94,7 @@ def remove(
         img = resize_image(img, width, height)
 
     if session is None:
-        session = ort_session(model_name)
+        session = ort_session(session)
 
     mask = predict(session, np.array(img)).convert("L")
 
@@ -115,5 +115,6 @@ def remove(
 
     bio = io.BytesIO()
     cutout.save(bio, "PNG")
+    bio.seek(0)
 
-    return bio.getbuffer()
+    return bio.read()

+ 12 - 7
rembg/cli.py

@@ -12,6 +12,14 @@ from .detect import ort_session
 sessions = {}
 
 
+def read(i):
+    i.buffer.read() if hasattr(i, "buffer") else i.read()
+
+
+def write(o, d):
+    o.buffer.write(d) if hasattr(o, "buffer") else o.write(d)
+
+
 def main():
     ap = argparse.ArgumentParser()
 
@@ -91,9 +99,6 @@ def main():
     )
 
     args = ap.parse_args()
-
-    r = lambda i: i.buffer.read() if hasattr(i, "buffer") else i.read()
-    w = lambda o, data: o.buffer.write(data) if hasattr(o, "buffer") else o.write(data)
     session = sessions.setdefault(args.model, ort_session(args.model))
 
     if args.path:
@@ -128,10 +133,10 @@ def main():
                     ),
                     "wb",
                 ) as output:
-                    w(
+                    write(
                         output,
                         remove(
-                            r(input),
+                            read(input),
                             session=session,
                             alpha_matting=args.alpha_matting,
                             alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
@@ -142,10 +147,10 @@ def main():
                     )
 
     else:
-        w(
+        write(
             args.output,
             remove(
-                r(args.input),
+                read(args.input),
                 session=session,
                 alpha_matting=args.alpha_matting,
                 alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,

+ 43 - 55
rembg/server.py

@@ -1,13 +1,12 @@
 import argparse
 from enum import Enum
-from io import BytesIO
 from typing import Optional
 
 import requests
 import uvicorn
-from fastapi import FastAPI, Form, Query, UploadFile, File
+from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
 from PIL import Image
-from starlette.responses import StreamingResponse
+from starlette.responses import Response
 
 from .bg import remove
 from .detect import ort_session
@@ -22,64 +21,53 @@ class ModelType(str, Enum):
     u2net_human_seg = "u2net_human_seg"
 
 
[email protected]("/")
-def get_index(
-    url: str,
-    model: Optional[ModelType] = ModelType.u2net,
-    width: Optional[int] = Query(None, gt=0),
-    height: Optional[int] = Query(None, gt=0),
-    a: Optional[bool] = Query(False),
-    af: Optional[int] = Query(240, ge=0),
-    ab: Optional[int] = Query(10, ge=0),
-    ae: Optional[int] = Query(10, ge=0),
-    az: Optional[int] = Query(1000, ge=0),
-):
-    return StreamingResponse(
-        BytesIO(
-            remove(
-                requests.get(url).content,
-                session=sessions.setdefault(model, ort_session(model)),
-                width=width,
-                height=height,
-                alpha_matting=a,
-                alpha_matting_foreground_threshold=af,
-                alpha_matting_background_threshold=ab,
-                alpha_matting_erode_structure_size=ae,
-                alpha_matting_base_size=az,
-            )
+class CommonQueryParams:
+    def __init__(
+        self,
+        model: Optional[ModelType] = ModelType.u2net,
+        width: Optional[int] = Query(None, gt=0),
+        height: Optional[int] = Query(None, gt=0),
+        a: Optional[bool] = Query(False),
+        af: Optional[int] = Query(240, ge=0),
+        ab: Optional[int] = Query(10, ge=0),
+        ae: Optional[int] = Query(10, ge=0),
+        az: Optional[int] = Query(1000, ge=0),
+    ):
+        self.model = model
+        self.width = width
+        self.height = height
+        self.a = a
+        self.af = af
+        self.ab = ab
+        self.ae = ae
+        self.az = az
+
+
+def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
+    return Response(
+        remove(
+            content,
+            session=sessions.setdefault(commons.model, ort_session(commons.model)),
+            width=commons.width,
+            height=commons.height,
+            alpha_matting=commons.a,
+            alpha_matting_foreground_threshold=commons.af,
+            alpha_matting_background_threshold=commons.ab,
+            alpha_matting_erode_structure_size=commons.ae,
+            alpha_matting_base_size=commons.az,
         ),
         media_type="image/png",
     )
 
 
[email protected]("/")
+def get_index(url: str, commons: CommonQueryParams = Depends()):
+    return im_without_bg(requests.get(url).content, commons)
+
+
 @app.post("/")
-def post_index(
-    file: UploadFile = File(...),
-    model: Optional[ModelType] = ModelType.u2net,
-    width: Optional[int] = Query(None, gt=0),
-    height: Optional[int] = Query(None, gt=0),
-    a: Optional[bool] = Query(False),
-    af: Optional[int] = Query(240, ge=0),
-    ab: Optional[int] = Query(10, ge=0),
-    ae: Optional[int] = Query(10, ge=0),
-    az: Optional[int] = Query(1000, ge=0),
-):
-    return StreamingResponse(
-        BytesIO(
-            remove(
-                file.read(),
-                session=sessions.setdefault(model, ort_session(model)),
-                width=width,
-                height=height,
-                alpha_matting=a,
-                alpha_matting_foreground_threshold=af,
-                alpha_matting_background_threshold=ab,
-                alpha_matting_erode_structure_size=ae,
-                alpha_matting_base_size=az,
-            )
-        ),
-        media_type="image/png",
-    )
+def post_index(file: UploadFile = File(...), commons: CommonQueryParams = Depends()):
+    return im_without_bg(file.read(), commons)
 
 
 def main():