2
0
Эх сурвалжийг харах

If PIL.Image is entered, PIL.Image will be the return value

iory 3 жил өмнө
parent
commit
fe8028229d
1 өөрчлөгдсөн 8 нэмэгдсэн , 4 устгасан
  1. 8 4
      rembg/bg.py

+ 8 - 4
rembg/bg.py

@@ -1,6 +1,5 @@
 import io
 import io
-from typing import Optional
-from typing import Union
+from typing import Optional, Union
 
 
 import numpy as np
 import numpy as np
 import onnxruntime as ort
 import onnxruntime as ort
@@ -74,13 +73,15 @@ def remove(
     alpha_matting_erode_size: int = 10,
     alpha_matting_erode_size: int = 10,
     session: Optional[ort.InferenceSession] = None,
     session: Optional[ort.InferenceSession] = None,
     only_mask: bool = False,
     only_mask: bool = False,
-) -> bytes:
+) -> Union[bytes, PILImage]:
+    return_type = "bytes"
     if isinstance(data, PILImage):
     if isinstance(data, PILImage):
+        return_type = "pillow"
         img = data.convert("RGB")
         img = data.convert("RGB")
     elif isinstance(data, bytes):
     elif isinstance(data, bytes):
         img = Image.open(io.BytesIO(data)).convert("RGB")
         img = Image.open(io.BytesIO(data)).convert("RGB")
     else:
     else:
-        raise ValueError('Input type {} is not supported.'.format(type(data)))
+        raise ValueError("Input type {} is not supported.".format(type(data)))
 
 
     if session is None:
     if session is None:
         session = ort_session("u2net")
         session = ort_session("u2net")
@@ -105,6 +106,9 @@ def remove(
     else:
     else:
         cutout = naive_cutout(img, mask)
         cutout = naive_cutout(img, mask)
 
 
+    if return_type == "pillow":
+        return cutout
+
     bio = io.BytesIO()
     bio = io.BytesIO()
     cutout.save(bio, "PNG")
     cutout.save(bio, "PNG")
     bio.seek(0)
     bio.seek(0)