Bläddra i källkod

add mask option and better gpu install

Daniel Gatis 3 år sedan
förälder
incheckning
722e23cc8d
7 ändrade filer med 26 tillägg och 82 borttagningar
  1. 1 1
      Dockerfile
  2. 1 1
      README.md
  3. 7 22
      rembg/bg.py
  4. 11 51
      rembg/cli.py
  5. 0 1
      requirements-cpu.txt
  6. 1 0
      requirements.txt
  7. 5 6
      setup.py

+ 1 - 1
Dockerfile

@@ -11,7 +11,7 @@ WORKDIR /rembg
 
 
 COPY . .
 COPY . .
 
 
-RUN GPU=1 pip3 install .
+RUN ["pip3", "install", ".[gpu]"]
 
 
 ENTRYPOINT ["rembg"]
 ENTRYPOINT ["rembg"]
 CMD []
 CMD []

+ 1 - 1
README.md

@@ -45,7 +45,7 @@ pip install rembg
 
 
 GPU support:
 GPU support:
 ```bash
 ```bash
-GPU=1 pip install rembg
+pip install rembg[gpu]
 ```
 ```
 
 
 ### Usage as a cli
 ### Usage as a cli

+ 7 - 22
rembg/bg.py

@@ -18,13 +18,7 @@ def alpha_matting_cutout(
     foreground_threshold: int,
     foreground_threshold: int,
     background_threshold: int,
     background_threshold: int,
     erode_structure_size: int,
     erode_structure_size: int,
-    base_size: int,
 ) -> Image:
 ) -> Image:
-    size = img.size
-
-    img.thumbnail((base_size, base_size), Image.LANCZOS)
-    mask = mask.resize(img.size, Image.LANCZOS)
-
     img = np.asarray(img)
     img = np.asarray(img)
     mask = np.asarray(mask)
     mask = np.asarray(mask)
 
 
@@ -60,45 +54,37 @@ def alpha_matting_cutout(
 
 
     cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
     cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
     cutout = Image.fromarray(cutout)
     cutout = Image.fromarray(cutout)
-    cutout = cutout.resize(size, Image.LANCZOS)
 
 
     return cutout
     return cutout
 
 
 
 
 def naive_cutout(img: Image, mask: Image) -> Image:
 def naive_cutout(img: Image, mask: Image) -> Image:
     empty = Image.new("RGBA", (img.size), 0)
     empty = Image.new("RGBA", (img.size), 0)
-    cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
+    cutout = Image.composite(img, empty, mask)
     return cutout
     return cutout
 
 
 
 
-def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Image:
-    original_width, original_height = img.size
-    width = original_width if width is None else width
-    height = original_height if height is None else height
-    return img.resize((width, height))
-
-
 def remove(
 def remove(
     data: bytes,
     data: bytes,
     alpha_matting: bool = False,
     alpha_matting: bool = False,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_erode_size: int = 10,
     alpha_matting_erode_size: int = 10,
-    alpha_matting_base_size: int = 1000,
     session: Optional[ort.InferenceSession] = None,
     session: Optional[ort.InferenceSession] = None,
-    width: Optional[int] = None,
-    height: Optional[int] = None,
+    only_mask: bool = False,
 ) -> bytes:
 ) -> bytes:
     img = Image.open(io.BytesIO(data)).convert("RGB")
     img = Image.open(io.BytesIO(data)).convert("RGB")
-    if width is not None or height is not None:
-        img = resize_image(img, width, height)
 
 
     if session is None:
     if session is None:
         session = ort_session("u2net")
         session = ort_session("u2net")
 
 
     mask = predict(session, np.array(img)).convert("L")
     mask = predict(session, np.array(img)).convert("L")
+    mask = mask.resize(img.size, Image.LANCZOS)
+
+    if only_mask:
+        cutout = mask
 
 
-    if alpha_matting:
+    elif alpha_matting:
         try:
         try:
             cutout = alpha_matting_cutout(
             cutout = alpha_matting_cutout(
                 img,
                 img,
@@ -106,7 +92,6 @@ def remove(
                 alpha_matting_foreground_threshold,
                 alpha_matting_foreground_threshold,
                 alpha_matting_background_threshold,
                 alpha_matting_background_threshold,
                 alpha_matting_erode_size,
                 alpha_matting_erode_size,
-                alpha_matting_base_size,
             )
             )
         except Exception:
         except Exception:
             cutout = naive_cutout(img, mask)
             cutout = naive_cutout(img, mask)

+ 11 - 51
rembg/cli.py

@@ -66,28 +66,11 @@ def main():
     help="erode size",
     help="erode size",
 )
 )
 @click.option(
 @click.option(
-    "-az",
-    "--alpha-matting-base-size",
-    default=1000,
-    type=int,
-    show_default=True,
-    help="image base size",
-)
[email protected](
-    "-w",
-    "--width",
-    default=None,
-    type=int,
-    show_default=True,
-    help="output image size",
-)
[email protected](
-    "-h",
-    "--height",
-    default=None,
-    type=int,
+    "-om",
+    "--only-mask",
+    is_flag=True,
     show_default=True,
     show_default=True,
-    help="output image size",
+    help="output only the mask",
 )
 )
 @click.argument(
 @click.argument(
     "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
     "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
@@ -143,28 +126,11 @@ def i(model: str, input: IO, output: IO, **kwargs):
     help="erode size",
     help="erode size",
 )
 )
 @click.option(
 @click.option(
-    "-az",
-    "--alpha-matting-base-size",
-    default=1000,
-    type=int,
-    show_default=True,
-    help="image base size",
-)
[email protected](
-    "-w",
-    "--width",
-    default=None,
-    type=int,
-    show_default=True,
-    help="output image size",
-)
[email protected](
-    "-h",
-    "--height",
-    default=None,
-    type=int,
+    "-om",
+    "--only-mask",
+    is_flag=True,
     show_default=True,
     show_default=True,
-    help="output image size",
+    help="output only the mask",
 )
 )
 @click.argument(
 @click.argument(
     "input",
     "input",
@@ -240,18 +206,14 @@ def s(port: int, log_level: str):
             af: int = Query(240, ge=0),
             af: int = Query(240, ge=0),
             ab: int = Query(10, ge=0),
             ab: int = Query(10, ge=0),
             ae: int = Query(10, ge=0),
             ae: int = Query(10, ge=0),
-            az: int = Query(1000, ge=0),
-            width: Optional[int] = Query(None, gt=0),
-            height: Optional[int] = Query(None, gt=0),
+            om: bool = Query(False),
         ):
         ):
             self.model = model
             self.model = model
