2
0
Daniel Gatis 2 жил өмнө
parent
commit
54cf4f8c11
2 өөрчлөгдсөн 23 нэмэгдсэн , 19 устгасан
  1. 7 7
      rembg/bg.py
  2. 16 12
      rembg/cli.py

+ 7 - 7
rembg/bg.py

@@ -1,6 +1,6 @@
 import io
 import io
 from enum import Enum
 from enum import Enum
-from typing import List, Optional, Union
+from typing import List, Optional, Tuple, Union
 
 
 import numpy as np
 import numpy as np
 from cv2 import (
 from cv2 import (
@@ -105,9 +105,9 @@ def post_process(mask: np.ndarray) -> np.ndarray:
     return mask
     return mask
 
 
 
 
-def apply_background_color(img: PILImage, color: List[int]) -> PILImage:
-    r, g, b = color
-    colored_image = Image.new("RGBA", img.size, (r, g, b, 255))
+def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
+    r, g, b, a = color
+    colored_image = Image.new("RGBA", img.size, (r, g, b, a))
     colored_image.paste(img, mask=img)
     colored_image.paste(img, mask=img)
 
 
     return colored_image
     return colored_image
@@ -122,7 +122,7 @@ def remove(
     session: Optional[BaseSession] = None,
     session: Optional[BaseSession] = None,
     only_mask: bool = False,
     only_mask: bool = False,
     post_process_mask: bool = False,
     post_process_mask: bool = False,
-    color: Optional[List[int]] = None,
+    bgcolor: Optional[Tuple[int, int, int, int]] = None,
 ) -> Union[bytes, PILImage, np.ndarray]:
 ) -> Union[bytes, PILImage, np.ndarray]:
     if isinstance(data, PILImage):
     if isinstance(data, PILImage):
         return_type = ReturnType.PILLOW
         return_type = ReturnType.PILLOW
@@ -170,8 +170,8 @@ def remove(
     if len(cutouts) > 0:
     if len(cutouts) > 0:
         cutout = get_concat_v_multi(cutouts)
         cutout = get_concat_v_multi(cutouts)
 
 
-    if color is not None:
-        cutout = apply_background_color(cutout, color)
+    if bgcolor is not None and not only_mask:
+        cutout = apply_background_color(cutout, bgcolor)
 
 
     if ReturnType.PILLOW == return_type:
     if ReturnType.PILLOW == return_type:
         return cutout
         return cutout

+ 16 - 12
rembg/cli.py

@@ -2,7 +2,7 @@ import pathlib
 import sys
 import sys
 import time
 import time
 from enum import Enum
 from enum import Enum
-from typing import IO, cast
+from typing import IO, Optional, cast
 
 
 import aiohttp
 import aiohttp
 import click
 import click
@@ -93,12 +93,12 @@ def main() -> None:
     help="post process the mask",
     help="post process the mask",
 )
 )
 @click.option(
 @click.option(
-    "-c",
-    "--color",
+    "-bgc",
+    "--bgcolor",
     default=None,
     default=None,
-    nargs=3,
-    type=int,
-    help="Background color (R G B) to replace the removed background with",
+    type=(int, int, int, int),
+    nargs=4,
+    help="Background color (R G B A) to replace the removed background with",
 )
 )
 @click.argument(
 @click.argument(
     "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
     "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
@@ -185,13 +185,12 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
     help="watches a folder for changes",
     help="watches a folder for changes",
 )
 )
 @click.option(
 @click.option(
-    "-c",
-    "--color",
+    "-bgc",
+    "--bgcolor",
     default=None,
     default=None,
-    type=(int, int, int),
-    nargs=3,
-    metavar="R G B",
-    help="background color (RGB) to replace removed areas",
+    type=(int, int, int, int),
+    nargs=4,
+    help="Background color (R G B A) to replace the removed background with",
 )
 )
 @click.argument(
 @click.argument(
     "input",
     "input",
@@ -369,6 +368,7 @@ def s(port: int, log_level: str, threads: int) -> None:
             ),
             ),
             om: bool = Query(default=False, description="Only Mask"),
             om: bool = Query(default=False, description="Only Mask"),
             ppm: bool = Query(default=False, description="Post Process Mask"),
             ppm: bool = Query(default=False, description="Post Process Mask"),
+            bgc: Optional[str] = Query(default=None, description="Background Color"),
         ):
         ):
             self.model = model
             self.model = model
             self.a = a
             self.a = a
@@ -377,6 +377,7 @@ def s(port: int, log_level: str, threads: int) -> None:
             self.ae = ae
             self.ae = ae
             self.om = om
             self.om = om
             self.ppm = ppm
             self.ppm = ppm
+            self.bgc = map(int, bgc.split(",")) if bgc else None
 
 
     class CommonQueryPostParams:
     class CommonQueryPostParams:
         def __init__(
         def __init__(
@@ -403,6 +404,7 @@ def s(port: int, log_level: str, threads: int) -> None:
             ),
             ),
             om: bool = Form(default=False, description="Only Mask"),
             om: bool = Form(default=False, description="Only Mask"),
             ppm: bool = Form(default=False, description="Post Process Mask"),
             ppm: bool = Form(default=False, description="Post Process Mask"),
+            bgc: Optional[str] = Query(default=None, description="Background Color"),
         ):
         ):
             self.model = model
             self.model = model
             self.a = a
             self.a = a
@@ -411,6 +413,7 @@ def s(port: int, log_level: str, threads: int) -> None:
             self.ae = ae
             self.ae = ae
             self.om = om
             self.om = om
             self.ppm = ppm
             self.ppm = ppm
+            self.bgc = map(int, bgc.split(",")) if bgc else None
 
 
     def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
     def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
         return Response(
         return Response(
@@ -425,6 +428,7 @@ def s(port: int, log_level: str, threads: int) -> None:
                 alpha_matting_erode_size=commons.ae,
                 alpha_matting_erode_size=commons.ae,
                 only_mask=commons.om,
                 only_mask=commons.om,
                 post_process_mask=commons.ppm,
                 post_process_mask=commons.ppm,
+                bgcolor=commons.bgc,
             ),
             ),
             media_type="image/png",
             media_type="image/png",
         )
         )