Flippchen преди 2 години
родител
ревизия
72d1c6c64c
променени са 2 файла, в които са добавени 4 реда и са изтрити 6 реда
  1. 1 1
      rembg/session_factory.py
  2. 3 5
      rembg/session_sam.py

+ 1 - 1
rembg/session_factory.py

@@ -11,8 +11,8 @@ import pooch
 from .session_base import BaseSession
 from .session_cloth import ClothSession
 from .session_dis import DisSession
-from .session_simple import SimpleSession
 from .session_sam import SamSession
+from .session_simple import SimpleSession
 
 
 def download_model(url: str, md5: str, fname: str, path: Path):

+ 3 - 5
rembg/session_sam.py

@@ -1,11 +1,9 @@
 from typing import List
 
-import numpy
 import numpy as np
 from PIL import Image
 from PIL.Image import Image as PILImage
 import onnxruntime as ort
-from matplotlib import pyplot as plt
 
 from .session_base import BaseSession
 
@@ -39,7 +37,7 @@ def resize_longes_side(img: PILImage, size=1024):
     return img.resize((new_w, new_h))
 
 
-def pad_to_square(img: numpy.ndarray, size=1024):
+def pad_to_square(img: np.ndarray, size=1024):
     h, w = img.shape[:2]
     padh = size - h
     padw = size - w
@@ -60,7 +58,7 @@ class SamSession(BaseSession):
 
     def normalize(
         self,
-        img: numpy.ndarray,
+        img: np.ndarray,
         mean=(123.675, 116.28, 103.53),
         std=(58.395, 57.12, 57.375),
         size=(1024, 1024),
@@ -78,7 +76,7 @@ class SamSession(BaseSession):
     ) -> List[PILImage]:
         # Preprocess image
         image = resize_longes_side(img)
-        image = numpy.array(image)
+        image = np.array(image)
         image = self.normalize(image)
         image = pad_to_square(image)