Daniel Gatis před 2 roky
rodič
revize
a54000d507

+ 33 - 0
README.md

@@ -265,6 +265,39 @@ The available models are:
 -   u2net_human_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation.
 -   u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body.
 -   silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb.
+-   isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): https://github.com/xuebinqin/DIS.
+
+### Some differences between the models result
+
+<table>
+    <tr>
+        <th>original</th>
+        <th>unet</th>
+        <th>unetp</th>
+        <th>u2net_human_seg</th>
+        <th>u2net_cloth_seg</th>
+        <th>silueta</th>
+        <th>isnet-general-use</th>
+    </tr>
+    <tr>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/fixtures/car-1.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.unet.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.unetp.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.u2net_human_seg.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.u2net_cloth_seg.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.silueta.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/car-1.isnet-general-use.jpg" width="100" /></th>
+    </tr>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/fixtures/cloth-1.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.unet.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.unetp.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.u2net_human_seg.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.u2net_cloth_seg.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.silueta.jpg" width="100" /></th>
+        <th><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/tests/results/cloth-1.isnet-general-use.jpg" width="100" /></th>
+    </tr>
+</table>
+
 
 ### How to train your own model
 

+ 3 - 2
rembg/cli.py

@@ -34,7 +34,7 @@ def main() -> None:
     "--model",
     default="u2net",
     type=click.Choice(
-        ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
+        ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]
     ),
     show_default=True,
     show_choices=True,
@@ -103,7 +103,7 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
     "--model",
     default="u2net",
     type=click.Choice(
-        ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
+        ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]
     ),
     show_default=True,
     show_choices=True,
@@ -311,6 +311,7 @@ def s(port: int, log_level: str, threads: int) -> None:
         u2net_human_seg = "u2net_human_seg"
         u2net_cloth_seg = "u2net_cloth_seg"
         silueta = "silueta"
+        isnet_general_use = "isnet-general-use"
 
     class CommonQueryParams:
         def __init__(

+ 3 - 3
rembg/session_base.py

@@ -7,18 +7,18 @@ from PIL.Image import Image as PILImage
 
 
 class BaseSession:
-    def __init__(self, model_name: str, inner_session: ort.InferenceSession, output_size: Tuple[int, int] = (320, 320)):
+    def __init__(self, model_name: str, inner_session: ort.InferenceSession):
         self.model_name = model_name
         self.inner_session = inner_session
-        self.output_size = output_size
 
     def normalize(
         self,
         img: PILImage,
         mean: Tuple[float, float, float],
         std: Tuple[float, float, float],
+        size: Tuple[int, int],
     ) -> Dict[str, np.ndarray]:
-        im = img.convert("RGB").resize(self.output_size, Image.LANCZOS)
+        im = img.convert("RGB").resize(size, Image.LANCZOS)
 
         im_ary = np.array(im)
         im_ary = im_ary / np.max(im_ary)

+ 1 - 1
rembg/session_cloth.py

@@ -56,7 +56,7 @@ pallete3 = [
 class ClothSession(BaseSession):
     def predict(self, img: PILImage) -> List[PILImage]:
         ort_outs = self.inner_session.run(
-            None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
+            None, self.normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768))
         )
 
         pred = ort_outs

+ 30 - 0
rembg/session_dis.py

@@ -0,0 +1,30 @@
+from typing import List
+
+import numpy as np
+from PIL import Image
+from PIL.Image import Image as PILImage
+
+from .session_base import BaseSession
+
+
+class DisSession(BaseSession):
+    def predict(self, img: PILImage) -> List[PILImage]:
+        ort_outs = self.inner_session.run(
+            None,
+            self.normalize(
+                img, (0.485, 0.456, 0.406), (1., 1., 1.), (1024, 1024)
+            ),
+        )
+
+        pred = ort_outs[0][:, 0, :, :]
+
+        ma = np.max(pred)
+        mi = np.min(pred)
+
+        pred = (pred - mi) / (ma - mi)
+        pred = np.squeeze(pred)
+
+        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.LANCZOS)
+
+        return [mask]

+ 5 - 8
rembg/session_factory.py

@@ -11,12 +11,10 @@ import pooch
 from .session_base import BaseSession
 from .session_cloth import ClothSession
 from .session_simple import SimpleSession
