|
@@ -1,6 +1,5 @@
|
|
|
import io
|
|
|
-from typing import Optional
|
|
|
-from typing import Union
|
|
|
+from typing import Optional, Union
|
|
|
|
|
|
import numpy as np
|
|
|
import onnxruntime as ort
|
|
@@ -74,13 +73,15 @@ def remove(
|
|
|
alpha_matting_erode_size: int = 10,
|
|
|
session: Optional[ort.InferenceSession] = None,
|
|
|
only_mask: bool = False,
|
|
|
-) -> bytes:
|
|
|
+) -> Union[bytes, PILImage]:
|
|
|
+ return_type = "bytes"
|
|
|
if isinstance(data, PILImage):
|
|
|
+ return_type = "pillow"
|
|
|
img = data.convert("RGB")
|
|
|
elif isinstance(data, bytes):
|
|
|
img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
|
else:
|
|
|
- raise ValueError('Input type {} is not supported.'.format(type(data)))
|
|
|
+ raise ValueError("Input type {} is not supported.".format(type(data)))
|
|
|
|
|
|
if session is None:
|
|
|
session = ort_session("u2net")
|
|
@@ -105,6 +106,9 @@ def remove(
|
|
|
else:
|
|
|
cutout = naive_cutout(img, mask)
|
|
|
|
|
|
+ if return_type == "pillow":
|
|
|
+ return cutout
|
|
|
+
|
|
|
bio = io.BytesIO()
|
|
|
cutout.save(bio, "PNG")
|
|
|
bio.seek(0)
|