瀏覽代碼

fix sam session (#531)

Daniel Gatis 1 年之前
父節點
當前提交
47701001ab

+ 5 - 1
README.md

@@ -159,10 +159,14 @@ rembg i -a path/to/input.png path/to/output.png
 Passing extras parameters
 
 ```
-rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
+SAM example
+
+rembg i -m sam -x '{ "sam_prompt": [{"type": "point", "data": [724, 740], "label": 1}] }' examples/plants-1.jpg examples/plants-1.out.png
 ```
 
 ```
+Custom model example
+
 rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
 ```
 

二進制
examples/plants-1.jpg


二進制
examples/plants-1.out.png


+ 179 - 68
rembg/sessions/sam.py

@@ -1,9 +1,12 @@
 import os
+from copy import deepcopy
 from typing import List
 
+import cv2
 import numpy as np
 import onnxruntime as ort
 import pooch
+from jsonschema import validate
 from PIL import Image
 from PIL.Image import Image as PILImage
 
@@ -15,37 +18,58 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
     newh, neww = oldh * scale, oldw * scale
     neww = int(neww + 0.5)
     newh = int(newh + 0.5)
+
     return (newh, neww)
 
 
-def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
+def apply_coords(coords: np.ndarray, original_size, target_length):
     old_h, old_w = original_size
     new_h, new_w = get_preprocess_shape(
         original_size[0], original_size[1], target_length
     )
-    coords = coords.copy().astype(float)
+
+    coords = deepcopy(coords).astype(float)
     coords[..., 0] = coords[..., 0] * (new_w / old_w)
     coords[..., 1] = coords[..., 1] * (new_h / old_h)
+
     return coords
 
 
-def resize_longes_side(img: PILImage, size=1024):
-    w, h = img.size
-    if h > w:
-        new_h, new_w = size, int(w * size / h)
-    else:
-        new_h, new_w = int(h * size / w), size
+def get_input_points(prompt):
+    points = []
+    labels = []
+
+    for mark in prompt:
+        if mark["type"] == "point":
+            points.append(mark["data"])
+            labels.append(mark["label"])
+        elif mark["type"] == "rectangle":
+            points.append([mark["data"][0], mark["data"][1]])
+            points.append([mark["data"][2], mark["data"][3]])
+            labels.append(2)
+            labels.append(3)
 
-    return img.resize((new_w, new_h))
+    points, labels = np.array(points), np.array(labels)
+    return points, labels
 
 
-def pad_to_square(img: np.ndarray, size=1024):
-    h, w = img.shape[:2]
-    padh = size - h
-    padw = size - w
-    img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
-    img = img.astype(np.float32)
-    return img
+def transform_masks(masks, original_size, transform_matrix):
+    output_masks = []
+
+    for batch in range(masks.shape[0]):
+        batch_masks = []
+        for mask_id in range(masks.shape[1]):
+            mask = masks[batch, mask_id]
+            mask = cv2.warpAffine(
+                mask,
+                transform_matrix[:2],
+                (original_size[1], original_size[0]),
+                flags=cv2.INTER_LINEAR,
+            )
+            batch_masks.append(mask)
+        output_masks.append(batch_masks)
+
+    return np.array(output_masks)
 
 
 class SamSession(BaseSession):
@@ -70,7 +94,7 @@ class SamSession(BaseSession):
             **kwargs: Arbitrary keyword arguments.
         """
         self.model_name = model_name
-        paths = self.__class__.download_models()
+        paths = self.__class__.download_models(*args, **kwargs)
         self.encoder = ort.InferenceSession(
             str(paths[0]),
             providers=ort.get_available_providers(),
@@ -85,9 +109,9 @@ class SamSession(BaseSession):
     def normalize(
         self,
         img: np.ndarray,
-        mean=(123.675, 116.28, 103.53),
-        std=(58.395, 57.12, 57.375),
-        size=(1024, 1024),
+        mean=(),
+        std=(),
+        size=(),
         *args,
         **kwargs,
     ):
@@ -96,19 +120,16 @@ class SamSession(BaseSession):
 
         Args:
             img (np.ndarray): The input image.
-            mean (tuple, optional): The mean values for normalization. Defaults to (123.675, 116.28, 103.53).
-            std (tuple, optional): The standard deviation values for normalization. Defaults to (58.395, 57.12, 57.375).
-            size (tuple, optional): The target size of the image. Defaults to (1024, 1024).
+            mean (tuple, optional): The mean values for normalization. Defaults to ().
+            std (tuple, optional): The standard deviation values for normalization. Defaults to ().
+            size (tuple, optional): The target size of the image. Defaults to ().
             *args: Variable length argument list.
             **kwargs: Arbitrary keyword arguments.
 
         Returns:
             np.ndarray: The normalized image.
         """
-        pixel_mean = np.array([*mean]).reshape(1, 1, -1)
-        pixel_std = np.array([*std]).reshape(1, 1, -1)
-        x = (img - pixel_mean) / pixel_std
-        return x
+        return img
 
     def predict(
         self,
@@ -129,36 +150,89 @@ class SamSession(BaseSession):
         Returns:
             List[PILImage]: A list of masks generated by the decoder.
         """
-        # Preprocess image
-        image = resize_longes_side(img)
-        image = np.array(image)
-        image = self.normalize(image)
-        image = pad_to_square(image)
-
-        input_labels = kwargs.get("input_labels")
-        input_points = kwargs.get("input_points")
-
-        if input_labels is None:
-            raise ValueError("input_labels is required")
-        if input_points is None:
-            raise ValueError("input_points is required")
-
-        # Transpose
-        image = image.transpose(2, 0, 1)[None, :, :, :]
-        # Run encoder (Image embedding)
-        encoded = self.encoder.run(None, {"x": image})
-        image_embedding = encoded[0]
-
-        # Add a batch index, concatenate a padding point, and transform.
+        prompt = kwargs.get("sam_prompt", "{}")
+        schema = {
+            "type": "array",
+            "items": {
+                "type": "object",
+                "properties": {
+                    "type": {"type": "string"},
+                    "label": {"type": "integer"},
+                    "data": {
+                        "type": "array",
+                        "items": {"type": "number"},
+                    },
+                },
+            },
+        }
+
+        validate(instance=prompt, schema=schema)
+
+        target_size = 1024
+        input_size = (684, 1024)
+        encoder_input_name = self.encoder.get_inputs()[0].name
+
+        img = img.convert("RGB")
+        cv_image = np.array(img)
+        original_size = cv_image.shape[:2]
+
+        scale_x = input_size[1] / cv_image.shape[1]
+        scale_y = input_size[0] / cv_image.shape[0]
+        scale = min(scale_x, scale_y)
+
+        transform_matrix = np.array(
+            [
+                [scale, 0, 0],
+                [0, scale, 0],
+                [0, 0, 1],
+            ]
+        )
+
+        cv_image = cv2.warpAffine(
+            cv_image,
+            transform_matrix[:2],
+            (input_size[1], input_size[0]),
+            flags=cv2.INTER_LINEAR,
+        )
+
+        ## encoder
+
+        encoder_inputs = {
+            encoder_input_name: cv_image.astype(np.float32),
+        }
+
+        encoder_output = self.encoder.run(None, encoder_inputs)
+        image_embedding = encoder_output[0]
+
+        embedding = {
+            "image_embedding": image_embedding,
+            "original_size": original_size,
+            "transform_matrix": transform_matrix,
+        }
+
+        ## decoder
+
+        input_points, input_labels = get_input_points(prompt)
         onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
             None, :, :
         ]
         onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
             None, :
         ].astype(np.float32)
-        onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
+        onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(
+            np.float32
+        )
+
+        onnx_coord = np.concatenate(
+            [
+                onnx_coord,
+                np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
+            ],
+            axis=2,
+        )
+        onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
+        onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
 
-        # Create an empty mask input and an indicator for no mask.
         onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
         onnx_has_mask_input = np.zeros(1, dtype=np.float32)
 
@@ -168,17 +242,19 @@ class SamSession(BaseSession):
             "point_labels": onnx_label,
             "mask_input": onnx_mask_input,
             "has_mask_input": onnx_has_mask_input,
-            "orig_im_size": np.array(img.size[::-1], dtype=np.float32),
+            "orig_im_size": np.array(input_size, dtype=np.float32),
         }
 
-        masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
-        masks = masks > 0.0
-        masks = [
-            Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
-            for i in range(masks.shape[0])
-        ]
+        masks, _, _ = self.decoder.run(None, decoder_inputs)
+        inv_transform_matrix = np.linalg.inv(transform_matrix)
+        masks = transform_masks(masks, original_size, inv_transform_matrix)
+
+        mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
+        for m in masks[0, :, :, :]:
+            mask[m > 0.0] = [255, 255, 255]
 
-        return masks
+        mask = Image.fromarray(mask).convert("L")
+        return [mask]
 
     @classmethod
     def download_models(cls, *args, **kwargs):
@@ -195,29 +271,64 @@ class SamSession(BaseSession):
         Returns:
             tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
         """
-        fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
-        fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
+        model_name = kwargs.get("sam_model", "sam_vit_b_01ec64")
+        quant = kwargs.get("sam_quant", False)
+
+        fname_encoder = f"{model_name}.encoder.onnx"
+        fname_decoder = f"{model_name}.decoder.onnx"
+
+        if quant:
+            fname_encoder = f"{model_name}.encoder.quant.onnx"
+            fname_decoder = f"{model_name}.decoder.quant.onnx"
 
         pooch.retrieve(
-            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
-            None
-            if cls.checksum_disabled(*args, **kwargs)
-            else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
+            f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}",
+            None,
             fname=fname_encoder,
             path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 
         pooch.retrieve(
-            "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
-            None
-            if cls.checksum_disabled(*args, **kwargs)
-            else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
+            f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}",
+            None,
             fname=fname_decoder,
             path=cls.u2net_home(*args, **kwargs),
             progressbar=True,
         )
 
+        if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists(
+            os.path.join(
+                cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin"
+            )
+        ):
+            content = bytearray()
+
+            for i in range(1, 4):
+                pooch.retrieve(
+                    f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin",
+                    None,
+                    fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin",
+                    path=cls.u2net_home(*args, **kwargs),
+                    progressbar=True,
+                )
+
+                fbin = os.path.join(
+                    cls.u2net_home(*args, **kwargs),
+                    f"sam_vit_h_4b8939.encoder_data.{i}.bin",
+                )
+                content.extend(open(fbin, "rb").read())
+                os.remove(fbin)
+
+            with open(
+                os.path.join(
+                    cls.u2net_home(*args, **kwargs),
+                    "sam_vit_h_4b8939.encoder_data.bin",
+                ),
+                "wb",
+            ) as fp:
+                fp.write(content)
+
         return (
             os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
             os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),

+ 1 - 0
setup.py

@@ -12,6 +12,7 @@ here = pathlib.Path(__file__).parent.resolve()
 long_description = (here / "README.md").read_text(encoding="utf-8")
 
 install_requires = [
+    "jsonschema",
     "numpy",
     "onnxruntime",
     "opencv-python-headless",

二進制
tests/fixtures/plants-1.jpg


二進制
tests/results/anime-girl-1.sam.png


二進制
tests/results/car-1.sam.png


二進制
tests/results/cloth-1.sam.png


二進制
tests/results/plants-1.isnet-anime.png


二進制
tests/results/plants-1.isnet-general-use.png


二進制
tests/results/plants-1.sam.png


二進制
tests/results/plants-1.silueta.png


二進制
tests/results/plants-1.u2net.png


二進制
tests/results/plants-1.u2net_cloth_seg.png


二進制
tests/results/plants-1.u2net_human_seg.png


二進制
tests/results/plants-1.u2netp.png


+ 8 - 7
tests/test_remove.py

@@ -12,18 +12,19 @@ def test_remove():
     kwargs = {
         "sam": {
             "anime-girl-1" : {
-                "input_points": [[400, 165]],
-                "input_labels": [1],
+                "sam_prompt" :[{"type": "point", "data": [400, 165], "label": 1}],
             },
 
             "car-1" : {
-                "input_points": [[250, 200]],
-                "input_labels": [1],
+                "sam_prompt" :[{"type": "point", "data": [250, 200], "label": 1}],
             },
 
             "cloth-1" : {
-                "input_points": [[370, 495]],
-                "input_labels": [1],
+                "sam_prompt" :[{"type": "point", "data": [370, 495], "label": 1}],
+            },
+
+            "plants-1" : {
+                "sam_prompt" :[{"type": "point", "data": [724, 740], "label": 1}],
             },
         }
     }
@@ -38,7 +39,7 @@ def test_remove():
         "isnet-anime",
         "sam"
     ]:
-        for picture in ["anime-girl-1", "car-1", "cloth-1"]:
+        for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]:
             image_path = Path(here / "fixtures" / f"{picture}.jpg")
             image = image_path.read_bytes()