|
@@ -17,6 +17,7 @@ from .detect import ort_session, predict
|
|
|
class ReturnType(Enum):
|
|
|
BYTES = 0
|
|
|
PILLOW = 1
|
|
|
+ NDARRAY = 2
|
|
|
|
|
|
|
|
|
def alpha_matting_cutout(
|
|
@@ -65,27 +66,31 @@ def naive_cutout(img: Image, mask: Image) -> Image:
|
|
|
|
|
|
|
|
|
def remove(
|
|
|
- data: Union[bytes, PILImage],
|
|
|
+ data: Union[bytes, PILImage, np.ndarray],
|
|
|
alpha_matting: bool = False,
|
|
|
alpha_matting_foreground_threshold: int = 240,
|
|
|
alpha_matting_background_threshold: int = 10,
|
|
|
alpha_matting_erode_size: int = 10,
|
|
|
session: Optional[ort.InferenceSession] = None,
|
|
|
only_mask: bool = False,
|
|
|
-) -> Union[bytes, PILImage]:
|
|
|
+) -> Union[bytes, PILImage, np.ndarray]:
|
|
|
+
|
|
|
if isinstance(data, PILImage):
|
|
|
return_type = ReturnType.PILLOW
|
|
|
- img = data.convert("RGB")
|
|
|
+ img = data
|
|
|
elif isinstance(data, bytes):
|
|
|
return_type = ReturnType.BYTES
|
|
|
- img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
|
+ img = Image.open(io.BytesIO(data))
|
|
|
+ elif isinstance(data, np.ndarray):
|
|
|
+ return_type = ReturnType.NDARRAY
|
|
|
+ img = Image.fromarray(data)
|
|
|
else:
|
|
|
raise ValueError("Input type {} is not supported.".format(type(data)))
|
|
|
|
|
|
if session is None:
|
|
|
session = ort_session("u2net")
|
|
|
|
|
|
- mask = predict(session, np.array(img)).convert("L")
|
|
|
+ mask = predict(session, np.array(img.convert("RGB"))).convert("L")
|
|
|
mask = mask.resize(img.size, Image.LANCZOS)
|
|
|
|
|
|
if only_mask:
|
|
@@ -105,6 +110,9 @@ def remove(
|
|
|
if ReturnType.PILLOW == return_type:
|
|
|
return cutout
|
|
|
|
|
|
+ if ReturnType.NDARRAY == return_type:
|
|
|
+ return np.asarray(cutout)
|
|
|
+
|
|
|
bio = io.BytesIO()
|
|
|
cutout.save(bio, "PNG")
|
|
|
bio.seek(0)
|