Flippchen 2 years ago
parent
commit
bb3c58f411
2 changed files with 41 additions and 11 deletions
  1. 10 3
      rembg/session_factory.py
  2. 31 8
      rembg/session_sam.py

+ 10 - 3
rembg/session_factory.py

@@ -75,7 +75,6 @@ def new_session(model_name: str = "u2net") -> BaseSession:
         decoder_md5 = "fa3d1c36a3187d3de1c8deebf33dd127"
         decoder_url = "https://github.com/Flippchen/rembg/releases/download/test/vit_b-decoder-quant.onnx"
 
-
         download_model(encoder_url, encoder_md5, fname_encoder, path)
         download_model(decoder_url, decoder_md5, fname_decoder, path)
 
@@ -86,8 +85,16 @@ def new_session(model_name: str = "u2net") -> BaseSession:
 
         return SamSession(
             model_name,
-            ort.InferenceSession(str(path / fname_encoder), providers=ort.get_available_providers(), sess_options=sess_opts),
-            ort.InferenceSession(str(path / fname_decoder), providers=ort.get_available_providers(), sess_options=sess_opts)
+            ort.InferenceSession(
+                str(path / fname_encoder),
+                providers=ort.get_available_providers(),
+                sess_options=sess_opts
+            ),
+            ort.InferenceSession(
+                str(path / fname_decoder),
+                providers=ort.get_available_providers(),
+                sess_options=sess_opts
+            ),
         )
 
     download_model(url, md5, fname, path)

+ 31 - 8
rembg/session_sam.py

@@ -43,23 +43,39 @@ def pad_to_square(img: numpy.ndarray, size=1024):
     h, w = img.shape[:2]
     padh = size - h
     padw = size - w
-    img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode='constant')
+    img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
     img = img.astype(np.float32)
     return img
 
 
 class SamSession(BaseSession):
-    def __init__(self, model_name: str, encoder: ort.InferenceSession, decoder: ort.InferenceSession):
+    def __init__(
+        self,
+        model_name: str,
+        encoder: ort.InferenceSession,
+        decoder: ort.InferenceSession
+    ):
         super().__init__(model_name, encoder)
         self.decoder = decoder
 
-    def normalize(self, img: numpy.ndarray, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), size=(1024, 1024)):
+    def normalize(
+        self,
+        img: numpy.ndarray,
+        mean=(0.485, 0.456, 0.406),
+        std=(0.229, 0.224, 0.225),
+        size=(1024, 1024)
+    ):
         pixel_mean = np.array([123.675, 116.28, 103.53]).reshape(1, 1, -1)
         pixel_std = np.array([58.395, 57.12, 57.375]).reshape(1, 1, -1)
         x = (img - pixel_mean) / pixel_std
         return x
 
-    def predict(self, img: PILImage, input_point=np.array([[500, 375]]), input_label=np.array([1])) -> List[PILImage]:
+    def predict(
+        self,
+        img: PILImage,
+        input_point=np.array([[500, 375]]),
+        input_label=np.array([1])
+    ) -> List[PILImage]:
         # Preprocess image
         image = resize_longes_side(img)
         image = numpy.array(image)
@@ -73,8 +89,12 @@ class SamSession(BaseSession):
         image_embedding = encoded[0]
 
         # Add a batch index, concatenate a padding point, and transform.
-        onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
-        onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
+        onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
+                     None, :, :
+        ]
+        onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
+                     None, :
+        ].astype(np.float32)
         onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
 
         # Create an empty mask input and an indicator for no mask.
@@ -87,11 +107,14 @@ class SamSession(BaseSession):
             "point_labels": onnx_label,
             "mask_input": onnx_mask_input,
             "has_mask_input": onnx_has_mask_input,
-            "orig_im_size": np.array(img.size[::-1], dtype=np.float32)
+            "orig_im_size": np.array(img.size[::-1], dtype=np.float32),
         }
 
         masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
         masks = masks > 0.0
-        masks = [Image.fromarray((masks[i, 0] * 255).astype(np.uint8)) for i in range(masks.shape[0])]
+        masks = [
+            Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
+            for i in range(masks.shape[0])
+        ]
 
         return masks