Daniel Gatis 3 years ago
parent
commit
1f6cce322f
5 changed files with 75 additions and 75 deletions
  1. 2 1
      .github/workflows/lint_python.yml
  2. 5 3
      rembg/bg.py
  3. 42 51
      rembg/cli.py
  4. 13 10
      rembg/detect.py
  5. 13 10
      rembg/server.py

+ 2 - 1
.github/workflows/lint_python.yml

@@ -9,7 +9,8 @@ jobs:
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - uses: actions/setup-python@v2
       - uses: actions/setup-python@v2
       - run: pip install --upgrade pip wheel
       - run: pip install --upgrade pip wheel
-      - run: pip install bandit black flake8 flake8-bugbear flake8-comprehensions isort safety
+      - run: pip install bandit black flake8 flake8-bugbear flake8-comprehensions isort safety mypy
+      - run: mypy --install-types --non-interactive --ignore-missing-imports ./rembg
       - run: bandit --recursive --skip B101,B104,B310,B311,B303 --exclude ./rembg/_version.py ./rembg
       - run: bandit --recursive --skip B101,B104,B310,B311,B303 --exclude ./rembg/_version.py ./rembg
       - run: black --force-exclude rembg/_version.py --check --diff ./rembg
       - run: black --force-exclude rembg/_version.py --check --diff ./rembg
       - run: flake8 ./rembg --count --ignore=B008,E203,E266,E731,F401,F811,F841,W503 --max-complexity=15 --max-line-length=120 --show-source --statistics --exclude ./rembg/_version.py
       - run: flake8 ./rembg --count --ignore=B008,E203,E266,E731,F401,F811,F841,W503 --max-complexity=15 --max-line-length=120 --show-source --statistics --exclude ./rembg/_version.py

+ 5 - 3
rembg/bg.py

@@ -35,7 +35,9 @@ def alpha_matting_cutout(
     # erode foreground/background
     # erode foreground/background
     structure = None
     structure = None
     if erode_structure_size > 0:
     if erode_structure_size > 0:
-        structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
+        structure = np.ones(
+            (erode_structure_size, erode_structure_size), dtype=np.uint8
+        )
 
 
     is_foreground = binary_erosion(is_foreground, structure=structure)
     is_foreground = binary_erosion(is_foreground, structure=structure)
     is_background = binary_erosion(is_background, structure=structure, border_value=1)
     is_background = binary_erosion(is_background, structure=structure, border_value=1)
@@ -82,12 +84,12 @@ def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Ima
 
 
 def remove(
 def remove(
     data: bytes,
     data: bytes,
-    session: Optional[ort.InferenceSession] = None,
     alpha_matting: bool = False,
     alpha_matting: bool = False,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_erode_structure_size: int = 10,
     alpha_matting_erode_structure_size: int = 10,
     alpha_matting_base_size: int = 1000,
     alpha_matting_base_size: int = 1000,
+    session: Optional[ort.InferenceSession] = None,
     width: Optional[int] = None,
     width: Optional[int] = None,
     height: Optional[int] = None,
     height: Optional[int] = None,
 ) -> bytes:
 ) -> bytes:
@@ -96,7 +98,7 @@ def remove(
         img = resize_image(img, width, height)
         img = resize_image(img, width, height)
 
 
     if session is None:
     if session is None:
-        session = ort_session(session)
+        session = ort_session("u2net")
 
 
     mask = predict(session, np.array(img)).convert("L")
     mask = predict(session, np.array(img)).convert("L")
 
 

+ 42 - 51
rembg/cli.py

@@ -2,22 +2,18 @@ import argparse
 import glob
 import glob
 import os
 import os
 from distutils.util import strtobool
 from distutils.util import strtobool
+from typing import BinaryIO
+import sys
+from pathlib import Path
 
 
 import filetype
 import filetype
 from tqdm import tqdm
 from tqdm import tqdm
+import onnxruntime as ort
 
 
 from .bg import remove
 from .bg import remove
 from .detect import ort_session
 from .detect import ort_session
 
 
-sessions = {}
-
-
-def read(i):
-    i.buffer.read() if hasattr(i, "buffer") else i.read()
-
-
-def write(o, d):
-    o.buffer.write(d) if hasattr(o, "buffer") else o.write(d)
+sessions: dict[str, ort.InferenceSession] = {}
 
 
 
 
 def main():
 def main():
@@ -81,23 +77,22 @@ def main():
         help="An input folder and an output folder.",
         help="An input folder and an output folder.",
     )
     )
 
 
