|
@@ -105,6 +105,14 @@ def post_process(mask: np.ndarray) -> np.ndarray:
|
|
|
return mask
|
|
|
|
|
|
|
|
|
+def apply_background_color(img: PILImage, color: List[int]) -> PILImage:
|
|
|
+ r, g, b = color
|
|
|
+ colored_image = Image.new("RGBA", img.size, (r, g, b, 255))
|
|
|
+ colored_image.paste(img, mask=img)
|
|
|
+
|
|
|
+ return colored_image
|
|
|
+
|
|
|
+
|
|
|
def remove(
|
|
|
data: Union[bytes, PILImage, np.ndarray],
|
|
|
alpha_matting: bool = False,
|
|
@@ -114,6 +122,7 @@ def remove(
|
|
|
session: Optional[BaseSession] = None,
|
|
|
only_mask: bool = False,
|
|
|
post_process_mask: bool = False,
|
|
|
+ color: Optional[List[int]] = None,
|
|
|
) -> Union[bytes, PILImage, np.ndarray]:
|
|
|
if isinstance(data, PILImage):
|
|
|
return_type = ReturnType.PILLOW
|
|
@@ -161,6 +170,9 @@ def remove(
|
|
|
if len(cutouts) > 0:
|
|
|
cutout = get_concat_v_multi(cutouts)
|
|
|
|
|
|
+ if color is not None:
|
|
|
+ cutout = apply_background_color(cutout, color)
|
|
|
+
|
|
|
if ReturnType.PILLOW == return_type:
|
|
|
return cutout
|
|
|
|