iory преди 3 години
родител
ревизия
194215c792
променени са 1 файла, в които са добавени 9 реда и са изтрити 2 реда
  1. 9 2
      rembg/bg.py

+ 9 - 2
rembg/bg.py

@@ -1,9 +1,11 @@
 import io
 import io
 from typing import Optional
 from typing import Optional
+from typing import Union
 
 
 import numpy as np
 import numpy as np
 import onnxruntime as ort
 import onnxruntime as ort
 from PIL import Image
 from PIL import Image
+from PIL.Image import Image as PILImage
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
 from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
 from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
 from pymatting.util.util import stack_images
 from pymatting.util.util import stack_images
@@ -65,7 +67,7 @@ def naive_cutout(img: Image, mask: Image) -> Image:
 
 
 
 
 def remove(
 def remove(
-    data: bytes,
+    data: Union[bytes, PILImage],
     alpha_matting: bool = False,
     alpha_matting: bool = False,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_foreground_threshold: int = 240,
     alpha_matting_background_threshold: int = 10,
     alpha_matting_background_threshold: int = 10,
@@ -73,7 +75,12 @@ def remove(
     session: Optional[ort.InferenceSession] = None,
     session: Optional[ort.InferenceSession] = None,
     only_mask: bool = False,
     only_mask: bool = False,
 ) -> bytes:
 ) -> bytes:
-    img = Image.open(io.BytesIO(data)).convert("RGB")
+    if isinstance(data, PILImage):
+        img = data.convert("RGB")
+    elif isinstance(data, bytes):
+        img = Image.open(io.BytesIO(data)).convert("RGB")
+    else:
+        raise ValueError('Input type {} is not supported.'.format(type(data)))
 
 
     if session is None:
     if session is None:
         session = ort_session("u2net")
         session = ort_session("u2net")