浏览代码

refact server

Daniel Gatis 3 年之前
父节点
当前提交
77b348d097
共有 5 个文件被更改,包括 111 次插入98 次删除
  1. 6 13
      rembg/bg.py
  2. 6 2
      rembg/cli.py
  3. 4 9
      rembg/detect.py
  4. 87 68
      rembg/server.py
  5. 8 6
      requirements.txt

+ 6 - 13
rembg/bg.py

@@ -7,7 +7,7 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
 from pymatting.util.util import stack_images
 from scipy.ndimage.morphology import binary_erosion
 
-from .detect import load_model, predict
+from .detect import ort_session, predict
 
 
 def alpha_matting_cutout(
@@ -67,15 +67,6 @@ def naive_cutout(img, mask):
     return cutout
 
 
-def get_model(model_name):
-    if model_name == "u2netp":
-        return load_model(model_name="u2netp")
-    if model_name == "u2net_human_seg":
-        return load_model(model_name="u2net_human_seg")
-    else:
-        return load_model(model_name="u2net")
-
-
 def resize_image(img, width, height):
     original_width, original_height = img.size
     width = original_width if width is None else width
@@ -89,7 +80,7 @@ def resize_image(img, width, height):
 
 def remove(
     data,
-    model_name="u2net",
+    session=None,
     alpha_matting=False,
     alpha_matting_foreground_threshold=240,
     alpha_matting_background_threshold=10,
@@ -102,8 +93,10 @@ def remove(
     if width is not None or height is not None:
         img = resize_image(img, width, height)
 
-    model = get_model(model_name)
-    mask = predict(model, np.array(img)).convert("L")
+    if session is None:
+        session = ort_session(model_name)
+
+    mask = predict(session, np.array(img)).convert("L")
 
     if alpha_matting:
         try:

+ 6 - 2
rembg/cli.py

@@ -7,6 +7,9 @@ import filetype
 from tqdm import tqdm
 
 from .bg import remove
+from .detect import ort_session
+
+sessions = {}
 
 
 def main():
@@ -91,6 +94,7 @@ def main():
 
     r = lambda i: i.buffer.read() if hasattr(i, "buffer") else i.read()
     w = lambda o, data: o.buffer.write(data) if hasattr(o, "buffer") else o.write(data)
+    session = sessions.setdefault(args.model, ort_session(args.model))
 
     if args.path:
         full_paths = [os.path.abspath(path) for path in args.path]
@@ -128,7 +132,7 @@ def main():
                         output,
                         remove(
                             r(input),
-                            model_name=args.model,
+                            session=session,
                             alpha_matting=args.alpha_matting,
                             alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
                             alpha_matting_background_threshold=args.alpha_matting_background_threshold,
@@ -142,7 +146,7 @@ def main():
             args.output,
             remove(
                 r(args.input),
-                model_name=args.model,
+                session=session,
                 alpha_matting=args.alpha_matting,
                 alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
                 alpha_matting_background_threshold=args.alpha_matting_background_threshold,

+ 4 - 9
rembg/detect.py

@@ -7,10 +7,8 @@ import onnxruntime as ort
 from PIL import Image
 from skimage import transform
 
-SESSIONS = {}
 
-
-def load_model(model_name: str = "u2net"):
+def ort_session(model_name):
     path = os.environ.get(
         "U2NETP_PATH",
         os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
@@ -26,13 +24,10 @@ def load_model(model_name: str = "u2net"):
         md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
         url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
     else:
-        print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
-
-    if SESSIONS.get(md5) is None:
-        gdown.cached_download(url, path, md5=md5, quiet=False)
-        SESSIONS[md5] = ort.InferenceSession(path)
+        assert AssertionError("Choose between u2net, u2netp or u2net_human_seg")
 
-    return SESSIONS[md5]
+    gdown.cached_download(url, path, md5=md5, quiet=True)
+    return ort.InferenceSession(path)
 
 
 def norm_pred(d):

+ 87 - 68
rembg/server.py

@@ -1,76 +1,85 @@
 import argparse
+from enum import Enum
 from io import BytesIO
-from urllib.parse import quote, unquote_plus
-from urllib.request import urlopen
+from typing import Optional
 
-from flask import Flask, request, send_file
-from waitress import serve
+import requests
+import uvicorn
+from fastapi import FastAPI, Form, Query, UploadFile
+from PIL import Image
+from starlette.responses import StreamingResponse
 
 from .bg import remove
+from .detect import ort_session
+
+sessions = {}
+app = FastAPI()
+
+
+class ModelType(str, Enum):
+    u2net = "u2net"
+    u2netp = "u2netp"
+    u2net_human_seg = "u2net_human_seg"
+
+
[email protected]("/")
+def get_index(
+    url: str,
+    model: Optional[ModelType] = ModelType.u2net,
+    width: Optional[int] = Query(None, gt=0),
+    height: Optional[int] = Query(None, gt=0),
+    a: Optional[bool] = Query(False),
+    af: Optional[int] = Query(240, ge=0),
+    ab: Optional[int] = Query(10, ge=0),
+    ae: Optional[int] = Query(10, ge=0),
+    az: Optional[int] = Query(1000, ge=0),
+):
+    return StreamingResponse(
+        BytesIO(
+            remove(
+                requests.get(url).content,
+                session=sessions.setdefault(model, ort_session(model)),
+                width=width,
+                height=height,
+                alpha_matting=a,
+                alpha_matting_foreground_threshold=af,
+                alpha_matting_background_threshold=ab,
+                alpha_matting_erode_structure_size=ae,
+                alpha_matting_base_size=az,
+            )
+        ),
+        media_type="image/png",
+    )
+
 
-app = Flask(__name__)
-
-
[email protected]("/", methods=["GET", "POST"])
-def index():
-    file_content = ""
-
-    if request.method == "POST":
-        if "file" not in request.files:
-            return {"error": "missing post form param 'file'"}, 400
-
-        file_content = request.files["file"].read()
-
-    if request.method == "GET":
-        url = request.args.get("url", type=str)
-        if url is None:
-            return {"error": "missing query param 'url'"}, 400
-
-        url = unquote_plus(url)
-        if " " in url:
-            url = quote(url, safe="/:")
-
-        file_content = urlopen(url).read()
-
-    if file_content == "":
-        return {"error": "File content is empty"}, 400
-
-    alpha_matting = "a" in request.values
-    af = request.values.get("af", type=int, default=240)
-    ab = request.values.get("ab", type=int, default=10)
-    ae = request.values.get("ae", type=int, default=10)
-    az = request.values.get("az", type=int, default=1000)
-    width = request.args.get("width", type=int)
-    height = request.args.get("height", type=int)
-
-    model = request.values.get("model", type=str, default="u2net")
-    model_choices = ["u2net", "u2netp", "u2net_human_seg"]
-
-    if model not in model_choices:
-        return {
-            "error": f"invalid query param 'model'. Available options are {model_choices}"
-        }, 400
-
-    try:
-        return send_file(
-            BytesIO(
-                remove(
-                    file_content,
-                    width=width,
-                    height=height,
-                    model_name=model,
-                    alpha_matting=alpha_matting,
-                    alpha_matting_foreground_threshold=af,
-                    alpha_matting_background_threshold=ab,
-                    alpha_matting_erode_structure_size=ae,
-                    alpha_matting_base_size=az,
-                )
-            ),
-            mimetype="image/png",
-        )
-    except Exception as e:
-        app.logger.exception(e, exc_info=True)
-        return {"error": "oops, something went wrong!"}, 500
[email protected]("/")
+def post_index(
+    file: UploadFile = File(...),
+    model: Optional[ModelType] = ModelType.u2net,
+    width: Optional[int] = Query(None, gt=0),
+    height: Optional[int] = Query(None, gt=0),
+    a: Optional[bool] = Query(False),
+    af: Optional[int] = Query(240, ge=0),
+    ab: Optional[int] = Query(10, ge=0),
+    ae: Optional[int] = Query(10, ge=0),
+    az: Optional[int] = Query(1000, ge=0),
+):
+    return StreamingResponse(
+        BytesIO(
+            remove(
+                file.read(),
+                session=sessions.setdefault(model, ort_session(model)),
+                width=width,
+                height=height,
+                alpha_matting=a,
+                alpha_matting_foreground_threshold=af,
+                alpha_matting_background_threshold=ab,
+                alpha_matting_erode_structure_size=ae,
+                alpha_matting_base_size=az,
+            )
+        ),
+        media_type="image/png",
+    )
 
 
 def main():
@@ -92,8 +101,18 @@ def main():
         help="The port to bind to.",
     )
 
+    ap.add_argument(
+        "-l",
+        "--log_level",
+        default="info",
+        type=str,
+        help="The log level.",
+    )
+
     args = ap.parse_args()
-    serve(app, host=args.addr, port=args.port)
+    uvicorn.run(
+        "rembg.server:app", host=args.addr, port=args.port, log_level=args.log_level
+    )
 
 
 if __name__ == "__main__":

+ 8 - 6
requirements.txt

@@ -1,10 +1,12 @@
-filetype==1.0.7
-flask==1.1.2
+fastapi==0.72.0
+filetype==1.0.9
 gdown==4.2.0
-numpy==1.21.0
+numpy==1.22.1
 pillow==9.0.0
 pymatting==1.1.5
+python-multipart==0.0.5
+requests==2.27.1
 scikit-image==0.19.1
-scipy==1.5.4
-tqdm==4.51.0
-waitress==1.4.4
+scipy==1.7.3
+tqdm==4.62.3
+uvicorn==0.17.0