浏览代码

add alpha matting

Daniel Gatis 4 年之前
父节点
当前提交
215cb3e934
共有 10 个文件被更改,包括 176 次插入22 次删除
  1. 27 0
      README.md
  2. 二进制
      examples/food-1.jpg
  3. 二进制
      examples/food-1.out.alpha.jpg
  4. 二进制
      examples/food-1.out.jpg
  5. 2 0
      requirements.txt
  6. 1 1
      setup.py
  7. 78 7
      src/rembg/bg.py
  8. 57 2
      src/rembg/cmd/cli.py
  9. 8 11
      src/rembg/cmd/server.py
  10. 3 1
      src/rembg/u2net/detect.py

+ 27 - 0
README.md

@@ -96,10 +96,37 @@ Then run
     cat input.png | python app.py > out.png
 ```
 
+### Advance usage
+
+Sometimes it is possible to achieve better results by turning on alpha matting
+```bash
+    curl -s http://input.png -a -ae 15 | rembg > output.png
+```
+
+Example:
+
+<table>
+    <thead>
+        <tr>
+            <td>Original</td>
+            <td>Without alpha matting</td>
+            <td>With alpha matting (-a -ae 15)</td>
+        </tr>
+    </thead>
+    <tbody>
+        <tr>
+            <td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.jpg" width="100" /></td>
+            <td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.out.jpg" width="100" /></td>
+            <td><img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/food-1.out.alpha.jpg" width="100" /></td>
+        </tr>
+    </tbody>
+</table>
+
 ### References
 
 - https://arxiv.org/pdf/2005.09007.pdf
 - https://github.com/NathanUA/U-2-Net
+- https://github.com/pymatting/pymatting
 
 ### License
 

二进制
examples/food-1.jpg


二进制
examples/food-1.out.alpha.jpg


二进制
examples/food-1.out.jpg


+ 2 - 0
requirements.txt

@@ -7,3 +7,5 @@ torchvision==0.7.0
 waitress==1.4.4
 tqdm==4.48.2
 requests==2.24.0
+scipy==1.5.2
+pymatting==1.0.6

+ 1 - 1
setup.py

@@ -11,7 +11,7 @@ with open("requirements.txt") as f:
 
 setup(
     name="rembg",
-    version="1.0.10",
+    version="1.0.11",
     description="Remove image background",
     long_description=long_description,
     long_description_content_type="text/markdown",

+ 78 - 7
src/rembg/bg.py

@@ -2,6 +2,10 @@ import io
 
 import numpy as np
 from PIL import Image
+from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
+from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
+from pymatting.util.util import stack_images
+from scipy.ndimage.morphology import binary_erosion
 
 from .u2net import detect
 
@@ -9,20 +13,87 @@ model_u2net = detect.load_model(model_name="u2net")
 model_u2netp = detect.load_model(model_name="u2netp")
 
 
-def remove(data, model_name="u2net"):
+def alpha_matting_cutout(
+    img, mask, foreground_threshold, background_threshold, erode_structure_size,
+):
+    base_size = (1000, 1000)
+    size = img.size
+
+    img.thumbnail(base_size, Image.LANCZOS)
+    mask = mask.resize(img.size, Image.LANCZOS)
+
+    img = np.asarray(img)
+    mask = np.asarray(mask)
+
+    # guess likely foreground/background
+    is_foreground = mask > foreground_threshold
+    is_background = mask < background_threshold
+
+    # erode foreground/background
+    structure = None
+    if erode_structure_size > 0:
+        structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int)
+
+    is_foreground = binary_erosion(is_foreground, structure=structure)
+    is_background = binary_erosion(is_background, structure=structure, border_value=1)
+
+    # build trimap
+    # 0   = background
+    # 128 = unknown
+    # 255 = foreground
+    trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
+    trimap[is_foreground] = 255
+    trimap[is_background] = 0
+
+    # build the cutout image
+    img_normalized = img / 255.0
+    trimap_normalized = trimap / 255.0
+
+    alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
+    foreground = estimate_foreground_ml(img_normalized, alpha)
+    cutout = stack_images(foreground, alpha)
+
+    cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
+    cutout = Image.fromarray(cutout)
+    cutout = cutout.resize(size, Image.LANCZOS)
+
+    return cutout
+
+
+def naive_cutout(img, mask):
+    empty = Image.new("RGBA", (img.size), 0)
+    cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
+    return cutout
+
+
+def remove(
+    data,
+    model_name="u2net",
+    alpha_matting=False,
+    alpha_matting_foreground_threshold=235,
+    alpha_matting_background_threshold=15,
+    alpha_matting_erode_structure_size=15,
+):
     model = model_u2net
 
     if model == "u2netp":
         model = model_u2netp
 
-    img = Image.open(io.BytesIO(data))
-    roi = detect.predict(model, np.array(img))
-    roi = roi.resize((img.size), resample=Image.LANCZOS)
+    img = Image.open(io.BytesIO(data)).convert("RGB")
+    mask = detect.predict(model, np.array(img)).convert("L")
 
-    empty = Image.new("RGBA", (img.size), 0)
-    out = Image.composite(img, empty, roi.convert("L"))
+    if alpha_matting:
+        cutout = alpha_matting_cutout(
+            img,
+            mask,
+            alpha_matting_foreground_threshold,
+            alpha_matting_background_threshold,
+            alpha_matting_erode_structure_size,
+        )
+    else:
+        cutout = naive_cutout(img, mask)
 
     bio = io.BytesIO()
-    out.save(bio, "PNG")
+    cutout.save(bio, "PNG")
 
     return bio.getbuffer()

+ 57 - 2
src/rembg/cmd/cli.py

@@ -2,6 +2,7 @@ import argparse
 import glob
 import imghdr
 import os
+from distutils.util import strtobool
 
 from ..bg import remove
 
@@ -18,6 +19,40 @@ def main():
         help="The model name.",
     )
 
+    ap.add_argument(
+        "-a",
+        "--alpha-matting",
+        nargs="?",
+        const=True,
+        default=False,
+        type=lambda x: bool(strtobool(x)),
+        help="When true use alpha matting cutout.",
+    )
+
+    ap.add_argument(
+        "-af",
+        "--alpha-matting-foreground-threshold",
+        default=235,
+        type=int,
+        help="The trimap foreground threshold.",
+    )
+
+    ap.add_argument(
+        "-ab",
+        "--alpha-matting-background-threshold",
+        default=15,
+        type=int,
+        help="The trimap background threshold.",
+    )
+
+    ap.add_argument(
+        "-ae",
+        "--alpha-matting-erode-size",
+        default=15,
+        type=int,
+        help="Size of element used for the erosion.",
+    )
+
     ap.add_argument(
         "-p", "--path", nargs="+", help="Path of a file or a folder of files.",
     )
@@ -60,10 +95,30 @@ def main():
 
             with open(fi, "rb") as input:
                 with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
-                    w(output, remove(r(input), args.model))
+                    w(
+                        output,
+                        remove(
+                            r(input),
+                            model_name=args.model,
+                            alpha_matting=args.alpha_matting,
+                            alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
+                            alpha_matting_background_threshold=args.alpha_matting_background_threshold,
+                            alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
+                        ),
+                    )
 
     else:
-        w(args.output, remove(r(args.input), args.model))
+        w(
+            args.output,
+            remove(
+                r(args.input),
+                model_name=args.model,
+                alpha_matting=args.alpha_matting,
+                alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
+                alpha_matting_background_threshold=args.alpha_matting_background_threshold,
+                alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
+            ),
+        )
 
 
 if __name__ == "__main__":

+ 8 - 11
src/rembg/cmd/server.py

@@ -11,24 +11,24 @@ from ..bg import remove
 app = Flask(__name__)
 
 
[email protected]('/', methods=['GET', 'POST'])
[email protected]("/", methods=["GET", "POST"])
 def index():
-    file_content = ''
+    file_content = ""
 
-    if request.method == 'POST':
-        if 'file' not in request.files:
+    if request.method == "POST":
+        if "file" not in request.files:
             return {"error": "missing post form param 'file'"}, 400
 
-        file_content = request.files['file'].read()
+        file_content = request.files["file"].read()
 
-    if request.method == 'GET':
+    if request.method == "GET":
         url = request.args.get("url", type=str)
         if url is None:
             return {"error": "missing query param 'url'"}, 400
 
         file_content = urlopen(unquote_plus(url)).read()
 
-    if file_content == '':
+    if file_content == "":
         return {"error": "File content is empty"}, 400
 
     model = request.args.get("model", type=str, default="u2net")
@@ -36,10 +36,7 @@ def index():
         return {"error": "invalid query param 'model'"}, 400
 
     try:
-        return send_file(
-            BytesIO(remove(file_content, model)),
-            mimetype="image/png",
-        )
+        return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
     except Exception as e:
         app.logger.exception(e, exc_info=True)
         return {"error": "oops, something went wrong!"}, 500

+ 3 - 1
src/rembg/u2net/detect.py

@@ -107,7 +107,9 @@ def predict(net, item):
     with torch.no_grad():
 
         if torch.cuda.is_available():
-            inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).cuda().float())
+            inputs_test = torch.cuda.FloatTensor(
+                sample["image"].unsqueeze(0).cuda().float()
+            )
         else:
             inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())