Explorar el Código

added input for remove function

Flippchen hace 2 años
padre
commit
d7828b0369
Se han modificado 3 ficheros con 12 adiciones y 4 borrados
  1. 9 1
      rembg/bg.py
  2. 1 1
      rembg/session_factory.py
  3. 2 2
      rembg/session_sam.py

+ 9 - 1
rembg/bg.py

@@ -123,6 +123,8 @@ def remove(
     only_mask: bool = False,
     post_process_mask: bool = False,
     bgcolor: Optional[Tuple[int, int, int, int]] = None,
+    input_point: Optional[np.ndarray] = None,
+    input_label: Optional[np.ndarray] = None,
 ) -> Union[bytes, PILImage, np.ndarray]:
     if isinstance(data, PILImage):
         return_type = ReturnType.PILLOW
@@ -139,7 +141,13 @@ def remove(
     if session is None:
         session = new_session("u2net")
 
-    masks = session.predict(img)
+    if session.model_name == "sam":
+        if input_point is None or input_label is None:
+            raise ValueError("Input point and label are required for SAM model.")
+        masks = session.predict(img, input_point, input_label)
+    else:
+        masks = session.predict(img)
+
     cutouts = []
 
     for mask in masks:

+ 1 - 1
rembg/session_factory.py

@@ -64,7 +64,7 @@ def new_session(model_name: str = "u2net") -> BaseSession:
         md5 = "fc16ebd8b0c10d971d3513d564d01e29"
         url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
         session_class = DisSession
-    elif model_name == "SAM":
+    elif model_name == "sam":
         path = Path(u2net_home).expanduser()
 
         fname_encoder = f"{model_name}_encoder.onnx"

+ 2 - 2
rembg/session_sam.py

@@ -71,8 +71,8 @@ class SamSession(BaseSession):
     def predict(
         self,
         img: PILImage,
-        input_point=np.array([[500, 375]]),
-        input_label=np.array([1]),
+        input_point: np.ndarray,
+        input_label: np.ndarray,
     ) -> List[PILImage]:
         # Preprocess image
         image = resize_longes_side(img)