|
@@ -5,14 +5,11 @@ import numpy as np
|
|
|
|
|
|
from PIL import Image
|
|
|
from PIL.Image import Image as PILImage
|
|
|
-import torch
|
|
|
|
|
|
from .base import BaseSession
|
|
|
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
-import torchvision.transforms as transforms
|
|
|
-import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
class BenCustomSession(BaseSession):
|
|
@@ -34,38 +31,6 @@ class BenCustomSession(BaseSession):
|
|
|
|
|
|
super().__init__(model_name, sess_opts, *args, **kwargs)
|
|
|
|
|
|
- def preprocess_image(self, image):
|
|
|
- original_size = image.size
|
|
|
- transform = transforms.Compose([
|
|
|
- transforms.Resize((1024, 1024)),
|
|
|
- transforms.ToTensor(),
|
|
|
- ])
|
|
|
-
|
|
|
- img_tensor = transform(image)
|
|
|
-
|
|
|
- img_tensor = img_tensor.unsqueeze(0)
|
|
|
- return img_tensor.numpy(), image, original_size
|
|
|
-
|
|
|
- def postprocess_image(self, result_np: np.ndarray, im_size: list) -> np.ndarray:
|
|
|
-
|
|
|
- result = torch.from_numpy(result_np)
|
|
|
-
|
|
|
-
|
|
|
- if len(result.shape) == 3:
|
|
|
- result = result.unsqueeze(0)
|
|
|
-
|
|
|
-
|
|
|
- result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
|
|
|
-
|
|
|
-
|
|
|
- ma = torch.max(result)
|
|
|
- mi = torch.min(result)
|
|
|
- result = (result - mi) / (ma - mi)
|
|
|
-
|
|
|
- im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
|
|
|
- im_array = np.squeeze(im_array)
|
|
|
- return im_array
|
|
|
-
|
|
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
"""
|
|
|
Predicts the mask image for the input image.
|
|
@@ -79,17 +44,22 @@ class BenCustomSession(BaseSession):
|
|
|
List[PILImage]: A list of PILImage objects representing the generated mask image.
|
|
|
"""
|
|
|
|
|
|
- input_data, original_image, (w, h) = self.preprocess_image(img)
|
|
|
+ ort_outs = self.inner_session.run(
|
|
|
+ None,
|
|
|
+ self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
|
|
|
+ )
|
|
|
|
|
|
- input_name = self.inner_session.get_inputs()[0].name
|
|
|
+ pred = ort_outs[0][:, 0, :, :]
|
|
|
|
|
|
- outputs = self.inner_session.run(None, {input_name: input_data})
|
|
|
+ ma = np.max(pred)
|
|
|
+ mi = np.min(pred)
|
|
|
|
|
|
+ pred = (pred - mi) / (ma - mi)
|
|
|
+ pred = np.squeeze(pred)
|
|
|
|
|
|
- alpha = self.postprocess_image(outputs[0], im_size=[w, h])
|
|
|
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
|
|
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
|
|
|
|
|
- mask = Image.fromarray(alpha, mode="L")
|
|
|
- mask = mask.resize((w, h), Image.Resampling.LANCZOS)
|
|
|
|
|
|
return [mask]
|
|
|
|