|
@@ -1,9 +1,11 @@
|
|
import io
|
|
import io
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
|
+from typing import Union
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import onnxruntime as ort
|
|
from PIL import Image
|
|
from PIL import Image
|
|
|
|
+from PIL.Image import Image as PILImage
|
|
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
|
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
|
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
from pymatting.util.util import stack_images
|
|
from pymatting.util.util import stack_images
|
|
@@ -65,7 +67,7 @@ def naive_cutout(img: Image, mask: Image) -> Image:
|
|
|
|
|
|
|
|
|
|
def remove(
|
|
def remove(
|
|
- data: bytes,
|
|
|
|
|
|
+ data: Union[bytes, PILImage],
|
|
alpha_matting: bool = False,
|
|
alpha_matting: bool = False,
|
|
alpha_matting_foreground_threshold: int = 240,
|
|
alpha_matting_foreground_threshold: int = 240,
|
|
alpha_matting_background_threshold: int = 10,
|
|
alpha_matting_background_threshold: int = 10,
|
|
@@ -73,7 +75,12 @@ def remove(
|
|
session: Optional[ort.InferenceSession] = None,
|
|
session: Optional[ort.InferenceSession] = None,
|
|
only_mask: bool = False,
|
|
only_mask: bool = False,
|
|
) -> bytes:
|
|
) -> bytes:
|
|
- img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
|
|
|
|
+ if isinstance(data, PILImage):
|
|
|
|
+ img = data.convert("RGB")
|
|
|
|
+ elif isinstance(data, bytes):
|
|
|
|
+ img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError('Input type {} is not supported.'.format(type(data)))
|
|
|
|
|
|
if session is None:
|
|
if session is None:
|
|
session = ort_session("u2net")
|
|
session = ort_session("u2net")
|