Quellcode durchsuchen

Merge pull request #268 from deshwalmahesh/main

added post processing
Daniel Gatis vor 3 Jahren
Ursprung
Commit
5e3d2e6f01
1 geänderte Dateien mit 18 neuen und 0 gelöschten Zeilen
  1. 18 0
      rembg/bg.py

+ 18 - 0
rembg/bg.py

@@ -4,6 +4,7 @@ from typing import List, Optional, Union
 
 import numpy as np
 from PIL import Image
+from cv2 import getStructuringElement, morphologyEx, GaussianBlur, MORPH_OPEN, MORPH_ELLIPSE, BORDER_DEFAULT
 from PIL.Image import Image as PILImage
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
 from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
@@ -13,6 +14,7 @@ from scipy.ndimage.morphology import binary_erosion
 from .session_base import BaseSession
 from .session_factory import new_session
 
+kernel = getStructuringElement(MORPH_ELLIPSE,(3,3)) # to save API calls, it has been declared global 
 
 class ReturnType(Enum):
     BYTES = 0
@@ -79,6 +81,19 @@ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
     return dst
 
 
+def post_process_mask(mask:np.ndarray)->np.ndarray:
+    '''
+    Post Process the mask for a smooth boundary by applying Morphological Operations
+    Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
+    args:
+        mask: Binary Numpy Mask
+    '''
+    mask = morphologyEx(mask,MORPH_OPEN,kernel)
+    mask = GaussianBlur(mask, (5,5), sigmaX = 2, sigmaY = 2, borderType = BORDER_DEFAULT) # Blur
+    mask = np.where( mask < 127, 0, 255).astype(np.uint8) # convert again to binary
+    return mask
+
+
 def remove(
     data: Union[bytes, PILImage, np.ndarray],
     alpha_matting: bool = False,
@@ -87,6 +102,7 @@ def remove(
     alpha_matting_erode_size: int = 10,
     session: Optional[BaseSession] = None,
     only_mask: bool = False,
+    post_process:bool = True
 ) -> Union[bytes, PILImage, np.ndarray]:
 
     if isinstance(data, PILImage):
@@ -108,6 +124,8 @@ def remove(
     cutouts = []
 
     for mask in masks:
+        if post_process:
+            mask = Image.fromarray(post_process_mask(np.array(mask))) # Apply post processing to mask
         if only_mask:
             cutout = mask