Browse Source

Adding dis_custom.py

BlackNoizE404 2 months ago
parent
commit
ec98f1668e
2 changed files with 96 additions and 0 deletions
  1. 4 0
      rembg/sessions/__init__.py
  2. 92 0
      rembg/sessions/dis_custom.py

+ 4 - 0
rembg/sessions/__init__.py

@@ -38,6 +38,10 @@ from .dis_anime import DisSession
 
 
 sessions[DisSession.name()] = DisSession
 sessions[DisSession.name()] = DisSession
 
 
+from .dis_custom import DisCustomSession
+
+sessions[DisCustomSession.name()] = DisCustomSession
+
 from .dis_general_use import DisSession as DisSessionGeneralUse
 from .dis_general_use import DisSession as DisSessionGeneralUse
 
 
 sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
 sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse

+ 92 - 0
rembg/sessions/dis_custom.py

@@ -0,0 +1,92 @@
+import os
+from typing import List
+import onnxruntime as ort
+import numpy as np
+
+from PIL import Image
+from PIL.Image import Image as PILImage
+
+from .base import BaseSession
+
+
+class DisCustomSession(BaseSession):
+    """This is a class representing a custom session for the Dis model."""
+
+    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
+        """
+        Initialize a new DisCustomSession object.
+
+        Parameters:
+            model_name (str): The name of the model.
+            sess_opts: The session options.
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+        """
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        super().__init__(model_name, sess_opts, *args, **kwargs)
+
+
+    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
+        """
+        Predicts the mask image for the input image.
+
+        This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
+
+        Parameters:
+            img (PILImage): The input image.
+
+        Returns:
+            List[PILImage]: A list of PILImage objects representing the generated mask image.
+        """
+        ort_outs = self.inner_session.run(
+            None,
+            self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
+        )
+
+        pred = ort_outs[0][:, 0, :, :]
+
+        ma = np.max(pred)
+        mi = np.min(pred)
+
+        pred = (pred - mi) / (ma - mi)
+        pred = np.squeeze(pred)
+
+        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.Resampling.LANCZOS)
+
+        return [mask]
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Download the model files.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The absolute path to the model files.
+        """
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            return
+
+        return os.path.abspath(os.path.expanduser(model_path))
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Get the name of the model.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the model.
+        """
+        return "dis-custom"