+from .session_dis import DisSession
 
 
-def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
-    # Set output size if not set ( because isnet hat a different size )
-    output_size = output_size or (320, 320)
-
+def new_session(model_name: str = "u2net") -> BaseSession:
     session_class: Type[BaseSession]
     md5 = "60024c5c889badc19c04ad937298a77b"
     url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
@@ -44,8 +42,8 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
         session_class = SimpleSession
     elif model_name == "isnet-general-use":
         md5 = "fc16ebd8b0c10d971d3513d564d01e29"
-        url = "https://github.com/Flippchen/rembg/releases/download/test/isnet-general-use.onnx"
-        session_class = SimpleSession
+        url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
+        session_class = DisSession
 
     u2net_home = os.getenv(
         "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
@@ -74,6 +72,5 @@ def new_session(model_name: str = "u2net", output_size=None) -> BaseSession:
             str(full_path),
             providers=ort.get_available_providers(),
             sess_options=sess_opts,
-        ),
-        output_size=output_size
+        )
     )

+ 2 - 9
rembg/session_simple.py

@@ -1,4 +1,4 @@
-from typing import List, Tuple
+from typing import List
 
 import numpy as np
 from PIL import Image
@@ -9,17 +9,10 @@ from .session_base import BaseSession
 
 class SimpleSession(BaseSession):
     def predict(self, img: PILImage) -> List[PILImage]:
-        if self.model_name == "isnet-general-use":
-            mean = (0.5, 0.5, 0.5)
-            std = (1., 1., 1.)
-        else:
-            mean = (0.485, 0.456, 0.406)
-            std = (0.229, 0.224, 0.225)
-
         ort_outs = self.inner_session.run(
             None,
             self.normalize(
-                img, mean, std
+                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
             ),
         )
 

binární
tests/fixtures/car-1.jpg


binární
tests/fixtures/cloth-1.jpg


binární
tests/results/car-1.isnet-general-use.png


binární
tests/results/car-1.silueta.png


binární
tests/results/car-1.u2net.png


binární
tests/results/car-1.u2net_cloth_seg.png


binární
tests/results/car-1.u2net_human_seg.png


binární
tests/results/car-1.u2netp.png


binární
tests/results/cloth-1.isnet-general-use.png


binární
tests/results/cloth-1.silueta.png


binární
tests/results/cloth-1.u2net.png


binární
tests/results/cloth-1.u2net_cloth_seg.png


binární
tests/results/cloth-1.u2net_human_seg.png


binární
tests/results/cloth-1.u2netp.png


+ 26 - 8
tests/test_remove.py

@@ -1,20 +1,38 @@
 from io import BytesIO
 from pathlib import Path
 
-from imagehash import average_hash
+from imagehash import phash as hash_img
 from PIL import Image
 
 from rembg import remove
+from rembg import new_session
 
 here = Path(__file__).parent.resolve()
 
-
 def test_remove():
-    image = Path(here / ".." / "examples" / "animal-1.jpg").read_bytes()
-    expected = Path(here / ".." / "examples" / "animal-1.out.png").read_bytes()
-    actual = remove(image)
+    for model in ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use"]:
+        for picture in ["car-1", "cloth-1"]:
+            image_path =  Path(here / "fixtures" / f"{picture}.jpg")
+            expected_path = Path(here / "results" / f"{picture}.{model}.png")
+
+            image = image_path.read_bytes()
+            expected = expected_path.read_bytes()
+
+            actual = remove(image, session=new_session(model))
+
+            # Uncomment to update the expected results
+            # f = open(expected_path, "ab")
+            # f.write(actual)
+            # f.close()
+
+            actual_hash = hash_img(Image.open(BytesIO(actual)))
+            expected_hash = hash_img(Image.open(BytesIO(expected)))
 
-    actual_hash = average_hash(Image.open(BytesIO(actual)))
-    expected_hash = average_hash(Image.open(BytesIO(expected)))
+            print(f"image_path: {image_path}")
+            print(f"expected_path: {expected_path}")
+            print(f"actual_hash: {actual_hash}")
+            print(f"expected_hash: {expected_hash}")
+            print(f"actual_hash == expected_hash: {actual_hash == expected_hash}")
+            print("---\n")
 
-    assert actual_hash == expected_hash
+            assert actual_hash == expected_hash