Browse Source

fix lint and refactored normalizing

Flippchen 2 years ago
parent
commit
ff38b9a377
2 changed files with 11 additions and 11 deletions
  1. 2 2
      rembg/session_factory.py
  2. 9 9
      rembg/session_sam.py

+ 2 - 2
rembg/session_factory.py

@@ -88,12 +88,12 @@ def new_session(model_name: str = "u2net") -> BaseSession:
             ort.InferenceSession(
                 str(path / fname_encoder),
                 providers=ort.get_available_providers(),
-                sess_options=sess_opts
+                sess_options=sess_opts,
             ),
             ort.InferenceSession(
                 str(path / fname_decoder),
                 providers=ort.get_available_providers(),
-                sess_options=sess_opts
+                sess_options=sess_opts,
             ),
         )
 

+ 9 - 9
rembg/session_sam.py

@@ -53,7 +53,7 @@ class SamSession(BaseSession):
         self,
         model_name: str,
         encoder: ort.InferenceSession,
-        decoder: ort.InferenceSession
+        decoder: ort.InferenceSession,
     ):
         super().__init__(model_name, encoder)
         self.decoder = decoder
@@ -61,12 +61,12 @@ class SamSession(BaseSession):
     def normalize(
         self,
         img: numpy.ndarray,
-        mean=(0.485, 0.456, 0.406),
-        std=(0.229, 0.224, 0.225),
-        size=(1024, 1024)
+        mean=(123.675, 116.28, 103.53),
+        std=(58.395, 57.12, 57.375),
+        size=(1024, 1024),
     ):
-        pixel_mean = np.array([123.675, 116.28, 103.53]).reshape(1, 1, -1)
-        pixel_std = np.array([58.395, 57.12, 57.375]).reshape(1, 1, -1)
+        pixel_mean = np.array([*mean]).reshape(1, 1, -1)
+        pixel_std = np.array([*std]).reshape(1, 1, -1)
         x = (img - pixel_mean) / pixel_std
         return x
 
@@ -74,7 +74,7 @@ class SamSession(BaseSession):
         self,
         img: PILImage,
         input_point=np.array([[500, 375]]),
-        input_label=np.array([1])
+        input_label=np.array([1]),
     ) -> List[PILImage]:
         # Preprocess image
         image = resize_longes_side(img)
@@ -90,10 +90,10 @@ class SamSession(BaseSession):
 
         # Add a batch index, concatenate a padding point, and transform.
         onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[
-                     None, :, :
+            None, :, :
         ]
         onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[
-                     None, :
+            None, :
         ].astype(np.float32)
         onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)