Browse Source

add silueta model

Daniel Gatis 2 years ago
parent
commit
23189bf6c8
3 changed files with 14 additions and 2 deletions
  1. 1 0
      README.md
  2. 7 2
      rembg/cli.py
  3. 6 0
      rembg/session_factory.py

+ 1 - 0
README.md

@@ -191,6 +191,7 @@ The available models are:
 -   u2netp ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A lightweight version of u2net model.
 -   u2net_human_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation.
 -   u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body.
+-   silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb.
 
 #### How to train your own model
 

+ 7 - 2
rembg/cli.py

@@ -33,7 +33,9 @@ def main() -> None:
     "-m",
     "--model",
     default="u2net",
-    type=click.Choice(["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg"]),
+    type=click.Choice(
+        ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
+    ),
     show_default=True,
     show_choices=True,
     help="model name",
@@ -100,7 +102,9 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
     "-m",
     "--model",
     default="u2net",
-    type=click.Choice(["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg"]),
+    type=click.Choice(
+        ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
+    ),
     show_default=True,
     show_choices=True,
     help="model name",
@@ -306,6 +310,7 @@ def s(port: int, log_level: str, threads: int) -> None:
         u2netp = "u2netp"
         u2net_human_seg = "u2net_human_seg"
         u2net_cloth_seg = "u2net_cloth_seg"
+        silueta = "silueta"
 
     class CommonQueryParams:
         def __init__(

+ 6 - 0
rembg/session_factory.py

@@ -34,6 +34,12 @@ def new_session(model_name: str) -> BaseSession:
         md5 = "2434d1f3cb744e0e49386c906e5a08bb"
         url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
         session_class = ClothSession
+    elif model_name == "silueta":
+        md5 = "55e59e0d8062d2f5d013f4725ee84782"
+        url = (
+            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
+        )
+        session_class = SimpleSession
     else:
         assert AssertionError(
             "Choose between u2net, u2netp, u2net_human_seg or u2net_cloth_seg"