|
@@ -20,6 +20,7 @@ from scipy.ndimage import binary_erosion
|
|
|
|
|
|
from .session_base import BaseSession
|
|
|
from .session_factory import new_session
|
|
|
+from .session_sam import SamSession
|
|
|
|
|
|
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
|
|
|
|
@@ -119,7 +120,7 @@ def remove(
|
|
|
alpha_matting_foreground_threshold: int = 240,
|
|
|
alpha_matting_background_threshold: int = 10,
|
|
|
alpha_matting_erode_size: int = 10,
|
|
|
- session: Optional[BaseSession] = None,
|
|
|
+ session: Optional[Union[BaseSession, SamSession]] = None,
|
|
|
only_mask: bool = False,
|
|
|
post_process_mask: bool = False,
|
|
|
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
|
@@ -141,10 +142,10 @@ def remove(
|
|
|
if session is None:
|
|
|
session = new_session("u2net")
|
|
|
|
|
|
- if session.model_name == "sam":
|
|
|
+ if isinstance(session, SamSession):
|
|
|
if input_point is None or input_label is None:
|
|
|
raise ValueError("Input point and label are required for SAM model.")
|
|
|
- masks = session.predict(img, input_point, input_label)
|
|
|
+ masks = session.predict_sam(img, input_point, input_label)
|
|
|
else:
|
|
|
masks = session.predict(img)
|
|
|
|