|
@@ -2,6 +2,10 @@ import io
|
|
|
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
+from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
|
|
+from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
|
+from pymatting.util.util import stack_images
|
|
|
+from scipy.ndimage.morphology import binary_erosion
|
|
|
|
|
|
from .u2net import detect
|
|
|
|
|
@@ -9,20 +13,87 @@ model_u2net = detect.load_model(model_name="u2net")
|
|
|
model_u2netp = detect.load_model(model_name="u2netp")
|
|
|
|
|
|
|
|
|
-def remove(data, model_name="u2net"):
|
|
|
+def alpha_matting_cutout(
|
|
|
+ img, mask, foreground_threshold, background_threshold, erode_structure_size,
|
|
|
+):
|
|
|
+ base_size = (1000, 1000)
|
|
|
+ size = img.size
|
|
|
+
|
|
|
+ img.thumbnail(base_size, Image.LANCZOS)
|
|
|
+ mask = mask.resize(img.size, Image.LANCZOS)
|
|
|
+
|
|
|
+ img = np.asarray(img)
|
|
|
+ mask = np.asarray(mask)
|
|
|
+
|
|
|
+ # guess likely foreground/background
|
|
|
+ is_foreground = mask > foreground_threshold
|
|
|
+ is_background = mask < background_threshold
|
|
|
+
|
|
|
+ # erode foreground/background
|
|
|
+ structure = None
|
|
|
+ if erode_structure_size > 0:
|
|
|
+ structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
|
|
|
+
|
|
|
+ is_foreground = binary_erosion(is_foreground, structure=structure)
|
|
|
+ is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
|
|
+
|
|
|
+ # build trimap
|
|
|
+ # 0 = background
|
|
|
+ # 128 = unknown
|
|
|
+ # 255 = foreground
|
|
|
+ trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
|
|
|
+ trimap[is_foreground] = 255
|
|
|
+ trimap[is_background] = 0
|
|
|
+
|
|
|
+ # build the cutout image
|
|
|
+ img_normalized = img / 255.0
|
|
|
+ trimap_normalized = trimap / 255.0
|
|
|
+
|
|
|
+ alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
|
|
+ foreground = estimate_foreground_ml(img_normalized, alpha)
|
|
|
+ cutout = stack_images(foreground, alpha)
|
|
|
+
|
|
|
+ cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
|
|
|
+ cutout = Image.fromarray(cutout)
|
|
|
+ cutout = cutout.resize(size, Image.LANCZOS)
|
|
|
+
|
|
|
+ return cutout
|
|
|
+
|
|
|
+
|
|
|
+def naive_cutout(img, mask):
|
|
|
+ empty = Image.new("RGBA", (img.size), 0)
|
|
|
+ cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
|
|
|
+ return cutout
|
|
|
+
|
|
|
+
|
|
|
+def remove(
|
|
|
+ data,
|
|
|
+ model_name="u2net",
|
|
|
+ alpha_matting=False,
|
|
|
+ alpha_matting_foreground_threshold=235,
|
|
|
+ alpha_matting_background_threshold=15,
|
|
|
+ alpha_matting_erode_structure_size=15,
|
|
|
+):
|
|
|
model = model_u2net
|
|
|
|
|
|
if model == "u2netp":
|
|
|
model = model_u2netp
|
|
|
|
|
|
- img = Image.open(io.BytesIO(data))
|
|
|
- roi = detect.predict(model, np.array(img))
|
|
|
- roi = roi.resize((img.size), resample=Image.LANCZOS)
|
|
|
+ img = Image.open(io.BytesIO(data)).convert("RGB")
|
|
|
+ mask = detect.predict(model, np.array(img)).convert("L")
|
|
|
|
|
|
- empty = Image.new("RGBA", (img.size), 0)
|
|
|
- out = Image.composite(img, empty, roi.convert("L"))
|
|
|
+ if alpha_matting:
|
|
|
+ cutout = alpha_matting_cutout(
|
|
|
+ img,
|
|
|
+ mask,
|
|
|
+ alpha_matting_foreground_threshold,
|
|
|
+ alpha_matting_background_threshold,
|
|
|
+ alpha_matting_erode_structure_size,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cutout = naive_cutout(img, mask)
|
|
|
|
|
|
bio = io.BytesIO()
|
|
|
- out.save(bio, "PNG")
|
|
|
+ cutout.save(bio, "PNG")
|
|
|
|
|
|
return bio.getbuffer()
|