Daniel Gatis 3 years ago
parent
commit
1f6cce322f
5 changed files with 75 additions and 75 deletions
  1. 2 1
      .github/workflows/lint_python.yml
  2. 5 3
      rembg/bg.py
  3. 42 51
      rembg/cli.py
  4. 13 10
      rembg/detect.py
  5. 13 10
      rembg/server.py

+ 2 - 1
.github/workflows/lint_python.yml

@@ -9,7 +9,8 @@ jobs:
       - uses: actions/checkout@v2
       - uses: actions/setup-python@v2
       - run: pip install --upgrade pip wheel
-      - run: pip install bandit black flake8 flake8-bugbear flake8-comprehensions isort safety
+      - run: pip install bandit black flake8 flake8-bugbear flake8-comprehensions isort safety mypy
+      - run: mypy --install-types --non-interactive --ignore-missing-imports ./rembg
       - run: bandit --recursive --skip B101,B104,B310,B311,B303 --exclude ./rembg/_version.py ./rembg
       - run: black --force-exclude rembg/_version.py --check --diff ./rembg
       - run: flake8 ./rembg --count --ignore=B008,E203,E266,E731,F401,F811,F841,W503 --max-complexity=15 --max-line-length=120 --show-source --statistics --exclude ./rembg/_version.py

+ 5 - 3
rembg/bg.py

@@ -35,7 +35,9 @@ def alpha_matting_cutout(
     # erode foreground/background
     structure = None
     if erode_structure_size > 0:
-        structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
+        structure = np.ones(
+            (erode_structure_size, erode_structure_size), dtype=np.uint8
+        )
 
     is_foreground = binary_erosion(is_foreground, structure=structure)
     is_background = binary_erosion(is_background, structure=structure, border_value=1)
@@ -82,12 +84,12 @@ def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Ima
 
 def remove(
     data: bytes,
-    session: Optional[ort.InferenceSession] = None,
     alpha_matting: bool = False,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_erode_structure_size: int = 10,
     alpha_matting_base_size: int = 1000,
+    session: Optional[ort.InferenceSession] = None,
     width: Optional[int] = None,
     height: Optional[int] = None,
 ) -> bytes:
@@ -96,7 +98,7 @@ def remove(
         img = resize_image(img, width, height)
 
     if session is None:
-        session = ort_session(session)
+        session = ort_session("u2net")
 
     mask = predict(session, np.array(img)).convert("L")
 

+ 42 - 51
rembg/cli.py

@@ -2,22 +2,18 @@ import argparse
 import glob
 import os
 from distutils.util import strtobool
+from typing import BinaryIO
+import sys
+from pathlib import Path
 
 import filetype
 from tqdm import tqdm
+import onnxruntime as ort
 
 from .bg import remove
 from .detect import ort_session
 
-sessions = {}
-
-
-def read(i):
-    i.buffer.read() if hasattr(i, "buffer") else i.read()
-
-
-def write(o, d):
-    o.buffer.write(d) if hasattr(o, "buffer") else o.write(d)
+sessions: dict[str, ort.InferenceSession] = {}
 
 
 def main():
@@ -81,23 +77,22 @@ def main():
         help="An input folder and an output folder.",
     )
 
-    ap.add_argument(
-        "-o",
-        "--output",
-        nargs="?",
-        default="-",
-        type=argparse.FileType("wb"),
-        help="Path to the output png image.",
-    )
-
     ap.add_argument(
         "input",
-        nargs="?",
-        default="-",
+        nargs=(None if sys.stdin.isatty() else "?"),
+        default=(None if sys.stdin.isatty() else sys.stdin.buffer),
         type=argparse.FileType("rb"),
         help="Path to the input image.",
     )
 
+    ap.add_argument(
+        "output",
+        nargs=(None if sys.stdin.isatty() else "?"),
+        default=(None if sys.stdin.isatty() else sys.stdout.buffer),
+        type=argparse.FileType("wb"),
+        help="Path to the output png image.",
+    )
+
     args = ap.parse_args()
     session = sessions.setdefault(args.model, ort_session(args.model))
 
@@ -110,54 +105,50 @@ def main():
         if not os.path.exists(output_path):
             os.makedirs(output_path)
 
-        files = set()
+        input_files = set()
 
-        for path in input_paths:
+        for input_path in input_paths:
             if os.path.isfile(path):
-                files.add(path)
+                input_files.add(path)
             else:
-                input_paths += set(glob.glob(path + "/*"))
+                input_paths += set(glob.glob(input_path + "/*"))
 
-        for fi in tqdm(files):
-            fi_type = filetype.guess(fi)
+        for input_file in tqdm(input_files):
+            input_file_type = filetype.guess(input_file)
 
-            if fi_type is None:
+            if input_file_type is None:
                 continue
-            elif fi_type.mime.find("image") < 0:
+
+            if input_file_type.mime.find("image") < 0:
                 continue
 
-            with open(fi, "rb") as input:
-                with open(
-                    os.path.join(
-                        output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"
-                    ),
-                    "wb",
-                ) as output:
-                    write(
-                        output,
-                        remove(
-                            read(input),
-                            session=session,
-                            alpha_matting=args.alpha_matting,
-                            alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
-                            alpha_matting_background_threshold=args.alpha_matting_background_threshold,
-                            alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
-                            alpha_matting_base_size=args.alpha_matting_base_size,
-                        ),
-                    )
+            out_file = os.path.join(
+                output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png"
+            )
+
+            Path(out_file).write_bytes(
+                remove(
+                    Path(input_file).read_bytes(),
+                    session=session,
+                    alpha_matting=args.alpha_matting,
+                    alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
+                    alpha_matting_background_threshold=args.alpha_matting_background_threshold,
+                    alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
+                    alpha_matting_base_size=args.alpha_matting_base_size,
+                )
+            )
 
     else:
-        write(
-            args.output,
+        args.output.write(
             remove(
-                read(args.input),
+                args.input.read(),
                 session=session,
                 alpha_matting=args.alpha_matting,
                 alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
                 alpha_matting_background_threshold=args.alpha_matting_background_threshold,
                 alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
                 alpha_matting_base_size=args.alpha_matting_base_size,
-            ),
+            )
         )
 
 

+ 13 - 10
rembg/detect.py

@@ -1,5 +1,6 @@
 import os
 import sys
+from contextlib import redirect_stdout
 
 import gdown
 import numpy as np
@@ -26,11 +27,13 @@ def ort_session(model_name: str) -> ort.InferenceSession:
     else:
         assert AssertionError("Choose between u2net, u2netp or u2net_human_seg")
 
-    gdown.cached_download(url, path, md5=md5, quiet=True)
+    with redirect_stdout(sys.stderr):
+        gdown.cached_download(url, path, md5=md5)
+
     return ort.InferenceSession(path)
 
 
-def norm_pred(d: np.array) -> np.array:
+def norm_pred(d: np.ndarray) -> np.ndarray:
     ma = np.max(d)
     mi = np.min(d)
     dn = (d - mi) / (ma - mi)
@@ -93,8 +96,8 @@ def color(sample: dict) -> dict:
     return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
 
 
-def preprocess(image: np.array) -> dict:
-    label_3 = np.zeros(image.shape)
+def preprocess(im_array: np.ndarray) -> dict:
+    label_3 = np.zeros(im_array.shape)
     label = np.zeros(label_3.shape[0:2])
 
     if 3 == len(label_3.shape):
@@ -102,21 +105,21 @@ def preprocess(image: np.array) -> dict:
     elif 2 == len(label_3.shape):
         label = label_3
 
-    if 3 == len(image.shape) and 2 == len(label.shape):
+    if 3 == len(im_array.shape) and 2 == len(label.shape):
         label = label[:, :, np.newaxis]
-    elif 2 == len(image.shape) and 2 == len(label.shape):
-        image = image[:, :, np.newaxis]
+    elif 2 == len(im_array.shape) and 2 == len(label.shape):
+        im_array = im_array[:, :, np.newaxis]
         label = label[:, :, np.newaxis]
 
-    sample = {"imidx": np.array([0]), "image": image, "label": label}
+    sample = {"imidx": np.array([0]), "image": im_array, "label": label}
     sample = rescale(sample, 320)
     sample = color(sample)
 
     return sample
 
 
-def predict(ort_session: ort.InferenceSession, item: np.array) -> Image:
-    sample = preprocess(item)
+def predict(ort_session: ort.InferenceSession, im_array: np.ndarray) -> Image:
+    sample = preprocess(im_array)
     inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
 
     ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}

