Răsfoiți Sursa

add more models

Daniel Gatis 3 ani în urmă
părinte
comite
738cf156af
10 a modificat fișierele cu 266 adăugiri și 183 ștergeri
  1. 1 0
      Dockerfile
  2. 10 0
      README.md
  3. 47 25
      rembg/bg.py
  4. 10 10
      rembg/cli.py
  5. 0 147
      rembg/detect.py
  6. 39 0
      rembg/session_base.py
  7. 65 0
      rembg/session_cloth.py
  8. 63 0
      rembg/session_factory.py
  9. 29 0
      rembg/session_simple.py
  10. 2 1
      requirements.txt

+ 1 - 0
Dockerfile

@@ -18,6 +18,7 @@ RUN mkdir -p ~/.u2net
 RUN gdown https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR -O ~/.u2net/u2netp.onnx
 RUN gdown https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab -O ~/.u2net/u2net.onnx
 RUN gdown https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j -O ~/.u2net/u2net_human_seg.onnx
+RUN gdown https://drive.google.com/uc?id=15rKbQSXQzrKCQurUjZFg8HqzZad8bcyz -O ~/.u2net/u2net_cloth_seg.onnx
 
 ENTRYPOINT ["rembg"]
 CMD ["--help"]

+ 10 - 0
README.md

