Browse Source

fix linters

Daniel Gatis 1 năm trước cách đây
mục cha
commit
a30359986a
4 tập tin đã thay đổi với 22 bổ sung47 xóa
  1. 6 6
      rembg/bg.py
  2. 2 2
      rembg/commands/b_command.py
  3. 2 2
      rembg/commands/p_command.py
  4. 12 37
      rembg/sessions/sam.py

+ 6 - 6
rembg/bg.py

@@ -55,11 +55,11 @@ def alpha_matting_cutout(
     if img.mode == "RGBA" or img.mode == "CMYK":
         img = img.convert("RGB")
 
-    img = np.asarray(img)
-    mask = np.asarray(mask)
+    img_array = np.asarray(img)
+    mask_array = np.asarray(mask)
 
-    is_foreground = mask > foreground_threshold
-    is_background = mask < background_threshold
+    is_foreground = mask_array > foreground_threshold
+    is_background = mask_array < background_threshold
 
     structure = None
     if erode_structure_size > 0:
@@ -70,11 +70,11 @@ def alpha_matting_cutout(
     is_foreground = binary_erosion(is_foreground, structure=structure)
     is_background = binary_erosion(is_background, structure=structure, border_value=1)
 
-    trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
+    trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)
     trimap[is_foreground] = 255
     trimap[is_background] = 0
 
-    img_normalized = img / 255.0
+    img_normalized = img_array / 255.0
     trimap_normalized = trimap / 255.0
 
     alpha = estimate_alpha_cf(img_normalized, trimap_normalized)

+ 2 - 2
rembg/commands/b_command.py

@@ -6,7 +6,7 @@ import sys
 from typing import IO
 
 import click
-from PIL import Image
+from PIL.Image import Image as PILImage
 
 from ..bg import remove
 from ..session_factory import new_session
@@ -134,7 +134,7 @@ def b_command(
         if not os.path.isdir(output_dir):
             os.makedirs(output_dir, exist_ok=True)
 
-    def img_to_byte_array(img: Image) -> bytes:
+    def img_to_byte_array(img: PILImage) -> bytes:
         buff = io.BytesIO()
         img.save(buff, format="PNG")
         return buff.getvalue()

+ 2 - 2
rembg/commands/p_command.py

@@ -186,9 +186,9 @@ def p_command(
 
     inputs = list(input.glob("**/*"))
     if not watch:
-        inputs = tqdm(inputs)
+        inputs_tqdm = tqdm(inputs)
 
-    for each_input in inputs:
+    for each_input in inputs_tqdm:
         if not each_input.is_dir():
             process(each_input)
 

+ 12 - 37
rembg/sessions/sam.py

@@ -1,6 +1,6 @@
 import os
 from copy import deepcopy
-from typing import List
+from typing import Dict, List, Tuple
 
 import cv2
 import numpy as np
@@ -87,8 +87,9 @@ class SamSession(BaseSession):
         self,
         model_name: str,
         sess_opts: ort.SessionOptions,
+        providers=None,
         *args,
-        **kwargs,
+        **kwargs
     ):
         """
         Initialize a new SamSession with the given model name and session options.
@@ -101,52 +102,27 @@ class SamSession(BaseSession):
         """
         self.model_name = model_name
 
-        self.providers = []
+        valid_providers = []
+        available_providers = ort.get_available_providers()
 
-        _providers = ort.get_available_providers()
-        for provider in kwargs.get("providers", []):
-            if provider in _providers:
-                self.providers.append(provider)
+        for provider in (providers or []):
+            if provider in available_providers:
+                valid_providers.append(provider)
         else:
-            self.providers.extend(_providers)
+            valid_providers.extend(available_providers)
 
         paths = self.__class__.download_models(*args, **kwargs)
         self.encoder = ort.InferenceSession(
             str(paths[0]),
-            providers=self.providers,
+            providers=valid_providers,
             sess_options=sess_opts,
         )
         self.decoder = ort.InferenceSession(
             str(paths[1]),
-            providers=self.providers,
+            providers=valid_providers,
             sess_options=sess_opts,
         )
 
-    def normalize(
-        self,
-        img: np.ndarray,
-        mean=(),
-        std=(),
-        size=(),
-        *args,
-        **kwargs,
-    ):
-        """
-        Normalize the input image by subtracting the mean and dividing by the standard deviation.
-
-        Args:
-            img (np.ndarray): The input image.
-            mean (tuple, optional): The mean values for normalization. Defaults to ().
-            std (tuple, optional): The standard deviation values for normalization. Defaults to ().
-            size (tuple, optional): The target size of the image. Defaults to ().
-            *args: Variable length argument list.
-            **kwargs: Arbitrary keyword arguments.
-
-        Returns:
-            np.ndarray: The normalized image.
-        """
-        return img
-
     def predict(
         self,
         img: PILImage,
@@ -269,8 +245,7 @@ class SamSession(BaseSession):
         for m in masks[0, :, :, :]:
             mask[m > 0.0] = [255, 255, 255]
 
-        mask = Image.fromarray(mask).convert("L")
-        return [mask]
+        return [Image.fromarray(mask).convert("L")]
 
     @classmethod
     def download_models(cls, *args, **kwargs):