Browse Source

Adding ben_custom.py

BlackNoizE404 3 months ago
parent
commit
d43138a2c6
2 changed files with 130 additions and 0 deletions
  1. 4 0
      rembg/sessions/__init__.py
  2. 126 0
      rembg/sessions/ben_custom.py

+ 4 - 0
rembg/sessions/__init__.py

@@ -78,5 +78,9 @@ from .bria_rmbg import BriaRmBgSession
 
 
 sessions[BriaRmBgSession.name()] = BriaRmBgSession
 sessions[BriaRmBgSession.name()] = BriaRmBgSession
 
 
+from .ben_custom import BenCustomSession
+
+sessions[BenCustomSession.name()] = BenCustomSession
+
 sessions_names = list(sessions.keys())
 sessions_names = list(sessions.keys())
 sessions_class = list(sessions.values())
 sessions_class = list(sessions.values())

+ 126 - 0
rembg/sessions/ben_custom.py

@@ -0,0 +1,126 @@
+import os
+from typing import List
+import onnxruntime as ort
+import numpy as np
+
+from PIL import Image
+from PIL.Image import Image as PILImage
+import torch
+
+from .base import BaseSession
+
+import numpy as np
+from PIL import Image
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+
+
+class BenCustomSession(BaseSession):
+    """This is a class representing a custom session for the Ben model."""
+
+    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
+        """
+        Initialize a new BenCustomSession object.
+
+        Parameters:
+            model_name (str): The name of the model.
+            sess_opts: The session options.
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+        """
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        super().__init__(model_name, sess_opts, *args, **kwargs)
+
+    def preprocess_image(self, image):
+        original_size = image.size
+        transform = transforms.Compose([
+            transforms.Resize((1024, 1024)),
+            transforms.ToTensor(),
+        ])
+
+        img_tensor = transform(image)
+
+        img_tensor = img_tensor.unsqueeze(0)
+        return img_tensor.numpy(), image, original_size
+
+    def postprocess_image(self, result_np: np.ndarray, im_size: list) -> np.ndarray:
+
+        result = torch.from_numpy(result_np)
+
+
+        if len(result.shape) == 3:
+            result = result.unsqueeze(0)
+
+
+        result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
+
+
+        ma = torch.max(result)
+        mi = torch.min(result)
+        result = (result - mi) / (ma - mi)
+
+        im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
+        im_array = np.squeeze(im_array)
+        return im_array
+
+    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
+        """
+        Predicts the mask image for the input image.
+
+        This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
+
+        Parameters:
+            img (PILImage): The input image.
+
+        Returns:
+            List[PILImage]: A list of PILImage objects representing the generated mask image.
+        """
+
+        input_data, original_image, (w, h) = self.preprocess_image(img)
+
+        input_name = self.inner_session.get_inputs()[0].name
+
+        outputs = self.inner_session.run(None, {input_name: input_data})
+
+
+        alpha = self.postprocess_image(outputs[0], im_size=[w, h])
+
+        mask = Image.fromarray(alpha, mode="L")
+        mask = mask.resize((w, h), Image.Resampling.LANCZOS)
+
+        return [mask]
+
+    @classmethod
+    def download_models(cls, *args, **kwargs):
+        """
+        Download the model files.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The absolute path to the model files.
+        """
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        return os.path.abspath(os.path.expanduser(model_path))
+
+    @classmethod
+    def name(cls, *args, **kwargs):
+        """
+        Get the name of the model.
+
+        Parameters:
+            *args: Additional positional arguments.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            str: The name of the model.
+        """
+        return "ben-custom"