Browse Source

fix pylint

Flippchen 2 years ago
parent
commit
394ab21ab9
2 changed files with 5 additions and 4 deletions
  1. 4 3
      rembg/bg.py
  2. 1 1
      rembg/session_sam.py

+ 4 - 3
rembg/bg.py

@@ -20,6 +20,7 @@ from scipy.ndimage import binary_erosion
 
 from .session_base import BaseSession
 from .session_factory import new_session
+from .session_sam import SamSession
 
 kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
 
@@ -119,7 +120,7 @@ def remove(
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_erode_size: int = 10,
-    session: Optional[BaseSession] = None,
+    session: Optional[Union[BaseSession, SamSession]] = None,
     only_mask: bool = False,
     post_process_mask: bool = False,
     bgcolor: Optional[Tuple[int, int, int, int]] = None,
@@ -141,10 +142,10 @@ def remove(
     if session is None:
         session = new_session("u2net")
 
-    if session.model_name == "sam":
+    if isinstance(session, SamSession):
         if input_point is None or input_label is None:
             raise ValueError("Input point and label are required for SAM model.")
-        masks = session.predict(img, input_point, input_label)
+        masks = session.predict_sam(img, input_point, input_label)
     else:
         masks = session.predict(img)
 

+ 1 - 1
rembg/session_sam.py

@@ -68,7 +68,7 @@ class SamSession(BaseSession):
         x = (img - pixel_mean) / pixel_std
         return x
 
-    def predict(
+    def predict_sam(
         self,
         img: PILImage,
         input_point: np.ndarray,