Bladeren bron

Merge pull request #431 from Flippchen/main

Added support for Facebook''s Segment Anything
Daniel Gatis 2 jaren geleden
bovenliggende
commit
1e311331e6
4 gewijzigde bestanden met toevoegingen van 181 en 16 verwijderingen
  1. 11 2
      rembg/bg.py
  2. 51 13
      rembg/session_factory.py
  3. 118 0
      rembg/session_sam.py
  4. 1 1
      requirements.txt

+ 11 - 2
rembg/bg.py

@@ -20,6 +20,7 @@ from scipy.ndimage import binary_erosion
 
 from .session_base import BaseSession
 from .session_factory import new_session
+from .session_sam import SamSession
 
 kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
 
@@ -119,10 +120,12 @@ def remove(
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_erode_size: int = 10,
-    session: Optional[BaseSession] = None,
+    session: Optional[Union[BaseSession, SamSession]] = None,
     only_mask: bool = False,
     post_process_mask: bool = False,
     bgcolor: Optional[Tuple[int, int, int, int]] = None,
+    input_point: Optional[np.ndarray] = None,
+    input_label: Optional[np.ndarray] = None,
 ) -> Union[bytes, PILImage, np.ndarray]:
     if isinstance(data, PILImage):
         return_type = ReturnType.PILLOW
