瀏覽代碼

If PIL.Image is entered, PIL.Image will be the return value

iory 3 年之前
父節點
當前提交
fe8028229d
共有 1 個文件被更改,包括 8 次插入4 次删除
  1. 8 4
      rembg/bg.py

+ 8 - 4
rembg/bg.py

@@ -1,6 +1,5 @@
 import io
-from typing import Optional
-from typing import Union
+from typing import Optional, Union
 
 import numpy as np
 import onnxruntime as ort
@@ -74,13 +73,15 @@ def remove(
     alpha_matting_erode_size: int = 10,
     session: Optional[ort.InferenceSession] = None,
     only_mask: bool = False,
-) -> bytes:
+) -> Union[bytes, PILImage]:
+    return_type = "bytes"
     if isinstance(data, PILImage):
+        return_type = "pillow"
         img = data.convert("RGB")
     elif isinstance(data, bytes):
         img = Image.open(io.BytesIO(data)).convert("RGB")
     else:
-        raise ValueError('Input type {} is not supported.'.format(type(data)))
+        raise ValueError("Input type {} is not supported.".format(type(data)))
 
     if session is None:
         session = ort_session("u2net")
@@ -105,6 +106,9 @@ def remove(
     else:
         cutout = naive_cutout(img, mask)
 
+    if return_type == "pillow":
+        return cutout
+
     bio = io.BytesIO()
     cutout.save(bio, "PNG")
     bio.seek(0)