Explorar el Código

Merge pull request #197 from iory/main

Enable PIL.Image input
Daniel Gatis hace 3 años
padre
commit
7627254b51
Se han modificado 1 ficheros con 15 adiciones y 4 borrados
  1. 15 4
      rembg/bg.py

+ 15 - 4
rembg/bg.py

@@ -1,9 +1,10 @@
 import io
 import io
-from typing import Optional
+from typing import Optional, Union
 
 
 import numpy as np
 import numpy as np
 import onnxruntime as ort
 import onnxruntime as ort
 from PIL import Image
 from PIL import Image
+from PIL.Image import Image as PILImage
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
 from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
 from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
 from pymatting.util.util import stack_images
 from pymatting.util.util import stack_images
@@ -65,15 +66,22 @@ def naive_cutout(img: Image, mask: Image) -> Image:
 
 
 
 
 def remove(
 def remove(
-    data: bytes,
+    data: Union[bytes, PILImage],
     alpha_matting: bool = False,
     alpha_matting: bool = False,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_erode_size: int = 10,
     alpha_matting_erode_size: int = 10,
     session: Optional[ort.InferenceSession] = None,
     session: Optional[ort.InferenceSession] = None,
     only_mask: bool = False,
     only_mask: bool = False,
-) -> bytes:
-    img = Image.open(io.BytesIO(data)).convert("RGB")
+) -> Union[bytes, PILImage]:
+    return_type = "bytes"
+    if isinstance(data, PILImage):
+        return_type = "pillow"
+        img = data.convert("RGB")
+    elif isinstance(data, bytes):
+        img = Image.open(io.BytesIO(data)).convert("RGB")
+    else:
+        raise ValueError("Input type {} is not supported.".format(type(data)))
 
 
     if session is None:
     if session is None:
         session = ort_session("u2net")
         session = ort_session("u2net")
@@ -98,6 +106,9 @@ def remove(
     else:
     else:
         cutout = naive_cutout(img, mask)
         cutout = naive_cutout(img, mask)
 
 
+    if return_type == "pillow":
+        return cutout
+
     bio = io.BytesIO()
     bio = io.BytesIO()
     cutout.save(bio, "PNG")
     cutout.save(bio, "PNG")
     bio.seek(0)
     bio.seek(0)