@@ -139,7 +142,13 @@ def remove(
     if session is None:
         session = new_session("u2net")
 
-    masks = session.predict(img)
+    if isinstance(session, SamSession):
+        if input_point is None or input_label is None:
+            raise ValueError("Input point and label are required for SAM model.")
+        masks = session.predict_sam(img, input_point, input_label)
+    else:
+        masks = session.predict(img)
+
     cutouts = []
 
     for mask in masks:

+ 51 - 13
rembg/session_factory.py

@@ -11,10 +11,30 @@ import pooch
 from .session_base import BaseSession
 from .session_cloth import ClothSession
 from .session_dis import DisSession
+from .session_sam import SamSession
 from .session_simple import SimpleSession
 
 
+def download_model(url: str, md5: str, fname: str, path: Path):
+    pooch.retrieve(
+        url,
+        f"md5:{md5}",
+        fname=fname,
+        path=path,
+        progressbar=True,
+    )
+
+
 def new_session(model_name: str = "u2net") -> BaseSession:
+    # Define the model path
+    u2net_home = os.getenv(
+        "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
+    )
+
+    fname = f"{model_name}.onnx"
+    path = Path(u2net_home).expanduser()
+    full_path = Path(u2net_home).expanduser() / fname
+
     session_class: Type[BaseSession]
     md5 = "60024c5c889badc19c04ad937298a77b"
     url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
@@ -44,22 +64,40 @@ def new_session(model_name: str = "u2net") -> BaseSession:
         md5 = "fc16ebd8b0c10d971d3513d564d01e29"
         url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx"
         session_class = DisSession
+    elif model_name == "sam":
+        path = Path(u2net_home).expanduser()
 
-    u2net_home = os.getenv(
-        "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
-    )
+        fname_encoder = f"{model_name}_encoder.onnx"
+        encoder_md5 = "13d97c5c79ab13ef86d67cbde5f1b250"
+        encoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-encoder-quant.onnx"
 
-    fname = f"{model_name}.onnx"
-    path = Path(u2net_home).expanduser()
-    full_path = Path(u2net_home).expanduser() / fname
+        fname_decoder = f"{model_name}_decoder.onnx"
+        decoder_md5 = "fa3d1c36a3187d3de1c8deebf33dd127"
+        decoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-decoder-quant.onnx"
 
-    pooch.retrieve(
-        url,
-        f"md5:{md5}",
-        fname=fname,
-        path=Path(u2net_home).expanduser(),
-        progressbar=True,
-    )
+        download_model(encoder_url, encoder_md5, fname_encoder, path)
+        download_model(decoder_url, decoder_md5, fname_decoder, path)
+
+        sess_opts = ort.SessionOptions()
+
+        if "OMP_NUM_THREADS" in os.environ:
+            sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
+
+        return SamSession(
+            model_name,
+            ort.InferenceSession(
+                str(path / fname_encoder),
+                providers=ort.get_available_providers(),
+                sess_options=sess_opts,
+            ),
+            ort.InferenceSession(
+                str(path / fname_decoder),
+                providers=ort.get_available_providers(),
+                sess_options=sess_opts,
+            ),
+        )
+
+    download_model(url, md5, fname, path)
 
     sess_opts = ort.SessionOptions()
 

+ 118 - 0
rembg/session_sam.py

@@ -0,0 +1,118 @@
+from typing import List
+
+import numpy as np
+import onnxruntime as ort
+from PIL import Image
+from PIL.Image import Image as PILImage
+
+from .session_base import BaseSession
+
+
+def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
+    scale = long_side_length * 1.0 / max(oldh, oldw)
+    newh, neww = oldh * scale, oldw * scale
+    neww = int(neww + 0.5)
+    newh = int(newh + 0.5)
+    return (newh, neww)
+
+
+def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
+    old_h, old_w = original_size
+    new_h, new_w = get_preprocess_shape(
+        original_size[0], original_size[1], target_length
+    )
+    coords = coords.copy().astype(float)
+    coords[..., 0] = coords[..., 0] * (new_w / old_w)
+    coords[..., 1] = coords[..., 1] * (new_h / old_h)
+    return coords
+
+
+def resize_longes_side(img: PILImage, size=1024):
+    w, h = img.size
+    if h > w:
+        new_h, new_w = size, int(w * size / h)
+    else:
+        new_h, new_w = int(h * size / w), size
+
+    return img.resize((new_w, new_h))
+
+
+def pad_to_square(img: np.ndarray, size=1024):
+    h, w = img.shape[:2]
+    padh = size - h
+    padw = size - w
+    img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
+    img = img.astype(np.float32)
+    return img
+
+
+class SamSession(BaseSession):
+    def __init__(
+        self,
+        model_name: str,
+        encoder: ort.InferenceSession,
+        decoder: ort.InferenceSession,
+    ):
+        super().__init__(model_name, encoder)
+        self.decoder = decoder
+
+    def normalize(
+        self,
+        img: np.ndarray,
+        mean=(123.675, 116.28, 103.53),
+        std=(58.395, 57.12, 57.375),
+        size=(1024, 1024),
+    ):
+        pixel_mean = np.array([*mean]).reshape(1, 1, -1)
+        pixel_std = np.array([*std]).reshape(1, 1, -1)
+        x = (img - pixel_mean) / pixel_std
+        return x
+
+    def predict_sam(
+        self,
+        img: PILImage,
+        input_point: np.ndarray,
+        input_label: np.ndarray,
+    ) -> List[PILImage]:
+        # Preprocess image
+        image = resize_longes_side(img)
+        image = np.array(image)
+        image = self.normalize(image)
+        image = pad_to_square(image)
+
+        # Transpose
+        image = image.transpose(2, 0, 1)[None, :, :, :]
+        # Run encoder (Image embedding)
+        encoded = self.inner_session.run(None, {"x": image})
+        image_embedding = encoded[0]
+
+        # Add a batch index, concatenate a padding point, and transform.
+        onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
+            None, :, :
+        ]
+        onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
+            None, :
+        ].astype(np.float32)
+        onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
+
+        # Create an empty mask input and an indicator for no mask.
+        onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
+        onnx_has_mask_input = np.zeros(1, dtype=np.float32)
+
+        decoder_inputs = {
+            "image_embeddings": image_embedding,
+            "point_coords": onnx_coord,
+            "point_labels": onnx_label,
+            "mask_input": onnx_mask_input,
+            "has_mask_input": onnx_has_mask_input,
+            "orig_im_size": np.array(img.size[::-1], dtype=np.float32),
+        }
+
+        masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
+        masks = masks > 0.0
+        masks = [
+            Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
+            for i in range(masks.shape[0])
+        ]
+
+        return masks

+ 1 - 1
requirements.txt

@@ -6,7 +6,7 @@ filetype==1.2.0
 pooch==1.6.0
 imagehash==4.3.1
 numpy==1.23.5
-onnxruntime==1.13.1
+onnxruntime==1.14.1
 opencv-python-headless==4.6.0.66
 pillow==9.3.0
 pymatting==1.1.8