Daniel Gatis 3 лет назад
Родитель
Сommit
b9e16e0e54
3 измененных файлов с 40 добавлено и 11 удалено
  1. 20 11
      rembg/bg.py
  2. 19 0
      rembg/cli.py
  3. 1 0
      requirements.txt

+ 20 - 11
rembg/bg.py

@@ -4,7 +4,14 @@ from typing import List, Optional, Union
 
 import numpy as np
 from PIL import Image
-from cv2 import getStructuringElement, morphologyEx, GaussianBlur, MORPH_OPEN, MORPH_ELLIPSE, BORDER_DEFAULT
+from cv2 import (
+    getStructuringElement,
+    morphologyEx,
+    GaussianBlur,
+    MORPH_OPEN,
+    MORPH_ELLIPSE,
+    BORDER_DEFAULT,
+)
 from PIL.Image import Image as PILImage
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
 from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
@@ -14,7 +21,8 @@ from scipy.ndimage.morphology import binary_erosion
 from .session_base import BaseSession
 from .session_factory import new_session
 
-kernel = getStructuringElement(MORPH_ELLIPSE,(3,3)) # to save API calls, it has been declared global 
+kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
+
 
 class ReturnType(Enum):
     BYTES = 0
@@ -81,16 +89,16 @@ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
     return dst
 
 
-def post_process_mask(mask:np.ndarray)->np.ndarray:
-    '''
+def post_process(mask: np.ndarray) -> np.ndarray:
+    """
     Post Process the mask for a smooth boundary by applying Morphological Operations
     Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
     args:
         mask: Binary Numpy Mask
-    '''
-    mask = morphologyEx(mask,MORPH_OPEN,kernel)
-    mask = GaussianBlur(mask, (5,5), sigmaX = 2, sigmaY = 2, borderType = BORDER_DEFAULT) # Blur
-    mask = np.where( mask < 127, 0, 255).astype(np.uint8) # convert again to binary
+    """
+    mask = morphologyEx(mask, MORPH_OPEN, kernel)
+    mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
+    mask = np.where(mask < 127, 0, 255).astype(np.uint8)  # convert again to binary
     return mask
 
 
@@ -102,7 +110,7 @@ def remove(
     alpha_matting_erode_size: int = 10,
     session: Optional[BaseSession] = None,
     only_mask: bool = False,
-    post_process:bool = True
+    post_process_mask: bool = False,
 ) -> Union[bytes, PILImage, np.ndarray]:
 
     if isinstance(data, PILImage):
@@ -124,8 +132,9 @@ def remove(
     cutouts = []
 
     for mask in masks:
-        if post_process:
-            mask = Image.fromarray(post_process_mask(np.array(mask))) # Apply post processing to mask
+        if post_process_mask:
+            mask = Image.fromarray(post_process(np.array(mask)))
+
         if only_mask:
             cutout = mask
 

+ 19 - 0
rembg/cli.py

@@ -76,6 +76,13 @@ def main() -> None:
     show_default=True,
     help="output only the mask",
 )
[email protected](
+    "-ppm",
+    "--post-process-mask",
+    is_flag=True,
+    show_default=True,
+    help="post process the mask",
+)
 @click.argument(
     "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
 )
@@ -136,6 +143,13 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
     show_default=True,
     help="output only the mask",
 )
[email protected](
+    "-ppm",
+    "--post-process-mask",
+    is_flag=True,
+    show_default=True,
+    help="post process the mask",
+)
 @click.option(
     "-w",
     "--watch",
@@ -309,6 +323,7 @@ def s(port: int, log_level: str) -> None:
                 default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
             ),
             om: bool = Query(default=False, description="Only Mask"),
+            ppm: bool = Query(default=False, description="Post Process Mask"),
         ):
             self.model = model
             self.a = a
@@ -316,6 +331,7 @@ def s(port: int, log_level: str) -> None:
             self.ab = ab
             self.ae = ae
             self.om = om
+            self.ppm = ppm
 
     class CommonQueryPostParams:
         def __init__(
@@ -341,6 +357,7 @@ def s(port: int, log_level: str) -> None:
                 default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
             ),
             om: bool = Form(default=False, description="Only Mask"),
+            ppm: bool = Form(default=False, description="Post Process Mask"),
         ):
             self.model = model
             self.a = a
@@ -348,6 +365,7 @@ def s(port: int, log_level: str) -> None:
             self.ab = ab
             self.ae = ae
             self.om = om
+            self.ppm = ppm
 
     def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
         return Response(
@@ -361,6 +379,7 @@ def s(port: int, log_level: str) -> None:
                 alpha_matting_background_threshold=commons.ab,
                 alpha_matting_erode_size=commons.ae,
                 only_mask=commons.om,
+                post_process_mask=commons.ppm,
             ),
             media_type="image/png",
         )

+ 1 - 0
requirements.txt

@@ -6,6 +6,7 @@ filetype==1.0.9
 gdown==4.5.1
 numpy==1.22.3
 onnxruntime==1.10.0
+opencv-python==4.6.0.66
 pillow==9.0.1
 pymatting==1.1.7
 python-multipart==0.0.5