|
@@ -1,9 +1,10 @@
|
|
import io
|
|
import io
|
|
-from typing import Optional
|
|
|
|
|
|
+from typing import Optional, Union
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import onnxruntime as ort
|
|
from PIL import Image
|
|
from PIL import Image
|
|
|
|
+from PIL.Image import Image as PILImage
|
|
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
|
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
|
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
from pymatting.util.util import stack_images
|
|
from pymatting.util.util import stack_images
|
|
@@ -65,15 +66,22 @@ def naive_cutout(img: Image, mask: Image) -> Image:
|
|
|
|
|
|
|
|
|
|
def remove(
|
|
def remove(
|
|
- data: bytes,
|
|
|
|
|
|
+ data: Union[bytes, PILImage],
|
|
alpha_matting: bool = False,
|
|
alpha_matting: bool = False,
|
|
alpha_matting_foreground_threshold: int = 240,
|
|
alpha_matting_foreground_threshold: int = 240,
|
|
alpha_matting_background_threshold: int = 10,
|
|
alpha_matting_background_threshold: int = 10,
|
|
alpha_matting_erode_size: int = 10,
|
|
alpha_matting_erode_size: int = 10,
|
|
session: Optional[ort.InferenceSession] = None,
|
|
session: Optional[ort.InferenceSession] = None,
|
|
only_mask: bool = False,
|
|
only_mask: bool = False,
|
|
-) -> bytes:
|
|
|
|
- img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
|
|
|
|
+) -> Union[bytes, PILImage]:
|
|
|
|
+ return_type = "bytes"
|
|
|
|
+ if isinstance(data, PILImage):
|
|
|
|
+ return_type = "pillow"
|
|
|
|
+ img = data.convert("RGB")
|
|
|
|
+ elif isinstance(data, bytes):
|
|
|
|
+ img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError("Input type {} is not supported.".format(type(data)))
|
|
|
|
|
|
if session is None:
|
|
if session is None:
|
|
session = ort_session("u2net")
|
|
session = ort_session("u2net")
|
|
@@ -98,6 +106,9 @@ def remove(
|
|
else:
|
|
else:
|
|
cutout = naive_cutout(img, mask)
|
|
cutout = naive_cutout(img, mask)
|
|
|
|
|
|
|
|
+ if return_type == "pillow":
|
|
|
|
+ return cutout
|
|
|
|
+
|
|
bio = io.BytesIO()
|
|
bio = io.BytesIO()
|
|
cutout.save(bio, "PNG")
|
|
cutout.save(bio, "PNG")
|
|
bio.seek(0)
|
|
bio.seek(0)
|