@@ -138,6 +138,16 @@ Try this:
 cat in.png | docker run -i --rm danielgatis/rembg i > out.png
 ```
 
+### Models
+
+All models are downloaded and saved in the user home folder in the `.u2net` directory.
+
+The available models are:
+
+- u2net ([download](https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for general use cases.
+- u2netp ([download](https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR), [source](https://github.com/xuebinqin/U-2-Net)): A lightweight version of u2net model.
+- u2net_human_seg ([download](https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation.
+- u2net_cloth_seg ([download](https://drive.google.com/uc?id=15rKbQSXQzrKCQurUjZFg8HqzZad8bcyz), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body.
 ### Advance usage
 
 Sometimes it is possible to achieve better results by turning on alpha matting. Example:

+ 47 - 25
rembg/bg.py

@@ -1,9 +1,8 @@
 import io
 from enum import Enum
-from typing import Optional, Union
+from typing import List, Optional, Union
 
 import numpy as np
-import onnxruntime as ort
 from PIL import Image
 from PIL.Image import Image as PILImage
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
@@ -11,7 +10,8 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
 from pymatting.util.util import stack_images
 from scipy.ndimage.morphology import binary_erosion
 
-from .detect import ort_session, predict
+from .session_base import BaseSession
+from .session_factory import new_session
 
 
 class ReturnType(Enum):
@@ -65,13 +65,27 @@ def naive_cutout(img: Image, mask: Image) -> Image:
     return cutout
 
 
+def get_concat_v_multi(imgs: List[Image]) -> Image:
+    pivot = imgs.pop(0)
+    for im in imgs:
+        pivot = get_concat_v(pivot, im)
+    return pivot
+
+
+def get_concat_v(img1: Image, img2: Image) -> Image:
+    dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
+    dst.paste(img1, (0, 0))
+    dst.paste(img2, (0, img1.height))
+    return dst
+
+
 def remove(
     data: Union[bytes, PILImage, np.ndarray],
     alpha_matting: bool = False,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_erode_size: int = 10,
-    session: Optional[ort.InferenceSession] = None,
+    session: Optional[BaseSession] = None,
     only_mask: bool = False,
 ) -> Union[bytes, PILImage, np.ndarray]:
 
@@ -88,27 +102,35 @@ def remove(
         raise ValueError("Input type {} is not supported.".format(type(data)))
 
     if session is None:
-        session = ort_session("u2net")
-
-    img = img.convert("RGB")
-
-    mask = predict(session, np.array(img))
-    mask = mask.convert("L")
-    mask = mask.resize(img.size, Image.LANCZOS)
-
-    if only_mask:
-        cutout = mask
-
-    elif alpha_matting:
-        cutout = alpha_matting_cutout(
-            img,
-            mask,
-            alpha_matting_foreground_threshold,
-            alpha_matting_background_threshold,
-            alpha_matting_erode_size,
-        )
-    else:
-        cutout = naive_cutout(img, mask)
+        session = new_session("u2net")
+
+    masks = session.predict(img)
+    cutouts = []
+
+    for mask in masks:
+        if only_mask:
+            cutout = mask
+
+        elif alpha_matting:
+            try:
+                cutout = alpha_matting_cutout(
+                    img,
+                    mask,
+                    alpha_matting_foreground_threshold,
+                    alpha_matting_background_threshold,
+                    alpha_matting_erode_size,
+                )
+            except:
+                cutout = naive_cutout(img, mask)
+
+        else:
+            cutout = naive_cutout(img, mask)
+
+        cutouts.append(cutout)
+
+    cutout = img
+    if len(cutouts) > 0:
+        cutout = get_concat_v_multi(cutouts)
 
     if ReturnType.PILLOW == return_type:
         return cutout

+ 10 - 10
rembg/cli.py

@@ -2,13 +2,11 @@ import pathlib
 import sys
 import time
 from enum import Enum
-from typing import IO, Optional, cast
+from typing import IO, cast
 
 import aiohttp
 import click
 import filetype
-import onnxruntime as ort
-import requests
 import uvicorn
 from asyncer import asyncify
 from fastapi import Depends, FastAPI, File, Query
@@ -19,7 +17,8 @@ from watchdog.observers import Observer
 
 from . import _version
 from .bg import remove
-from .detect import ort_session
+from .session_base import BaseSession
+from .session_factory import new_session
 
 
 @click.group()
@@ -33,7 +32,7 @@ def main() -> None:
     "-m",
     "--model",
     default="u2net",
-    type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
+    type=click.Choice(["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg"]),
     show_default=True,
     show_choices=True,
     help="model name",
@@ -85,7 +84,7 @@ def main() -> None:
     type=click.File("wb", lazy=True),
 )
 def i(model: str, input: IO, output: IO, **kwargs) -> None:
-    output.write(remove(input.read(), session=ort_session(model), **kwargs))
+    output.write(remove(input.read(), session=new_session(model), **kwargs))
 
 
 @main.command(help="for a folder as input")
@@ -93,7 +92,7 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
     "-m",
     "--model",
     default="u2net",
-    type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
+    type=click.Choice(["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg"]),
     show_default=True,
     show_choices=True,
     help="model name",
@@ -167,7 +166,7 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
 def p(
     model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
 ) -> None:
-    session = ort_session(model)
+    session = new_session(model)
 
     def process(each_input: pathlib.Path) -> None:
         try:
@@ -244,7 +243,7 @@ def p(
     help="log level",
 )
 def s(port: int, log_level: str) -> None:
-    sessions: dict[str, ort.InferenceSession] = {}
+    sessions: dict[str, BaseSession] = {}
     tags_metadata = [
         {
             "name": "Background Removal",
@@ -275,6 +274,7 @@ def s(port: int, log_level: str) -> None:
         u2net = "u2net"
         u2netp = "u2netp"
         u2net_human_seg = "u2net_human_seg"
+        u2net_cloth_seg = "u2net_cloth_seg"
 
     class CommonQueryParams:
         def __init__(
@@ -307,7 +307,7 @@ def s(port: int, log_level: str) -> None:
             remove(
                 content,
                 session=sessions.setdefault(
-                    commons.model.value, ort_session(commons.model.value)
+                    commons.model.value, new_session(commons.model.value)
                 ),
                 alpha_matting=commons.a,
                 alpha_matting_foreground_threshold=commons.af,

+ 0 - 147
rembg/detect.py

@@ -1,147 +0,0 @@
-import hashlib
-import os
-import sys
-from contextlib import redirect_stdout
-from pathlib import Path
-
-import gdown
-import numpy as np
-import onnxruntime as ort
-from PIL import Image
-from skimage import transform
-
-
-def ort_session(model_name: str) -> ort.InferenceSession:
-    if model_name == "u2netp":
-        md5 = "8e83ca70e441ab06c318d82300c84806"
-        url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
-    elif model_name == "u2net":
-        md5 = "60024c5c889badc19c04ad937298a77b"
-        url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
-    elif model_name == "u2net_human_seg":
-        md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
-        url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
-    else:
-        assert AssertionError("Choose between u2net, u2netp or u2net_human_seg")
-
-    home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
-    path = Path(home).expanduser() / f"{model_name}.onnx"
-    path.parents[0].mkdir(parents=True, exist_ok=True)
-
-    if not path.exists():
-        with redirect_stdout(sys.stderr):
-            gdown.download(url, str(path), use_cookies=False)
-    else:
-        hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
-        if hashing.hexdigest() != md5:
-            with redirect_stdout(sys.stderr):
-                gdown.download(url, str(path), use_cookies=False)
-
-    sess_opts = ort.SessionOptions()
-
-    if "OMP_NUM_THREADS" in os.environ:
-        sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
-        sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
-
-    return ort.InferenceSession(
-        str(path), providers=ort.get_available_providers(), sess_options=sess_opts
-    )
-
-
-def norm_pred(d: np.ndarray) -> np.ndarray:
-    ma = np.max(d)
-    mi = np.min(d)
-    return (d - mi) / (ma - mi)
-
-
-def rescale(sample: dict, output_size: int) -> dict:
-    imidx, image, label = sample["imidx"], sample["image"], sample["label"]
-
-    h, w = image.shape[:2]
-
-    if isinstance(output_size, int):
-        if h > w:
-            new_h, new_w = output_size * h / w, output_size
-        else:
-            new_h, new_w = output_size, output_size * w / h
-    else:
-        new_h, new_w = output_size
-
-    new_h, new_w = int(new_h), int(new_w)
-
-    img = transform.resize(image, (output_size, output_size), mode="constant")
-    lbl = transform.resize(
-        label,
-        (output_size, output_size),
-        mode="constant",
-        order=0,
-        preserve_range=True,
-    )
-
-    return {"imidx": imidx, "image": img, "label": lbl}
-
-
-def color(sample: dict) -> dict:
-    imidx, image, label = sample["imidx"], sample["image"], sample["label"]
-
-    tmpLbl = np.zeros(label.shape)
-
-    if np.max(label) < 1e-6:
-        label = label
-    else:
-        label = label / np.max(label)
-
-    tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
-    image = image / np.max(image)
-
-    if image.shape[2] == 1:
-        tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
-        tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
-        tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
-    else:
-        tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
-        tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
-        tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
-
-    tmpLbl[:, :, 0] = label[:, :, 0]
-    tmpImg = tmpImg.transpose((2, 0, 1))
-    tmpLbl = label.transpose((2, 0, 1))
-
-    return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
-
-
-def preprocess(im_array: np.ndarray) -> dict:
-    label_3 = np.zeros(im_array.shape)
-    label = np.zeros(label_3.shape[0:2])
-
-    if 3 == len(label_3.shape):
-        label = label_3[:, :, 0]
-    elif 2 == len(label_3.shape):
-        label = label_3
-
-    if 3 == len(im_array.shape) and 2 == len(label.shape):
-        label = label[:, :, np.newaxis]
-    elif 2 == len(im_array.shape) and 2 == len(label.shape):
-        im_array = im_array[:, :, np.newaxis]
-        label = label[:, :, np.newaxis]
-
-    sample = {"imidx": np.array([0]), "image": im_array, "label": label}
-    sample = rescale(sample, 320)
-    sample = color(sample)
-
-    return sample
-
-
-def predict(ort_session: ort.InferenceSession, im_array: np.ndarray) -> Image:
-    sample = preprocess(im_array)
-    inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
-
-    ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}
-    ort_outs = ort_session.run(None, ort_inputs)
-
-    d1 = ort_outs[0]
-    pred = d1[:, 0, :, :]
-    predict = np.squeeze(norm_pred(pred))
-    img = Image.fromarray(predict * 255).convert("RGB")
-
-    return img

+ 39 - 0
rembg/session_base.py

@@ -0,0 +1,39 @@
+from typing import Dict, List, Tuple
+
+import numpy as np
+import onnxruntime as ort
+from PIL import Image
+
+
+class BaseSession:
+    def __init__(self, model_name: str, inner_session: ort.InferenceSession):
+        self.model_name = model_name
+        self.inner_session = inner_session
+
+    def normalize(
+        self,
+        img: Image,
+        mean: Tuple[float, float, float],
+        std: Tuple[float, float, float],
+        size: Tuple[int, int],
+    ) -> Dict[str, np.ndarray]:
+        im = img.convert("RGB").resize(size, Image.LANCZOS)
+
+        im_ary = np.array(im)
+        im_ary = im_ary / np.max(im_ary)
+
+        tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
+        tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
+        tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
+        tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
+
+        tmpImg = tmpImg.transpose((2, 0, 1))
+
+        return {
+            self.inner_session.get_inputs()[0]
+            .name: np.expand_dims(tmpImg, 0)
+            .astype(np.float32)
+        }
+
+    def predict(self, im: Image) -> List[Image]:
+        raise NotImplementedError

+ 65 - 0
rembg/session_cloth.py

@@ -0,0 +1,65 @@
+from typing import List
+
+import numpy as np
+from PIL import Image
+from scipy.special import log_softmax
+
+from .session_base import BaseSession
+
+# fmt: off
+pallete1 = [
+      0,   0,   0, # background
+    255, 255, 255, # upper body
+      0,   0,   0, # lower body
+      0,   0,   0, # full body
+]
+
+pallete2 = [
+      0,   0,   0, # background
+      0,   0,   0, # upper body
+    255, 255, 255, # lower body
+      0,   0,   0, # full body
+]
+
+pallete3 = [
+      0,   0,   0, # background
+      0,   0,   0, # upper body
+      0,   0,   0, # lower body
+    255, 255, 255, # full body
+]
+# fmt: on
+
+
+class ClothSession(BaseSession):
+    def predict(self, img: Image) -> List[Image]:
+        ort_outs = self.inner_session.run(
+            None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
+        )
+
+        pred = ort_outs
+        pred = log_softmax(pred[0], 1)
+        pred = np.argmax(pred, axis=1, keepdims=True)
+        pred = np.squeeze(pred, 0)
+        pred = np.squeeze(pred, 0)
+
+        mask = Image.fromarray(pred.astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.LANCZOS)
+
+        masks = []
+
+        mask1 = mask.copy()
+        mask1.putpalette(pallete1)
+        mask1 = mask1.convert("RGB").convert("L")
+        masks.append(mask1)
+
+        mask2 = mask.copy()
+        mask2.putpalette(pallete2)
+        mask2 = mask2.convert("RGB").convert("L")
+        masks.append(mask2)
+
+        mask3 = mask.copy()
+        mask3.putpalette(pallete3)
+        mask3 = mask3.convert("RGB").convert("L")
+        masks.append(mask3)
+
+        return masks

+ 63 - 0
rembg/session_factory.py

@@ -0,0 +1,63 @@
+import hashlib
+import os
+import sys
+from contextlib import redirect_stdout
+from pathlib import Path
+from typing import Type
+
+import gdown
+import onnxruntime as ort
+
+from .session_base import BaseSession
+from .session_cloth import ClothSession
+from .session_simple import SimpleSession
+
+
+def new_session(model_name: str) -> BaseSession:
+    session_class: Type[BaseSession]
+
+    if model_name == "u2netp":
+        md5 = "8e83ca70e441ab06c318d82300c84806"
+        url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
+        session_class = SimpleSession
+    elif model_name == "u2net":
+        md5 = "60024c5c889badc19c04ad937298a77b"
+        url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
+        session_class = SimpleSession
+    elif model_name == "u2net_human_seg":
+        md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
+        url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
+        session_class = SimpleSession
+    elif model_name == "u2net_cloth_seg":
+        md5 = "2434d1f3cb744e0e49386c906e5a08bb"
+        url = "https://drive.google.com/uc?id=15rKbQSXQzrKCQurUjZFg8HqzZad8bcyz"
+        session_class = ClothSession
+    else:
+        assert AssertionError(
+            "Choose between u2net, u2netp, u2net_human_seg or u2net_cloth_seg"
+        )
+
+    home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
+    path = Path(home).expanduser() / f"{model_name}.onnx"
+    path.parents[0].mkdir(parents=True, exist_ok=True)
+
+    if not path.exists():
+        with redirect_stdout(sys.stderr):
+            gdown.download(url, str(path), use_cookies=False)
+    else:
+        hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
+        if hashing.hexdigest() != md5:
+            with redirect_stdout(sys.stderr):
+                gdown.download(url, str(path), use_cookies=False)
+
+    sess_opts = ort.SessionOptions()
+
+    if "OMP_NUM_THREADS" in os.environ:
+        sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
+
+    return session_class(
+        model_name,
+        ort.InferenceSession(
+            str(path), providers=ort.get_available_providers(), sess_options=sess_opts
+        ),
+    )

+ 29 - 0
rembg/session_simple.py

@@ -0,0 +1,29 @@
+from typing import List
+
+import numpy as np
+from PIL import Image
+
+from .session_base import BaseSession
+
+
+class SimpleSession(BaseSession):
+    def predict(self, img: Image) -> List[Image]:
+        ort_outs = self.inner_session.run(
+            None,
+            self.normalize(
+                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
+            ),
+        )
+
+        pred = ort_outs[0][:, 0, :, :]
+
+        ma = np.max(pred)
+        mi = np.min(pred)
+
+        pred = (pred - mi) / (ma - mi)
+        pred = np.squeeze(pred)
+
+        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.LANCZOS)
+
+        return [mask]

+ 2 - 1
requirements.txt

@@ -4,13 +4,14 @@ click==8.0.3
 fastapi==0.72.0
 filetype==1.0.9
 gdown==4.4.0
-numpy==1.21.5
+numpy==1.22.3
 onnxruntime==1.10.0
 pillow==9.0.1
 pymatting==1.1.5
 python-multipart==0.0.5
 scikit-image==0.19.1
 scipy==1.7.3
+scipy==1.8.0
 tqdm==4.62.3
 uvicorn==0.17.0
 watchdog==2.1.7