|
@@ -1,6 +1,6 @@
|
|
|
import os
|
|
|
from copy import deepcopy
|
|
|
-from typing import List
|
|
|
+from typing import Dict, List, Tuple
|
|
|
|
|
|
import cv2
|
|
|
import numpy as np
|
|
@@ -87,8 +87,9 @@ class SamSession(BaseSession):
|
|
|
self,
|
|
|
model_name: str,
|
|
|
sess_opts: ort.SessionOptions,
|
|
|
+ providers=None,
|
|
|
*args,
|
|
|
- **kwargs,
|
|
|
+ **kwargs
|
|
|
):
|
|
|
"""
|
|
|
Initialize a new SamSession with the given model name and session options.
|
|
@@ -101,52 +102,27 @@ class SamSession(BaseSession):
|
|
|
"""
|
|
|
self.model_name = model_name
|
|
|
|
|
|
- self.providers = []
|
|
|
+ valid_providers = []
|
|
|
+ available_providers = ort.get_available_providers()
|
|
|
|
|
|
- _providers = ort.get_available_providers()
|
|
|
- for provider in kwargs.get("providers", []):
|
|
|
- if provider in _providers:
|
|
|
- self.providers.append(provider)
|
|
|
+ for provider in (providers or []):
|
|
|
+ if provider in available_providers:
|
|
|
+ valid_providers.append(provider)
|
|
|
else:
|
|
|
- self.providers.extend(_providers)
|
|
|
+ valid_providers.extend(available_providers)
|
|
|
|
|
|
paths = self.__class__.download_models(*args, **kwargs)
|
|
|
self.encoder = ort.InferenceSession(
|
|
|
str(paths[0]),
|
|
|
- providers=self.providers,
|
|
|
+ providers=valid_providers,
|
|
|
sess_options=sess_opts,
|
|
|
)
|
|
|
self.decoder = ort.InferenceSession(
|
|
|
str(paths[1]),
|
|
|
- providers=self.providers,
|
|
|
+ providers=valid_providers,
|
|
|
sess_options=sess_opts,
|
|
|
)
|
|
|
|
|
|
- def normalize(
|
|
|
- self,
|
|
|
- img: np.ndarray,
|
|
|
- mean=(),
|
|
|
- std=(),
|
|
|
- size=(),
|
|
|
- *args,
|
|
|
- **kwargs,
|
|
|
- ):
|
|
|
- """
|
|
|
- Normalize the input image by subtracting the mean and dividing by the standard deviation.
|
|
|
-
|
|
|
- Args:
|
|
|
- img (np.ndarray): The input image.
|
|
|
- mean (tuple, optional): The mean values for normalization. Defaults to ().
|
|
|
- std (tuple, optional): The standard deviation values for normalization. Defaults to ().
|
|
|
- size (tuple, optional): The target size of the image. Defaults to ().
|
|
|
- *args: Variable length argument list.
|
|
|
- **kwargs: Arbitrary keyword arguments.
|
|
|
-
|
|
|
- Returns:
|
|
|
- np.ndarray: The normalized image.
|
|
|
- """
|
|
|
- return img
|
|
|
-
|
|
|
def predict(
|
|
|
self,
|
|
|
img: PILImage,
|
|
@@ -269,8 +245,7 @@ class SamSession(BaseSession):
|
|
|
for m in masks[0, :, :, :]:
|
|
|
mask[m > 0.0] = [255, 255, 255]
|
|
|
|
|
|
- mask = Image.fromarray(mask).convert("L")
|
|
|
- return [mask]
|
|
|
+ return [Image.fromarray(mask).convert("L")]
|
|
|
|
|
|
@classmethod
|
|
|
def download_models(cls, *args, **kwargs):
|