-    ap.add_argument(
-        "-o",
-        "--output",
-        nargs="?",
-        default="-",
-        type=argparse.FileType("wb"),
-        help="Path to the output png image.",
-    )
-
     ap.add_argument(
     ap.add_argument(
         "input",
         "input",
-        nargs="?",
-        default="-",
+        nargs=(None if sys.stdin.isatty() else "?"),
+        default=(None if sys.stdin.isatty() else sys.stdin.buffer),
         type=argparse.FileType("rb"),
         type=argparse.FileType("rb"),
         help="Path to the input image.",
         help="Path to the input image.",
     )
     )
 
 
+    ap.add_argument(
+        "output",
+        nargs=(None if sys.stdin.isatty() else "?"),
+        default=(None if sys.stdin.isatty() else sys.stdout.buffer),
+        type=argparse.FileType("wb"),
+        help="Path to the output png image.",
+    )
+
     args = ap.parse_args()
     args = ap.parse_args()
     session = sessions.setdefault(args.model, ort_session(args.model))
     session = sessions.setdefault(args.model, ort_session(args.model))
 
 
@@ -110,54 +105,50 @@ def main():
         if not os.path.exists(output_path):
         if not os.path.exists(output_path):
             os.makedirs(output_path)
             os.makedirs(output_path)
 
 
-        files = set()
+        input_files = set()
 
 
-        for path in input_paths:
+        for input_path in input_paths:
             if os.path.isfile(path):
             if os.path.isfile(path):
-                files.add(path)
+                input_files.add(path)
             else:
             else:
-                input_paths += set(glob.glob(path + "/*"))
+                input_paths += set(glob.glob(input_path + "/*"))
 
 
-        for fi in tqdm(files):
-            fi_type = filetype.guess(fi)
+        for input_file in tqdm(input_files):
+            input_file_type = filetype.guess(input_file)
 
 
-            if fi_type is None:
+            if input_file_type is None:
                 continue
                 continue
-            elif fi_type.mime.find("image") < 0:
+
+            if input_file_type.mime.find("image") < 0:
                 continue
                 continue
 
 
-            with open(fi, "rb") as input:
-                with open(
-                    os.path.join(
-                        output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"
-                    ),
-                    "wb",
-                ) as output:
-                    write(
-                        output,
-                        remove(
-                            read(input),
-                            session=session,
-                            alpha_matting=args.alpha_matting,
-                            alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
-                            alpha_matting_background_threshold=args.alpha_matting_background_threshold,
-                            alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
-                            alpha_matting_base_size=args.alpha_matting_base_size,
-                        ),
-                    )
+            out_file = os.path.join(
+                output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png"
+            )
+
+            Path(out_file).write_bytes(
+                remove(
+                    Path(input_file).read_bytes(),
+                    session=session,
+                    alpha_matting=args.alpha_matting,
+                    alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
+                    alpha_matting_background_threshold=args.alpha_matting_background_threshold,
+                    alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
+                    alpha_matting_base_size=args.alpha_matting_base_size,
+                )
+            )
 
 
     else:
     else:
-        write(
-            args.output,
+        args.output.write(
             remove(
             remove(
-                read(args.input),
+                args.input.read(),
                 session=session,
                 session=session,
                 alpha_matting=args.alpha_matting,
                 alpha_matting=args.alpha_matting,
                 alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
                 alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
                 alpha_matting_background_threshold=args.alpha_matting_background_threshold,
                 alpha_matting_background_threshold=args.alpha_matting_background_threshold,
                 alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
                 alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
                 alpha_matting_base_size=args.alpha_matting_base_size,
                 alpha_matting_base_size=args.alpha_matting_base_size,
-            ),
+            )
         )
         )
 
 
 
 

+ 13 - 10
rembg/detect.py

@@ -1,5 +1,6 @@
 import os
 import os
 import sys
 import sys
+from contextlib import redirect_stdout
 
 
 import gdown
 import gdown
 import numpy as np
 import numpy as np
@@ -26,11 +27,13 @@ def ort_session(model_name: str) -> ort.InferenceSession:
     else:
     else:
         assert AssertionError("Choose between u2net, u2netp or u2net_human_seg")
         assert AssertionError("Choose between u2net, u2netp or u2net_human_seg")
 
 
-    gdown.cached_download(url, path, md5=md5, quiet=True)
+    with redirect_stdout(sys.stderr):
+        gdown.cached_download(url, path, md5=md5)
+
     return ort.InferenceSession(path)
     return ort.InferenceSession(path)
 
 
 
 
