Prechádzať zdrojové kódy

Merge pull request #761 from blacknoize404/main

Adding dis_custom.py
Daniel Gatis 3 mesiacov pred
rodič
commit
d092c7c7bd

+ 8 - 0
rembg/sessions/__init__.py

@@ -38,6 +38,10 @@ from .dis_anime import DisSession
 
 sessions[DisSession.name()] = DisSession
 
+from .dis_custom import DisCustomSession
+
+sessions[DisCustomSession.name()] = DisCustomSession
+
 from .dis_general_use import DisSession as DisSessionGeneralUse
 
 sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
@@ -74,5 +78,9 @@ from .bria_rmbg import BriaRmBgSession
 
 sessions[BriaRmBgSession.name()] = BriaRmBgSession
 
+from .ben_custom import BenCustomSession
+
+sessions[BenCustomSession.name()] = BenCustomSession
+
 sessions_names = list(sessions.keys())
 sessions_class = list(sessions.values())

+ 96 - 0
rembg/sessions/ben_custom.py

@@ -0,0 +1,96 @@
+import os
+from typing import List
+import onnxruntime as ort
+import numpy as np
+
+from PIL import Image
+from PIL.Image import Image as PILImage
+
+from .base import BaseSession
+
+import numpy as np
+from PIL import Image
+
+
+class BenCustomSession(BaseSession):
+    """This is a class representing a custom session for the Ben model."""
+
+    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
+        """
+        Initialize a new BenCustomSession object.
+
+        Parameters:
+            model_name (str): The name of the model.
+            sess_opts: The session options.
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+        """
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        super().__init__(model_name, sess_opts, *args, **kwargs)
+
+    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
+        """
+        Predicts the mask image for the input image.
+
+        This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
+
+        Parameters:
+            img (PILImage): The input image.
+
+        Returns:
+            List[PILImage]: A list of PILImage objects representing the generated mask image.
+        """
+
+        ort_outs = self.inner_session.run(
+            None,
+            self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
+        )
+
+        pred = ort_outs[0][:, 0, :, :]
+
+        ma = np.max(pred)
+        mi = np.min(pred)
+
+        pred = (pred - mi) / (ma - mi)
+        pred = np.squeeze(pred)
+
+        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.Resampling.LANCZOS)
+
+
+        return [mask]
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Download the model files.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The absolute path to the model files.
+        """
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        return os.path.abspath(os.path.expanduser(model_path))
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Get the name of the model.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the model.
+        """
+        return "ben_custom"

+ 91 - 0
rembg/sessions/dis_custom.py

@@ -0,0 +1,91 @@
+import os
+from typing import List
+import onnxruntime as ort
+import numpy as np
+
+from PIL import Image
+from PIL.Image import Image as PILImage
+
+from .base import BaseSession
+
+class DisCustomSession(BaseSession):
+    """This is a class representing a custom session for the Dis model."""
+
+    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
+        """
+        Initialize a new DisCustomSession object.
+
+        Parameters:
+            model_name (str): The name of the model.
+            sess_opts: The session options.
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+        """
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        super().__init__(model_name, sess_opts, *args, **kwargs)
+
+
+    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
+        """
+        Predicts the mask image for the input image.
+
+        This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
+
+        Parameters:
+            img (PILImage): The input image.
+
+        Returns:
+            List[PILImage]: A list of PILImage objects representing the generated mask image.
+        """
+        ort_outs = self.inner_session.run(
+            None,
+            self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
+        )
+
+        pred = ort_outs[0][:, 0, :, :]
+
+        ma = np.max(pred)
+        mi = np.min(pred)
+
+        pred = (pred - mi) / (ma - mi)
+        pred = np.squeeze(pred)
+
+        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.Resampling.LANCZOS)
+
+        return [mask]
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Download the model files.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The absolute path to the model files.
+        """
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        return os.path.abspath(os.path.expanduser(model_path))
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Get the name of the model.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the model.
+        """
+        return "dis_custom"

+ 1 - 1
rembg/sessions/u2net_custom.py

@@ -78,7 +78,7 @@ class U2netCustomSession(BaseSession):
         """
         model_path = kwargs.get("model_path")
         if model_path is None:
-            return
+            raise ValueError("model_path is required")
 
         return os.path.abspath(os.path.expanduser(model_path))