浏览代码

reordered imports

Flippchen 2 年之前
父节点
当前提交
72d1c6c64c
共有 2 个文件被更改,包括 4 次插入6 次删除
  1. 1 1
      rembg/session_factory.py
  2. 3 5
      rembg/session_sam.py

+ 1 - 1
rembg/session_factory.py

@@ -11,8 +11,8 @@ import pooch
 from .session_base import BaseSession
 from .session_base import BaseSession
 from .session_cloth import ClothSession
 from .session_cloth import ClothSession
 from .session_dis import DisSession
 from .session_dis import DisSession
-from .session_simple import SimpleSession
 from .session_sam import SamSession
 from .session_sam import SamSession
+from .session_simple import SimpleSession
 
 
 
 
 def download_model(url: str, md5: str, fname: str, path: Path):
 def download_model(url: str, md5: str, fname: str, path: Path):

+ 3 - 5
rembg/session_sam.py

@@ -1,11 +1,9 @@
 from typing import List
 from typing import List
 
 
-import numpy
 import numpy as np
 import numpy as np
 from PIL import Image
 from PIL import Image
 from PIL.Image import Image as PILImage
 from PIL.Image import Image as PILImage
 import onnxruntime as ort
 import onnxruntime as ort
-from matplotlib import pyplot as plt
 
 
 from .session_base import BaseSession
 from .session_base import BaseSession
 
 
@@ -39,7 +37,7 @@ def resize_longes_side(img: PILImage, size=1024):
     return img.resize((new_w, new_h))
     return img.resize((new_w, new_h))
 
 
 
 
-def pad_to_square(img: numpy.ndarray, size=1024):
+def pad_to_square(img: np.ndarray, size=1024):
     h, w = img.shape[:2]
     h, w = img.shape[:2]
     padh = size - h
     padh = size - h
     padw = size - w
     padw = size - w
@@ -60,7 +58,7 @@ class SamSession(BaseSession):
 
 
     def normalize(
     def normalize(
         self,
         self,
-        img: numpy.ndarray,
+        img: np.ndarray,
         mean=(123.675, 116.28, 103.53),
         mean=(123.675, 116.28, 103.53),
         std=(58.395, 57.12, 57.375),
         std=(58.395, 57.12, 57.375),
         size=(1024, 1024),
         size=(1024, 1024),
@@ -78,7 +76,7 @@ class SamSession(BaseSession):
     ) -> List[PILImage]:
     ) -> List[PILImage]:
         # Preprocess image
         # Preprocess image
         image = resize_longes_side(img)
         image = resize_longes_side(img)
-        image = numpy.array(image)
+        image = np.array(image)
         image = self.normalize(image)
         image = self.normalize(image)
         image = pad_to_square(image)
         image = pad_to_square(image)