Browse Source

Add BiRefNet-General and BiRefNet-Portrait models as available models (#665)

Dimitri Barbot 1 year ago
parent
commit
d4c40e1c3e
38 changed files with 453 additions and 1 deletions
  1. 7 0
      README.md
  2. 35 0
      rembg/sessions/__init__.py
  3. 52 0
      rembg/sessions/birefnet_cod.py
  4. 52 0
      rembg/sessions/birefnet_dis.py
  5. 91 0
      rembg/sessions/birefnet_general.py
  6. 52 0
      rembg/sessions/birefnet_general_lite.py
  7. 52 0
      rembg/sessions/birefnet_hrsod.py
  8. 52 0
      rembg/sessions/birefnet_massive.py
  9. 52 0
      rembg/sessions/birefnet_portrait.py
  10. BIN
      tests/results/anime-girl-1.birefnet-cod.png
  11. BIN
      tests/results/anime-girl-1.birefnet-dis.png
  12. BIN
      tests/results/anime-girl-1.birefnet-general-lite.png
  13. BIN
      tests/results/anime-girl-1.birefnet-general.png
  14. BIN
      tests/results/anime-girl-1.birefnet-hrsod.png
  15. BIN
      tests/results/anime-girl-1.birefnet-massive.png
  16. BIN
      tests/results/anime-girl-1.birefnet-portrait.png
  17. BIN
      tests/results/car-1.birefnet-cod.png
  18. BIN
      tests/results/car-1.birefnet-dis.png
  19. BIN
      tests/results/car-1.birefnet-general-lite.png
  20. BIN
      tests/results/car-1.birefnet-general.png
  21. BIN
      tests/results/car-1.birefnet-hrsod.png
  22. BIN
      tests/results/car-1.birefnet-massive.png
  23. BIN
      tests/results/car-1.birefnet-portrait.png
  24. BIN
      tests/results/cloth-1.birefnet-cod.png
  25. BIN
      tests/results/cloth-1.birefnet-dis.png
  26. BIN
      tests/results/cloth-1.birefnet-general-lite.png
  27. BIN
      tests/results/cloth-1.birefnet-general.png
  28. BIN
      tests/results/cloth-1.birefnet-hrsod.png
  29. BIN
      tests/results/cloth-1.birefnet-massive.png
  30. BIN
      tests/results/cloth-1.birefnet-portrait.png
  31. BIN
      tests/results/plants-1.birefnet-cod.png
  32. BIN
      tests/results/plants-1.birefnet-dis.png
  33. BIN
      tests/results/plants-1.birefnet-general-lite.png
  34. BIN
      tests/results/plants-1.birefnet-general.png
  35. BIN
      tests/results/plants-1.birefnet-hrsod.png
  36. BIN
      tests/results/plants-1.birefnet-massive.png
  37. BIN
      tests/results/plants-1.birefnet-portrait.png
  38. 8 1
      tests/test_remove.py

+ 7 - 0
README.md

@@ -334,6 +334,13 @@ The available models are:
 - isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases.
 - isnet-anime ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx), [source](https://github.com/SkyTNT/anime-segmentation)): A high-accuracy segmentation for anime character.
 - sam ([download encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [download decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): A pre-trained model for any use cases.
+- birefnet-general ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for general use cases.
+- birefnet-general-lite ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A light pre-trained model for general use cases.
+- birefnet-portrait ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for human portraits.
+- birefnet-dis ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for dichotomous image segmentation (DIS).
+- birefnet-hrsod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for high-resolution salient object detection (HRSOD).
+- birefnet-cod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for concealed object detection (COD).
+- birefnet-massive ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model with massive dataset.
 
 ### How to train your own model
 

+ 35 - 0
rembg/sessions/__init__.py

@@ -7,6 +7,41 @@ from .base import BaseSession
 sessions_class: List[type[BaseSession]] = []
 sessions_names: List[str] = []
 
+from .birefnet_general import BiRefNetSessionGeneral
+
+sessions_class.append(BiRefNetSessionGeneral)
+sessions_names.append(BiRefNetSessionGeneral.name())
+
+from .birefnet_general_lite import BiRefNetSessionGeneralLite
+
+sessions_class.append(BiRefNetSessionGeneralLite)
+sessions_names.append(BiRefNetSessionGeneralLite.name())
+
+from .birefnet_portrait import BiRefNetSessionPortrait
+
+sessions_class.append(BiRefNetSessionPortrait)
+sessions_names.append(BiRefNetSessionPortrait.name())
+
+from .birefnet_dis import BiRefNetSessionDIS
+
+sessions_class.append(BiRefNetSessionDIS)
+sessions_names.append(BiRefNetSessionDIS.name())
+
+from .birefnet_hrsod import BiRefNetSessionHRSOD
+
+sessions_class.append(BiRefNetSessionHRSOD)
+sessions_names.append(BiRefNetSessionHRSOD.name())
+
+from .birefnet_cod import BiRefNetSessionCOD
+
+sessions_class.append(BiRefNetSessionCOD)
+sessions_names.append(BiRefNetSessionCOD.name())
+
+from .birefnet_massive import BiRefNetSessionMassive
+
+sessions_class.append(BiRefNetSessionMassive)
+sessions_names.append(BiRefNetSessionMassive.name())
+
 from .dis_anime import DisSession
 
 sessions_class.append(DisSession)

+ 52 - 0
rembg/sessions/birefnet_cod.py

@@ -0,0 +1,52 @@
+import os
+
+import pooch
+
+from . import BiRefNetSessionGeneral
+
+
+class BiRefNetSessionCOD(BiRefNetSessionGeneral):
+    """
+    This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral.
+    """
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Downloads the BiRefNet-COD model file from a specific URL and saves it.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The path to the downloaded model file.
+        """
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
+        pooch.retrieve(
+            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx",
+            (
+                None
+                if cls.checksum_disabled(*args, **kwargs)
+                else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45"
+            ),
+            fname=fname,
+            path=cls.u2net_home(*args, **kwargs),
+            progressbar=True,
+        )
+
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Returns the name of the BiRefNet-COD session.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the session.
+        """
+        return "birefnet-cod"

+ 52 - 0
rembg/sessions/birefnet_dis.py

@@ -0,0 +1,52 @@
+import os
+
+import pooch
+
+from . import BiRefNetSessionGeneral
+
+
+class BiRefNetSessionDIS(BiRefNetSessionGeneral):
+    """
+    This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral.
+    """
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Downloads the BiRefNet-DIS model file from a specific URL and saves it.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The path to the downloaded model file.
+        """
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
+        pooch.retrieve(
+            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx",
+            (
+                None
+                if cls.checksum_disabled(*args, **kwargs)
+                else "md5:2d4d44102b446f33a4ebb2e56c051f2b"
+            ),
+            fname=fname,
+            path=cls.u2net_home(*args, **kwargs),
+            progressbar=True,
+        )
+
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Returns the name of the BiRefNet-DIS session.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the session.
+        """
+        return "birefnet-dis"

+ 91 - 0
rembg/sessions/birefnet_general.py

@@ -0,0 +1,91 @@
+import os
+from typing import List
+
+import numpy as np
+import pooch
+from PIL import Image
+from PIL.Image import Image as PILImage
+
+from .base import BaseSession
+
+
+class BiRefNetSessionGeneral(BaseSession):
+    """
+    This class represents a BiRefNet-General session, which is a subclass of BaseSession.
+    """
+
+    def sigmoid(self, mat):
+        return 1 / (1 + np.exp(-mat))
+
+    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
+        """
+        Predicts the output masks for the input image using the inner session.
+
+        Parameters:
+            img (PILImage): The input image.
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            List[PILImage]: The list of output masks.
+        """
+        ort_outs = self.inner_session.run(
+            None,
+            self.normalize(
+                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)
+            ),
+        )
+
+        pred = self.sigmoid(ort_outs[0][:, 0, :, :])
+
+        ma = np.max(pred)
+        mi = np.min(pred)
+
+        pred = (pred - mi) / (ma - mi)
+        pred = np.squeeze(pred)
+
+        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.Resampling.LANCZOS)
+
+        return [mask]
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Downloads the BiRefNet-General model file from a specific URL and saves it.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The path to the downloaded model file.
+        """
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
+        pooch.retrieve(
+            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx",
+            (
+                None
+                if cls.checksum_disabled(*args, **kwargs)
+                else "md5:7a35a0141cbbc80de11d9c9a28f52697"
+            ),
+            fname=fname,
+            path=cls.u2net_home(*args, **kwargs),
+            progressbar=True,
+        )
+
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Returns the name of the BiRefNet-General session.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the session.
+        """
+        return "birefnet-general"

+ 52 - 0
rembg/sessions/birefnet_general_lite.py

@@ -0,0 +1,52 @@
+import os
+
+import pooch
+
+from . import BiRefNetSessionGeneral
+
+
+class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral):
+    """
+    This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral.
+    """
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Downloads the BiRefNet-General-Lite model file from a specific URL and saves it.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The path to the downloaded model file.
+        """
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
+        pooch.retrieve(
+            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
+            (
+                None
+                if cls.checksum_disabled(*args, **kwargs)
+                else "md5:4fab47adc4ff364be1713e97b7e66334"
+            ),
+            fname=fname,
+            path=cls.u2net_home(*args, **kwargs),
+            progressbar=True,
+        )
+
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Returns the name of the BiRefNet-General-Lite session.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the session.
+        """
+        return "birefnet-general-lite"

