瀏覽代碼

Add a custom u2net session (#482)

Daniel Gatis 2 年之前
父節點
當前提交
c0b08f831b

+ 4 - 0
README.md

@@ -164,6 +164,10 @@ Passing extras parameters
 rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
 ```
 
+```
+rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
+```
+
 ### rembg `p`
 
 Used when input and output are folders.

+ 1 - 1
rembg/commands/b_command.py

@@ -107,7 +107,7 @@ def rs_command(
     except Exception:
         pass
 
-    session = new_session(model)
+    session = new_session(model, **kwargs)
     bytes_per_img = image_width * image_height * 3
 
     if output_specifier:

+ 1 - 1
rembg/commands/i_command.py

@@ -90,4 +90,4 @@ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
     except Exception:
         pass
 
-    output.write(remove(input.read(), session=new_session(model), **kwargs))
+    output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs))

+ 1 - 1
rembg/commands/p_command.py

@@ -122,7 +122,7 @@ def p_command(
     except Exception:
         pass
 
-    session = new_session(model)
+    session = new_session(model, **kwargs)
 
     def process(each_input: pathlib.Path) -> None:
         try:

+ 13 - 16
rembg/commands/s_command.py

@@ -186,7 +186,9 @@ def s_command(port: int, log_level: str, threads: int) -> None:
         return Response(
             remove(
                 content,
-                session=sessions.setdefault(commons.model, new_session(commons.model)),
+                session=sessions.setdefault(
+                    commons.model, new_session(commons.model, **kwargs)
+                ),
                 alpha_matting=commons.a,
                 alpha_matting_foreground_threshold=commons.af,
                 alpha_matting_background_threshold=commons.ab,
@@ -245,12 +247,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
         return await asyncify(im_without_bg)(file, commons)  # type: ignore
 
     def gr_app(app):
-        def inference(input_path, model):
+        def inference(input_path, model, cmd_args):
             output_path = "output.png"
+
+            kwargs = {}
+            if cmd_args:
+                kwargs.update(json.loads(cmd_args))
+            kwargs["session"] = new_session(model, **kwargs)
+
             with open(input_path, "rb") as i:
                 with open(output_path, "wb") as o:
                     input = i.read()
-                    output = remove(input, session=new_session(model))
+                    output = remove(input, **kwargs)
                     o.write(output)
             return os.path.join(output_path)
 
@@ -258,19 +266,8 @@ def s_command(port: int, log_level: str, threads: int) -> None:
             inference,
             [
                 gr.components.Image(type="filepath", label="Input"),
-                gr.components.Dropdown(
-                    [
-                        "u2net",
-                        "u2netp",
-                        "u2net_human_seg",
-                        "u2net_cloth_seg",
-                        "silueta",
-                        "isnet-general-use",
-                        "isnet-anime",
-                    ],
-                    value="u2net",
-                    label="Models",
-                ),
+                gr.components.Dropdown(sessions_names, value="u2net", label="Models"),
+                gr.components.Textbox(label="Arguments"),
             ],
             gr.components.Image(type="filepath", label="Output"),
         )

+ 1 - 1
rembg/sessions/base.py

@@ -29,7 +29,7 @@ class BaseSession:
             self.providers.extend(_providers)
 
         self.inner_session = ort.InferenceSession(
-            str(self.__class__.download_models()),
+            str(self.__class__.download_models(*args, **kwargs)),
             providers=self.providers,
             sess_options=sess_opts,
         )

+ 2 - 2
rembg/sessions/dis_anime.py

@@ -31,7 +31,7 @@ class DisSession(BaseSession):
 
     @classmethod
     def download_models(cls, *args, **kwargs):
-        fname = f"{cls.name()}.onnx"
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
             None
@@ -42,7 +42,7 @@ class DisSession(BaseSession):
             progressbar=True,
         )
 
-        return os.path.join(cls.u2net_home(), fname)
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
 
     @classmethod
     def name(cls, *args, **kwargs):

+ 2 - 2
rembg/sessions/dis_general_use.py

@@ -31,7 +31,7 @@ class DisSession(BaseSession):
 
     @classmethod
     def download_models(cls, *args, **kwargs):
-        fname = f"{cls.name()}.onnx"
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
             None
@@ -42,7 +42,7 @@ class DisSession(BaseSession):
             progressbar=True,
         )
 
-        return os.path.join(cls.u2net_home(), fname)
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
 
     @classmethod
     def name(cls, *args, **kwargs):

+ 4 - 4
rembg/sessions/sam.py

@@ -136,8 +136,8 @@ class SamSession(BaseSession):
 
     @classmethod
     def download_models(cls, *args, **kwargs):
-        fname_encoder = f"{cls.name()}_encoder.onnx"
-        fname_decoder = f"{cls.name()}_decoder.onnx"
+        fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
+        fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
 
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
@@ -160,8 +160,8 @@ class SamSession(BaseSession):
         )
 
         return (
-            os.path.join(cls.u2net_home(), fname_encoder),
-            os.path.join(cls.u2net_home(), fname_decoder),
+            os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
+            os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
         )
 
     @classmethod

+ 1 - 1
rembg/sessions/silueta.py

@@ -44,7 +44,7 @@ class SiluetaSession(BaseSession):
             progressbar=True,
         )
 
-        return os.path.join(cls.u2net_home(), fname)
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
 
     @classmethod
     def name(cls, *args, **kwargs):

+ 2 - 2
rembg/sessions/u2net.py

@@ -33,7 +33,7 @@ class U2netSession(BaseSession):
 
     @classmethod
     def download_models(cls, *args, **kwargs):
-        fname = f"{cls.name()}.onnx"
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
             None
@@ -44,7 +44,7 @@ class U2netSession(BaseSession):
             progressbar=True,
         )
 
-        return os.path.join(cls.u2net_home(), fname)
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
 
     @classmethod
     def name(cls, *args, **kwargs):

+ 2 - 2
rembg/sessions/u2net_cloth_seg.py

@@ -94,7 +94,7 @@ class Unet2ClothSession(BaseSession):
 
     @classmethod
     def download_models(cls, *args, **kwargs):
-        fname = f"{cls.name()}.onnx"
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
             None
@@ -105,7 +105,7 @@ class Unet2ClothSession(BaseSession):
             progressbar=True,
         )
 
-        return os.path.join(cls.u2net_home(), fname)
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
 
     @classmethod
     def name(cls, *args, **kwargs):

+ 45 - 0
rembg/sessions/u2net_custom.py

@@ -0,0 +1,45 @@
+import os
+from typing import List
+
+import numpy as np
+import pooch
+from PIL import Image
+from PIL.Image import Image as PILImage
+
+from .base import BaseSession
+
+
+class U2netCustomSession(BaseSession):
+    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
+        ort_outs = self.inner_session.run(
+            None,
+            self.normalize(
+                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
+            ),
+        )
+
+        pred = ort_outs[0][:, 0, :, :]
+
+        ma = np.max(pred)
+        mi = np.min(pred)
+
+        pred = (pred - mi) / (ma - mi)
+        pred = np.squeeze(pred)
+
+        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.LANCZOS)
+
+        return [mask]
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        model_path = kwargs.get("model_path")
+
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        return os.path.abspath(os.path.expanduser(model_path))
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        return "u2net_custom"

+ 2 - 2
rembg/sessions/u2net_human_seg.py

@@ -33,7 +33,7 @@ class U2netHumanSegSession(BaseSession):
 
     @classmethod
     def download_models(cls, *args, **kwargs):
-        fname = f"{cls.name()}.onnx"
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
             None
@@ -44,7 +44,7 @@ class U2netHumanSegSession(BaseSession):
             progressbar=True,
         )
 
-        return os.path.join(cls.u2net_home(), fname)
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
 
     @classmethod
     def name(cls, *args, **kwargs):

+ 2 - 2
rembg/sessions/u2netp.py

@@ -33,7 +33,7 @@ class U2netpSession(BaseSession):
 
     @classmethod
     def download_models(cls, *args, **kwargs):
-        fname = f"{cls.name()}.onnx"
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
         pooch.retrieve(
             "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
             None
@@ -44,7 +44,7 @@ class U2netpSession(BaseSession):
             progressbar=True,
         )
 
-        return os.path.join(cls.u2net_home(), fname)
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
 
     @classmethod
     def name(cls, *args, **kwargs):