-            self.width = width
-            self.height = height
             self.a = a
             self.a = a
             self.af = af
             self.af = af
             self.ab = ab
             self.ab = ab
             self.ae = ae
             self.ae = ae
-            self.az = az
+            self.om = om
 
 
     def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
     def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
         return Response(
         return Response(
@@ -260,13 +222,11 @@ def s(port: int, log_level: str):
                 session=sessions.setdefault(
                 session=sessions.setdefault(
                     commons.model.value, ort_session(commons.model.value)
                     commons.model.value, ort_session(commons.model.value)
                 ),
                 ),
-                width=commons.width,
-                height=commons.height,
                 alpha_matting=commons.a,
                 alpha_matting=commons.a,
                 alpha_matting_foreground_threshold=commons.af,
                 alpha_matting_foreground_threshold=commons.af,
                 alpha_matting_background_threshold=commons.ab,
                 alpha_matting_background_threshold=commons.ab,
                 alpha_matting_erode_size=commons.ae,
                 alpha_matting_erode_size=commons.ae,
-                alpha_matting_base_size=commons.az,
+                only_mask=commons.om,
             ),
             ),
             media_type="image/png",
             media_type="image/png",
         )
         )

+ 0 - 1
requirements-cpu.txt

@@ -1 +0,0 @@
-onnxruntime==1.10.0

+ 1 - 0
requirements.txt

@@ -5,6 +5,7 @@ fastapi==0.72.0
 filetype==1.0.9
 filetype==1.0.9
 gdown==4.2.0
 gdown==4.2.0
 numpy==1.21.5
 numpy==1.21.5
+onnxruntime==1.10.0
 pillow==9.0.0
 pillow==9.0.0
 pymatting==1.1.5
 pymatting==1.1.5
 python-multipart==0.0.5
 python-multipart==0.0.5

+ 5 - 6
setup.py

@@ -14,12 +14,8 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
 with open("requirements.txt") as f:
 with open("requirements.txt") as f:
     requireds = f.read().splitlines()
     requireds = f.read().splitlines()
 
 
-if os.getenv("GPU") is None:
-    with open("requirements-cpu.txt") as f:
-        requireds += f.read().splitlines()
-else:
-    with open("requirements-gpu.txt") as f:
-        requireds += f.read().splitlines()
+with open("requirements-gpu.txt") as f:
+    gpu_requireds = f.read().splitlines()
 
 
 setup(
 setup(
     name="rembg",
     name="rembg",
@@ -42,6 +38,9 @@ setup(
             "rembg=rembg.cli:main",
             "rembg=rembg.cli:main",
         ],
         ],
     },
     },
+    extras_require={
+        'gpu': gpu_requireds,
+    },
     version=versioneer.get_version(),
     version=versioneer.get_version(),
     cmdclass=versioneer.get_cmdclass(),
     cmdclass=versioneer.get_cmdclass(),
 )
 )