|
@@ -123,6 +123,8 @@ def remove(
|
|
|
only_mask: bool = False,
|
|
|
post_process_mask: bool = False,
|
|
|
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
|
|
+ input_point: Optional[np.ndarray] = None,
|
|
|
+ input_label: Optional[np.ndarray] = None,
|
|
|
) -> Union[bytes, PILImage, np.ndarray]:
|
|
|
if isinstance(data, PILImage):
|
|
|
return_type = ReturnType.PILLOW
|
|
@@ -139,7 +141,13 @@ def remove(
|
|
|
if session is None:
|
|
|
session = new_session("u2net")
|
|
|
|
|
|
- masks = session.predict(img)
|
|
|
+ if session.model_name == "sam":
|
|
|
+ if input_point is None or input_label is None:
|
|
|
+ raise ValueError("Input point and label are required for SAM model.")
|
|
|
+ masks = session.predict(img, input_point, input_label)
|
|
|
+ else:
|
|
|
+ masks = session.predict(img)
|
|
|
+
|
|
|
cutouts = []
|
|
|
|
|
|
for mask in masks:
|