+ 13 - 10
rembg/server.py

@@ -7,11 +7,12 @@ import uvicorn
 from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
 from PIL import Image
 from starlette.responses import Response
+import onnxruntime as ort
 
 from .bg import remove
 from .detect import ort_session
 
-sessions = {}
+sessions: dict[str, ort.InferenceSession] = {}
 app = FastAPI()
 
 
@@ -24,14 +25,14 @@ class ModelType(str, Enum):
 class CommonQueryParams:
     def __init__(
         self,
-        model: Optional[ModelType] = ModelType.u2net,
+        model: ModelType = Query(ModelType.u2net),
+        a: bool = Query(False),
+        af: int = Query(240, ge=0),
+        ab: int = Query(10, ge=0),
+        ae: int = Query(10, ge=0),
+        az: int = Query(1000, ge=0),
         width: Optional[int] = Query(None, gt=0),
         height: Optional[int] = Query(None, gt=0),
-        a: Optional[bool] = Query(False),
-        af: Optional[int] = Query(240, ge=0),
-        ab: Optional[int] = Query(10, ge=0),
-        ae: Optional[int] = Query(10, ge=0),
-        az: Optional[int] = Query(1000, ge=0),
     ):
         self.model = model
         self.width = width
@@ -47,7 +48,9 @@ def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
     return Response(
         remove(
             content,
-            session=sessions.setdefault(commons.model, ort_session(commons.model)),
+            session=sessions.setdefault(
+                commons.model.value, ort_session(commons.model.value)
+            ),
             width=commons.width,
             height=commons.height,
             alpha_matting=commons.a,
@@ -66,8 +69,8 @@ def get_index(url: str, commons: CommonQueryParams = Depends()):
 
 
 @app.post("/")
-def post_index(file: UploadFile = File(...), commons: CommonQueryParams = Depends()):
-    return im_without_bg(file.read(), commons)
+def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
+    return im_without_bg(file, commons)
 
 
 def main():