Daniel Gatis il y a 3 ans
Parent
commit
385d34da4a
6 fichiers modifiés avec 278 ajouts et 264 suppressions
  1. 6 6
      README.md
  2. 3 7
      rembg/bg.py
  3. 268 140
      rembg/cli.py
  4. 0 110
      rembg/server.py
  5. 1 0
      requirements.txt
  6. 0 1
      setup.py

+ 6 - 6
README.md

@@ -53,24 +53,24 @@ GPU=1 pip install rembg
 
 Remove the background from a remote image
 ```bash
-curl -s http://input.png | rembg > output.png
+curl -s http://input.png | rembg i > output.png
 ```
 
 Remove the background from a local file
 ```bash
-rembg -o path/to/output.png path/to/input.png
+rembg i path/to/input.png path/to/output.png
 ```
 
 Remove the background from all images in a folder
 ```bash
-rembg -p path/to/input path/to/output
+rembg p path/to/input path/to/output
 ```
 
 ### Usage as a server
 
 Start the server
 ```bash
-rembg-server
+rembg s
 ```
 
 Open your browser to
@@ -140,14 +140,14 @@ docker build . -t rembg
 Then run with:
 
 ```
-docker run --rm -i rembg <in.png >out.png
+docker run --rm -i rembg i in.png out.png
 ```
 
 ### Advance usage
 
 Sometimes it is possible to achieve better results by turning on alpha matting. Example:
 ```bash
-curl -s http://input.png | rembg -a -ae 15 > output.png
+curl -s http://input.png | rembg i -a -ae 15 > output.png
 ```
 
 <table>

+ 3 - 7
rembg/bg.py

@@ -75,11 +75,7 @@ def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Ima
     original_width, original_height = img.size
     width = original_width if width is None else width
     height = original_height if height is None else height