+ 52 - 0
rembg/sessions/birefnet_hrsod.py

@@ -0,0 +1,52 @@
+import os
+
+import pooch
+
+from . import BiRefNetSessionGeneral
+
+
+class BiRefNetSessionHRSOD(BiRefNetSessionGeneral):
+    """
+    This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral.
+    """
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Downloads the BiRefNet-HRSOD model file from a specific URL and saves it.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The path to the downloaded model file.
+        """
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
+        pooch.retrieve(
+            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx",
+            (
+                None
+                if cls.checksum_disabled(*args, **kwargs)
+                else "md5:c017ade5de8a50ff0fd74d790d268dda"
+            ),
+            fname=fname,
+            path=cls.u2net_home(*args, **kwargs),
+            progressbar=True,
+        )
+
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Returns the name of the BiRefNet-HRSOD session.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the session.
+        """
+        return "birefnet-hrsod"

+ 52 - 0
rembg/sessions/birefnet_massive.py

@@ -0,0 +1,52 @@
+import os
+
+import pooch
+
+from . import BiRefNetSessionGeneral
+
+
+class BiRefNetSessionMassive(BiRefNetSessionGeneral):
+    """
+    This class represents a BiRefNet-Massive session, which is a subclass of BiRefNetSessionGeneral.
+    """
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Downloads the BiRefNet-Massive model file from a specific URL and saves it.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The path to the downloaded model file.
+        """
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
+        pooch.retrieve(
+            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx",
+            (
+                None
+                if cls.checksum_disabled(*args, **kwargs)
+                else "md5:33e726a2136a3d59eb0fdf613e31e3e9"
+            ),
+            fname=fname,
+            path=cls.u2net_home(*args, **kwargs),
+            progressbar=True,
+        )
+
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Returns the name of the BiRefNet-Massive session.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the session.
+        """
+        return "birefnet-massive"

