Daniel Gatis 3 éve
szülő
commit
77ce4d7c4e
4 módosított fájl, 13 hozzáadás és 10 törlés
  1. 6 6
      rembg/bg.py
  2. 3 2
      rembg/session_base.py
  3. 2 1
      rembg/session_cloth.py
  4. 2 1
      rembg/session_simple.py

+ 6 - 6
rembg/bg.py

@@ -21,12 +21,12 @@ class ReturnType(Enum):
 
 
 def alpha_matting_cutout(
-    img: Image,
-    mask: Image,
+    img: PILImage,
+    mask: PILImage,
     foreground_threshold: int,
     background_threshold: int,
     erode_structure_size: int,
-) -> Image:
+) -> PILImage:
     img = np.asarray(img)
     mask = np.asarray(mask)
 
@@ -59,20 +59,20 @@ def alpha_matting_cutout(
     return cutout
 
 
-def naive_cutout(img: Image, mask: Image) -> Image:
+def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
     empty = Image.new("RGBA", (img.size), 0)
     cutout = Image.composite(img, empty, mask)
     return cutout
 
 
-def get_concat_v_multi(imgs: List[Image]) -> Image:
+def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
     pivot = imgs.pop(0)
     for im in imgs:
         pivot = get_concat_v(pivot, im)
     return pivot
 
 
-def get_concat_v(img1: Image, img2: Image) -> Image:
+def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
     dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
     dst.paste(img1, (0, 0))
     dst.paste(img2, (0, img1.height))

+ 3 - 2
rembg/session_base.py

@@ -3,6 +3,7 @@ from typing import Dict, List, Tuple
 import numpy as np
 import onnxruntime as ort
 from PIL import Image
+from PIL.Image import Image as PILImage
 
 
 class BaseSession:
@@ -12,7 +13,7 @@ class BaseSession:
 
     def normalize(
         self,
-        img: Image,
+        img: PILImage,
         mean: Tuple[float, float, float],
         std: Tuple[float, float, float],
         size: Tuple[int, int],
@@ -35,5 +36,5 @@ class BaseSession:
             .astype(np.float32)
         }
 
-    def predict(self, im: Image) -> List[Image]:
+    def predict(self, img: PILImage) -> List[PILImage]:
         raise NotImplementedError

+ 2 - 1
rembg/session_cloth.py

@@ -2,6 +2,7 @@ from typing import List
 
 import numpy as np
 from PIL import Image
+from PIL.Image import Image as PILImage
 from scipy.special import log_softmax
 
 from .session_base import BaseSession
@@ -53,7 +54,7 @@ pallete3 = [
 
 
 class ClothSession(BaseSession):
-    def predict(self, img: Image) -> List[Image]:
+    def predict(self, img: PILImage) -> List[PILImage]:
         ort_outs = self.inner_session.run(
             None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
         )

+ 2 - 1
rembg/session_simple.py

@@ -2,12 +2,13 @@ from typing import List
 
 import numpy as np
 from PIL import Image
+from PIL.Image import Image as PILImage
 
 from .session_base import BaseSession
 
 
 class SimpleSession(BaseSession):
-    def predict(self, img: Image) -> List[Image]:
+    def predict(self, img: PILImage) -> List[PILImage]:
         ort_outs = self.inner_session.run(
             None,
             self.normalize(