Daniel Gatis 3 years ago
parent
commit
034bac2388
2 changed files with 27 additions and 25 deletions
  1. 21 19
      rembg/bg.py
  2. 6 6
      rembg/detect.py

+ 21 - 19
rembg/bg.py

@@ -1,6 +1,8 @@
 import io
+from typing import Optional
 
 import numpy as np
+import onnxruntime as ort
 from PIL import Image
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
 from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
@@ -11,13 +13,13 @@ from .detect import ort_session, predict
 
 
 def alpha_matting_cutout(
-    img,
-    mask,
-    foreground_threshold,
-    background_threshold,
-    erode_structure_size,
-    base_size,
-):
+    img: Image,
+    mask: Image,
+    foreground_threshold: int,
+    background_threshold: int,
+    erode_structure_size: int,
+    base_size: int,
+) -> Image:
     size = img.size
 
     img.thumbnail((base_size, base_size), Image.LANCZOS)
@@ -61,13 +63,13 @@ def alpha_matting_cutout(
     return cutout
 
 
-def naive_cutout(img, mask):
+def naive_cutout(img: Image, mask: Image) -> Image:
     empty = Image.new("RGBA", (img.size), 0)
     cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
     return cutout
 
 
-def resize_image(img, width, height):
+def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Image:
     original_width, original_height = img.size
     width = original_width if width is None else width
     height = original_height if height is None else height
@@ -79,16 +81,16 @@ def resize_image(img, width, height):
 
 
 def remove(
-    data,
-    session=None,
-    alpha_matting=False,
-    alpha_matting_foreground_threshold=240,
-    alpha_matting_background_threshold=10,
-    alpha_matting_erode_structure_size=10,
-    alpha_matting_base_size=1000,
-    width=None,
-    height=None,
-):
+    data: bytes,
+    session: Optional[ort.InferenceSession] = None,
+    alpha_matting: bool = False,
+    alpha_matting_foreground_threshold: int = 240,
+    alpha_matting_background_threshold: int = 10,
+    alpha_matting_erode_structure_size: int = 10,
+    alpha_matting_base_size: int = 1000,
+    width: Optional[int] = None,
+    height: Optional[int] = None,
+) -> bytes:
     img = Image.open(io.BytesIO(data)).convert("RGB")
     if width is not None or height is not None:
         img = resize_image(img, width, height)

+ 6 - 6
rembg/detect.py

@@ -8,7 +8,7 @@ from PIL import Image
 from skimage import transform
 
 
-def ort_session(model_name):
+def ort_session(model_name: str) -> ort.InferenceSession:
     path = os.environ.get(
         "U2NETP_PATH",
         os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
@@ -30,7 +30,7 @@ def ort_session(model_name):
     return ort.InferenceSession(path)
 
 
-def norm_pred(d):
+def norm_pred(d: np.array) -> np.array:
     ma = np.max(d)
     mi = np.min(d)
     dn = (d - mi) / (ma - mi)
@@ -38,7 +38,7 @@ def norm_pred(d):
     return dn
 
 
-def rescale(sample, output_size):
+def rescale(sample: dict, output_size: int) -> dict:
     imidx, image, label = sample["imidx"], sample["image"], sample["label"]
 
     h, w = image.shape[:2]
@@ -65,7 +65,7 @@ def rescale(sample, output_size):
     return {"imidx": imidx, "image": img, "label": lbl}
 
 
-def color(sample):
+def color(sample: dict) -> dict:
     imidx, image, label = sample["imidx"], sample["image"], sample["label"]
 
     tmpLbl = np.zeros(label.shape)
@@ -93,7 +93,7 @@ def color(sample):
     return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
 
 
-def preprocess(image):
+def preprocess(image: np.array) -> dict:
     label_3 = np.zeros(image.shape)
     label = np.zeros(label_3.shape[0:2])
 
@@ -115,7 +115,7 @@ def preprocess(image):
     return sample
 
 
-def predict(ort_session, item):
+def predict(ort_session: ort.InferenceSession, item: np.array) -> Image:
     sample = preprocess(item)
     inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)