-    return (
-        img.resize((width, height))
-        if original_width != width or original_height != height
-        else img
-    )
+    return img.resize((width, height))
 
 
 def remove(
@@ -87,7 +83,7 @@ def remove(
     alpha_matting: bool = False,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
-    alpha_matting_erode_structure_size: int = 10,
+    alpha_matting_erode_size: int = 10,
     alpha_matting_base_size: int = 1000,
     session: Optional[ort.InferenceSession] = None,
     width: Optional[int] = None,
@@ -109,7 +105,7 @@ def remove(
                 mask,
                 alpha_matting_foreground_threshold,
                 alpha_matting_background_threshold,
-                alpha_matting_erode_structure_size,
+                alpha_matting_erode_size,
                 alpha_matting_base_size,
             )
         except Exception:

+ 268 - 140
rembg/cli.py

@@ -1,156 +1,284 @@
-import argparse
-import glob
-import os
-from distutils.util import strtobool
-from typing import BinaryIO
+import pathlib
 import sys
-from pathlib import Path
+from enum import Enum
+from typing import IO, Optional
 
+import click
 import filetype
-from tqdm import tqdm
 import onnxruntime as ort
+import requests
+import uvicorn
+from fastapi import Depends, FastAPI, File, Query
+from starlette.responses import Response
+from tqdm import tqdm
 
 from .bg import remove
 from .detect import ort_session
 
-sessions: dict[str, ort.InferenceSession] = {}
-
 
[email protected]()
[email protected]_option()
 def main():
-    ap = argparse.ArgumentParser()
-
-    ap.add_argument(
-        "-m",
-        "--model",
-        default="u2net",
-        type=str,
-        choices=["u2net", "u2netp", "u2net_human_seg"],
-        help="The model name.",
-    )
-
-    ap.add_argument(
-        "-a",
-        "--alpha-matting",
-        nargs="?",
-        const=True,
-        default=False,
-        type=lambda x: bool(strtobool(x)),
-        help="When true use alpha matting cutout.",
-    )
-
-    ap.add_argument(
-        "-af",
-        "--alpha-matting-foreground-threshold",
-        default=240,
-        type=int,
-        help="The trimap foreground threshold.",
-    )
-
-    ap.add_argument(
-        "-ab",
-        "--alpha-matting-background-threshold",
-        default=10,
-        type=int,
-        help="The trimap background threshold.",
-    )
-
-    ap.add_argument(
-        "-ae",
-        "--alpha-matting-erode-size",
-        default=10,
-        type=int,
-        help="Size of element used for the erosion.",
-    )
-
-    ap.add_argument(
-        "-az",
-        "--alpha-matting-base-size",
-        default=1000,
-        type=int,
-        help="The image base size.",
-    )
-
-    ap.add_argument(
-        "-p",
-        "--path",
-        nargs=2,
-        help="An input folder and an output folder.",
-    )
-
-    ap.add_argument(
-        "input",
-        nargs=(None if sys.stdin.isatty() else "?"),
-        default=(None if sys.stdin.isatty() else sys.stdin.buffer),
-        type=argparse.FileType("rb"),
-        help="Path to the input image.",
-    )
-
-    ap.add_argument(
-        "output",
-        nargs=(None if sys.stdin.isatty() else "?"),
-        default=(None if sys.stdin.isatty() else sys.stdout.buffer),
-        type=argparse.FileType("wb"),
-        help="Path to the output png image.",
-    )
-
-    args = ap.parse_args()
-    session = sessions.setdefault(args.model, ort_session(args.model))
-
-    if args.path:
-        full_paths = [os.path.abspath(path) for path in args.path]
-
-        input_paths = [full_paths[0]]
-        output_path = full_paths[1]
-
-        if not os.path.exists(output_path):
-            os.makedirs(output_path)
-
-        input_files = set()
-
-        for input_path in input_paths:
-            if os.path.isfile(path):
-                input_files.add(path)
-            else:
-                input_paths += set(glob.glob(input_path + "/*"))
-
-        for input_file in tqdm(input_files):
-            input_file_type = filetype.guess(input_file)
-
-            if input_file_type is None:
-                continue
-
-            if input_file_type.mime.find("image") < 0:
-                continue
-
-            out_file = os.path.join(
-                output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png"
-            )
-
-            Path(out_file).write_bytes(
-                remove(
-                    Path(input_file).read_bytes(),
-                    session=session,
-                    alpha_matting=args.alpha_matting,
-                    alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
-                    alpha_matting_background_threshold=args.alpha_matting_background_threshold,
-                    alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
-                    alpha_matting_base_size=args.alpha_matting_base_size,
-                )
-            )
-
-    else:
-        args.output.write(
+    pass
+
+
[email protected](help="for a file as input")
[email protected](
+    "-m",
+    "--model",
+    default="u2net",
+    type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
+    show_default=True,
+    show_choices=True,
+    help="model name",
+)
[email protected](
+    "-a",
+    "--alpha-matting",
+    is_flag=True,
+    show_default=True,
+    help="use alpha matting",
+)
[email protected](
+    "-af",
+    "--alpha-matting-foreground-threshold",
+    default=240,
+    type=int,
+    show_default=True,
+    help="trimap fg threshold",
+)
[email protected](
+    "-ab",
+    "--alpha-matting-background-threshold",
+    default=10,
+    type=int,
+    show_default=True,
+    help="trimap bg threshold",
+)
[email protected](
+    "-ae",
+    "--alpha-matting-erode-size",
+    default=10,
+    type=int,
+    show_default=True,
+    help="erode size",
+)
[email protected](
+    "-az",
+    "--alpha-matting-base-size",
+    default=1000,
+    type=int,
+    show_default=True,
+    help="image base size",
+)
[email protected](
+    "-w",
+    "--width",
+    default=None,
+    type=int,
+    show_default=True,
+    help="output image size",
+)
[email protected](
+    "-h",
+    "--height",
+    default=None,
+    type=int,
+    show_default=True,
+    help="output image size",
+)
[email protected](
+    "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
+)
[email protected](
+    "output",
+    default=(None if sys.stdin.isatty() else "-"),
+    type=click.File("wb", lazy=True),
+)
+def i(model: str, input: IO, output: IO, **kwargs: dict):
+    output.write(remove(input.read(), session=ort_session(model), **kwargs))
+
+
[email protected](help="for a folder as input")
[email protected](
+    "-m",
+    "--model",
+    default="u2net",
+    type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
+    show_default=True,
+    show_choices=True,
+    help="model name",
+)
[email protected](
+    "-a",
+    "--alpha-matting",
+    is_flag=True,
+    show_default=True,
+    help="use alpha matting",
+)
[email protected](
+    "-af",
+    "--alpha-matting-foreground-threshold",
+    default=240,
+    type=int,
+    show_default=True,
+    help="trimap fg threshold",
+)
[email protected](
+    "-ab",
+    "--alpha-matting-background-threshold",
+    default=10,
+    type=int,
+    show_default=True,
+    help="trimap bg threshold",
+)
[email protected](
+    "-ae",
+    "--alpha-matting-erode-size",
+    default=10,
+    type=int,
+    show_default=True,
+    help="erode size",
+)
[email protected](
+    "-az",
+    "--alpha-matting-base-size",
+    default=1000,
+    type=int,
+    show_default=True,
+    help="image base size",
+)
[email protected](
+    "-w",
+    "--width",
+    default=None,
+    type=int,
+    show_default=True,
+    help="output image size",
+)
[email protected](
+    "-h",
+    "--height",
+    default=None,
+    type=int,
+    show_default=True,
+    help="output image size",
+)
[email protected](
+    "input",
+    type=click.Path(
+        exists=True,
+        path_type=pathlib.Path,
+        file_okay=False,
+        dir_okay=True,
+        readable=True,
+    ),
+)
[email protected](
+    "output",
+    type=click.Path(
+        exists=False,
+        path_type=pathlib.Path,
+        file_okay=False,
+        dir_okay=True,
+        writable=True,
+    ),
+)
+def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs: dict):
+    session = ort_session(model)
+    for each_input in tqdm(list(input.glob("**/*"))):
+        if each_input.is_dir():
+            continue
+
+        mimetype = filetype.guess(each_input)
+        if mimetype is None:
+            continue
+        if mimetype.mime.find("image") < 0:
+            continue
+
+        each_output = (output / each_input.name).with_suffix(".png")
+        each_output.parents[0].mkdir(parents=True, exist_ok=True)
+
+        each_output.write_bytes(
+            remove(each_input.read_bytes(), session=session, **kwargs)
+        )
+
+
[email protected](help="for a http server")
[email protected](
+    "-p",
+    "--port",
+    default=5000,
+    type=int,
+    show_default=True,
+    help="port",
+)
[email protected](
+    "-l",
+    "--log_level",
+    default="info",
+    type=str,
+    show_default=True,
+    help="log level",
+)
+def s(port: int, log_level: str):
+    sessions: dict[str, ort.InferenceSession] = {}
+    app = FastAPI()
+
+    class ModelType(str, Enum):
+        u2net = "u2net"
+        u2netp = "u2netp"
+        u2net_human_seg = "u2net_human_seg"
+
+    class CommonQueryParams:
+        def __init__(
+            self,
+            model: ModelType = Query(ModelType.u2net),
+            a: bool = Query(False),
+            af: int = Query(240, ge=0),
+            ab: int = Query(10, ge=0),
+            ae: int = Query(10, ge=0),
+            az: int = Query(1000, ge=0),
+            width: Optional[int] = Query(None, gt=0),
+            height: Optional[int] = Query(None, gt=0),
+        ):
+            self.model = model
+            self.width = width
+            self.height = height
+            self.a = a
+            self.af = af
+            self.ab = ab
+            self.ae = ae
+            self.az = az
+
+    def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
+        return Response(
             remove(
-                args.input.read(),
-                session=session,
-                alpha_matting=args.alpha_matting,
-                alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
-                alpha_matting_background_threshold=args.alpha_matting_background_threshold,
-                alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
-                alpha_matting_base_size=args.alpha_matting_base_size,
-            )
+                content,
+                session=sessions.setdefault(
+                    commons.model.value, ort_session(commons.model.value)
+                ),
+                width=commons.width,
+                height=commons.height,
+                alpha_matting=commons.a,
+                alpha_matting_foreground_threshold=commons.af,
+                alpha_matting_background_threshold=commons.ab,
+                alpha_matting_erode_size=commons.ae,
+                alpha_matting_base_size=commons.az,
+            ),
+            media_type="image/png",
         )
 
+    @app.get("/")
+    def get_index(url: str, commons: CommonQueryParams = Depends()):
+        return im_without_bg(requests.get(url).content, commons)
+
+    @app.post("/")
+    def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
+        return im_without_bg(file, commons)
+
+    uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
+
 
 if __name__ == "__main__":
     main()

+ 0 - 110
rembg/server.py

@@ -1,110 +0,0 @@
-import argparse
-from enum import Enum
-from typing import Optional
-
-import requests
-import uvicorn
-from fastapi import Depends, FastAPI, File, Form, Query, UploadFile
-from PIL import Image
-from starlette.responses import Response
-import onnxruntime as ort
-
-from .bg import remove
-from .detect import ort_session
-
-sessions: dict[str, ort.InferenceSession] = {}
-app = FastAPI()
-
-
-class ModelType(str, Enum):
-    u2net = "u2net"
-    u2netp = "u2netp"
-    u2net_human_seg = "u2net_human_seg"
-
-
-class CommonQueryParams:
-    def __init__(
-        self,
-        model: ModelType = Query(ModelType.u2net),
-        a: bool = Query(False),
-        af: int = Query(240, ge=0),
-        ab: int = Query(10, ge=0),
-        ae: int = Query(10, ge=0),
-        az: int = Query(1000, ge=0),
-        width: Optional[int] = Query(None, gt=0),
-        height: Optional[int] = Query(None, gt=0),
-    ):
-        self.model = model
-        self.width = width
-        self.height = height
-        self.a = a
-        self.af = af
-        self.ab = ab
-        self.ae = ae
-        self.az = az
-
-
-def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
-    return Response(
-        remove(
-            content,
-            session=sessions.setdefault(
-                commons.model.value, ort_session(commons.model.value)
-            ),
-            width=commons.width,
-            height=commons.height,
-            alpha_matting=commons.a,
-            alpha_matting_foreground_threshold=commons.af,
-            alpha_matting_background_threshold=commons.ab,
-            alpha_matting_erode_structure_size=commons.ae,
-            alpha_matting_base_size=commons.az,
-        ),
-        media_type="image/png",
-    )
-
-
[email protected]("/")
-def get_index(url: str, commons: CommonQueryParams = Depends()):
-    return im_without_bg(requests.get(url).content, commons)
-
-
[email protected]("/")
-def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
-    return im_without_bg(file, commons)
-
-
-def main():
-    ap = argparse.ArgumentParser()
-
-    ap.add_argument(
-        "-a",
-        "--addr",
-        default="0.0.0.0",
-        type=str,
-        help="The IP address to bind to.",
-    )
-
-    ap.add_argument(
-        "-p",
-        "--port",
-        default=5000,
-        type=int,
-        help="The port to bind to.",
-    )
-
-    ap.add_argument(
-        "-l",
-        "--log_level",
-        default="info",
-        type=str,
-        help="The log level.",
-    )
-
-    args = ap.parse_args()
-    uvicorn.run(
-        "rembg.server:app", host=args.addr, port=args.port, log_level=args.log_level
-    )
-
-
-if __name__ == "__main__":
-    main()

+ 1 - 0
requirements.txt

@@ -10,3 +10,4 @@ scikit-image==0.19.1
 scipy==1.7.3
 tqdm==4.62.3
 uvicorn==0.17.0
+click==8.0.3

+ 0 - 1
setup.py

@@ -40,7 +40,6 @@ setup(
     entry_points={
         "console_scripts": [
             "rembg=rembg.cli:main",
-            "rembg-server=rembg.server:main",
         ],
     },
     version=versioneer.get_version(),