-def norm_pred(d: np.array) -> np.array:
+def norm_pred(d: np.ndarray) -> np.ndarray:
     ma = np.max(d)
     ma = np.max(d)
     mi = np.min(d)
     mi = np.min(d)
     dn = (d - mi) / (ma - mi)
     dn = (d - mi) / (ma - mi)
@@ -93,8 +96,8 @@ def color(sample: dict) -> dict:
     return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
     return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
 
 
 
 
-def preprocess(image: np.array) -> dict:
-    label_3 = np.zeros(image.shape)
+def preprocess(im_array: np.ndarray) -> dict:
+    label_3 = np.zeros(im_array.shape)
     label = np.zeros(label_3.shape[0:2])
     label = np.zeros(label_3.shape[0:2])
 
 
     if 3 == len(label_3.shape):
     if 3 == len(label_3.shape):
@@ -102,21 +105,21 @@ def preprocess(image: np.array) -> dict:
     elif 2 == len(label_3.shape):
     elif 2 == len(label_3.shape):
         label = label_3
         label = label_3
 
 
-    if 3 == len(image.shape) and 2 == len(label.shape):
+    if 3 == len(im_array.shape) and 2 == len(label.shape):
         label = label[:, :, np.newaxis]
         label = label[:, :, np.newaxis]
-    elif 2 == len(image.shape) and 2 == len(label.shape):
-        image = image[:, :, np.newaxis]
+    elif 2 == len(im_array.shape) and 2 == len(label.shape):
+        im_array = im_array[:, :, np.newaxis]
         label = label[:, :, np.newaxis]
         label = label[:, :, np.newaxis]
 
 
-    sample = {"imidx": np.array([0]), "image": image, "label": label}
+    sample = {"imidx": np.array([0]), "image": im_array, "label": label}
     sample = rescale(sample, 320)
     sample = rescale(sample, 320)
     sample = color(sample)
     sample = color(sample)
 
 
     return sample
     return sample
 
 
 
 
-def predict(ort_session: ort.InferenceSession, item: np.array) -> Image:
-    sample = preprocess(item)
+def predict(ort_session: ort.InferenceSession, im_array: np.ndarray) -> Image:
+    sample = preprocess(im_array)
     inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
     inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
 
 
     ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}
     ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}

+ 13 - 10
rembg/server.py

@@ -7,11 +7,12 @@ import uvicorn
 from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
 from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
 from PIL import Image
 from PIL import Image
 from starlette.responses import Response
 from starlette.responses import Response
+import onnxruntime as ort
 
 
 from .bg import remove
 from .bg import remove
 from .detect import ort_session
 from .detect import ort_session
 
 
-sessions = {}
+sessions: dict[str, ort.InferenceSession] = {}
 app = FastAPI()
 app = FastAPI()
 
 
 
 
@@ -24,14 +25,14 @@ class ModelType(str, Enum):
 class CommonQueryParams:
 class CommonQueryParams:
     def __init__(
     def __init__(
         self,
         self,
-        model: Optional[ModelType] = ModelType.u2net,
+        model: ModelType = Query(ModelType.u2net),
+        a: bool = Query(False),
+        af: int = Query(240, ge=0),
+        ab: int = Query(10, ge=0),
+        ae: int = Query(10, ge=0),
+        az: int = Query(1000, ge=0),
         width: Optional[int] = Query(None, gt=0),
         width: Optional[int] = Query(None, gt=0),
         height: Optional[int] = Query(None, gt=0),
         height: Optional[int] = Query(None, gt=0),
-        a: Optional[bool] = Query(False),
-        af: Optional[int] = Query(240, ge=0),
-        ab: Optional[int] = Query(10, ge=0),
-        ae: Optional[int] = Query(10, ge=0),
-        az: Optional[int] = Query(1000, ge=0),
     ):
     ):
         self.model = model
         self.model = model
         self.width = width
         self.width = width
@@ -47,7 +48,9 @@ def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
     return Response(
     return Response(
         remove(
         remove(
             content,
             content,
-            session=sessions.setdefault(commons.model, ort_session(commons.model)),
+            session=sessions.setdefault(
+                commons.model.value, ort_session(commons.model.value)
+            ),
             width=commons.width,
             width=commons.width,
             height=commons.height,
             height=commons.height,
             alpha_matting=commons.a,
             alpha_matting=commons.a,
@@ -66,8 +69,8 @@ def get_index(url: str, commons: CommonQueryParams = Depends()):
 
 
 
 
 @app.post("/")
 @app.post("/")
-def post_index(file: UploadFile = File(...), commons: CommonQueryParams = Depends()):
-    return im_without_bg(file.read(), commons)
+def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
+    return im_without_bg(file, commons)
 
 
 
 
 def main():
 def main():