+ 52 - 0
rembg/sessions/birefnet_portrait.py

@@ -0,0 +1,52 @@
+import os
+
+import pooch
+
+from . import BiRefNetSessionGeneral
+
+
+class BiRefNetSessionPortrait(BiRefNetSessionGeneral):
+    """
+    This class represents a BiRefNet-Portrait session, which is a subclass of BiRefNetSessionGeneral.
+    """
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Downloads the BiRefNet-Portrait model file from a specific URL and saves it.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The path to the downloaded model file.
+        """
+        fname = f"{cls.name(*args, **kwargs)}.onnx"
+        pooch.retrieve(
+            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx",
+            (
+                None
+                if cls.checksum_disabled(*args, **kwargs)
+                else "md5:c3a64a6abf20250d090cd055f12a3b67"
+            ),
+            fname=fname,
+            path=cls.u2net_home(*args, **kwargs),
+            progressbar=True,
+        )
+
+        return os.path.join(cls.u2net_home(*args, **kwargs), fname)
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Returns the name of the BiRefNet-Portrait session.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the session.
+        """
+        return "birefnet-portrait"

BIN
tests/results/anime-girl-1.birefnet-cod.png


BIN
tests/results/anime-girl-1.birefnet-dis.png


BIN
tests/results/anime-girl-1.birefnet-general-lite.png


BIN
tests/results/anime-girl-1.birefnet-general.png


BIN
tests/results/anime-girl-1.birefnet-hrsod.png


BIN
tests/results/anime-girl-1.birefnet-massive.png


BIN
tests/results/anime-girl-1.birefnet-portrait.png


BIN
tests/results/car-1.birefnet-cod.png


BIN
tests/results/car-1.birefnet-dis.png


BIN
tests/results/car-1.birefnet-general-lite.png


BIN
tests/results/car-1.birefnet-general.png


BIN
tests/results/car-1.birefnet-hrsod.png


BIN
tests/results/car-1.birefnet-massive.png


BIN
tests/results/car-1.birefnet-portrait.png


BIN
tests/results/cloth-1.birefnet-cod.png


BIN
tests/results/cloth-1.birefnet-dis.png


BIN
tests/results/cloth-1.birefnet-general-lite.png


BIN
tests/results/cloth-1.birefnet-general.png


BIN
tests/results/cloth-1.birefnet-hrsod.png


BIN
tests/results/cloth-1.birefnet-massive.png


BIN
tests/results/cloth-1.birefnet-portrait.png


BIN
tests/results/plants-1.birefnet-cod.png


BIN
tests/results/plants-1.birefnet-dis.png


BIN
tests/results/plants-1.birefnet-general-lite.png


BIN
tests/results/plants-1.birefnet-general.png


BIN
tests/results/plants-1.birefnet-hrsod.png


BIN
tests/results/plants-1.birefnet-massive.png


BIN
tests/results/plants-1.birefnet-portrait.png


+ 8 - 1
tests/test_remove.py

@@ -37,7 +37,14 @@ def test_remove():
         "silueta",
         "isnet-general-use",
         "isnet-anime",
-        "sam"
+        "sam",
+        "birefnet-general",
+        "birefnet-general-lite",
+        "birefnet-portrait",
+        "birefnet-dis",
+        "birefnet-hrsod",
+        "birefnet-cod",
+        "birefnet-massive"
     ]:
         for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]:
             image_path = Path(here / "fixtures" / f"{picture}.jpg")