BlackNoizE404 1 ヶ月 前
コミット
12f9f71f01
2 ファイル変更12 行追加42 行削除
  1. 11 41
      rembg/sessions/ben_custom.py
  2. 1 1
      rembg/sessions/dis_custom.py

+ 11 - 41
rembg/sessions/ben_custom.py

@@ -5,14 +5,11 @@ import numpy as np
 
 from PIL import Image
 from PIL.Image import Image as PILImage
-import torch
 
 from .base import BaseSession
 
 import numpy as np
 from PIL import Image
-import torchvision.transforms as transforms
-import torch.nn.functional as F
 
 
 class BenCustomSession(BaseSession):
@@ -34,38 +31,6 @@ class BenCustomSession(BaseSession):
 
         super().__init__(model_name, sess_opts, *args, **kwargs)
 
-    def preprocess_image(self, image):
-        original_size = image.size
-        transform = transforms.Compose([
-            transforms.Resize((1024, 1024)),
-            transforms.ToTensor(),
-        ])
-
-        img_tensor = transform(image)
-
-        img_tensor = img_tensor.unsqueeze(0)
-        return img_tensor.numpy(), image, original_size
-
-    def postprocess_image(self, result_np: np.ndarray, im_size: list) -> np.ndarray:
-
-        result = torch.from_numpy(result_np)
-
-
-        if len(result.shape) == 3:
-            result = result.unsqueeze(0)
-
-
-        result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
-
-
-        ma = torch.max(result)
-        mi = torch.min(result)
-        result = (result - mi) / (ma - mi)
-
-        im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
-        im_array = np.squeeze(im_array)
-        return im_array
-
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         """
         Predicts the mask image for the input image.
@@ -79,17 +44,22 @@ class BenCustomSession(BaseSession):
             List[PILImage]: A list of PILImage objects representing the generated mask image.
         """
 
-        input_data, original_image, (w, h) = self.preprocess_image(img)
+        ort_outs = self.inner_session.run(
+            None,
+            self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
+        )
 
-        input_name = self.inner_session.get_inputs()[0].name
+        pred = ort_outs[0][:, 0, :, :]
 
-        outputs = self.inner_session.run(None, {input_name: input_data})
+        ma = np.max(pred)
+        mi = np.min(pred)
 
+        pred = (pred - mi) / (ma - mi)
+        pred = np.squeeze(pred)
 
-        alpha = self.postprocess_image(outputs[0], im_size=[w, h])
+        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
+        mask = mask.resize(img.size, Image.Resampling.LANCZOS)
 
-        mask = Image.fromarray(alpha, mode="L")
-        mask = mask.resize((w, h), Image.Resampling.LANCZOS)
 
         return [mask]
 

+ 1 - 1
rembg/sessions/dis_custom.py

@@ -72,7 +72,7 @@ class DisCustomSession(BaseSession):
         """
         model_path = kwargs.get("model_path")
         if model_path is None:
-            return
+            raise ValueError("model_path is required")
 
         return os.path.abspath(os.path.expanduser(model_path))