Daniel Gatis 5 年之前
當前提交
7fb6683169

+ 11 - 0
.editorconfig

@@ -0,0 +1,11 @@
+# https://editorconfig.org/
+
+root = true
+
+[*]
+indent_style = space
+indent_size = 4
+insert_final_newline = true
+trim_trailing_whitespace = true
+end_of_line = lf
+charset = utf-8

+ 13 - 0
.gitignore

@@ -0,0 +1,13 @@
+# general things to ignore
+build/
+dist/
+*.egg-info/
+*.egg
+*.py[cod]
+__pycache__/
+*.so
+*~
+
+# due to using tox and pytest
+.tox
+.cache

+ 21 - 0
LICENSE.txt

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Daniel Gatis
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 12 - 0
MANIFEST.in

@@ -0,0 +1,12 @@
+include pyproject.toml
+
+# Include the README
+include *.md
+
+# Include the license file
+include LICENSE.txt
+
+# Include the data files
+recursive-include data *
+
+include requirements.txt

+ 96 - 0
README.md

@@ -0,0 +1,96 @@
+
+# Rembg
+
+Rembg is a tool to remove images background. That is it.
+
+<p style="display: flex;align-items: center;justify-content: center;">
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/car-1.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/car-1.out.png" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/car-2.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/car-2.out.png" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/car-3.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/car-3.out.png" width="100" />
+</p>
+
+<p style="display: flex;align-items: center;justify-content: center;">
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/animal-1.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/animal-1.out.png" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/animal-2.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/animal-2.out.png" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/animal-3.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/animal-3.out.png" width="100" />
+</p>
+
+<p style="display: flex;align-items: center;justify-content: center;">
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/girl-1.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/girl-1.out.png" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/girl-2.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/girl-2.out.png" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/girl-3.jpg" width="100" />
+  <img src="https://raw.githubusercontent.com/danielgatis/rembg/master/examples/girl-3.out.png" width="100" />
+</p>
+
+### Installation
+
+Install it from pypi
+
+```bash
+    pip install rembg
+```
+
+### Usage as a cli
+
+Remove the backaground from a remote image
+```bash
+    curl -s http://input.png | rembg > output.png
+```
+
+Remove the backaground from a local file
+```bash
+    rembg -o path/to/output.png paht/to/input.png
+```
+
+Remove the backaground from all images in a folder
+```bash
+    rembg -p path/to/inputs
+```
+
+### Usage as a server
+
+Start the server
+```bash
+    rembg-server
+```
+
+Open your browser to
+```
+    http://localhost:5000?url=http://image.png
+```
+
+### Usage as a library
+
+In `app.py`
+
+```python
+import sys
+from rembg.bg import remove
+
+sys.stdout.buffer.write(remove(sys.stdin.buffer.read()))
+
+```
+
+Then run
+```
+    cat input.png | python app.py > out.png
+```
+
+### References
+
+- https://arxiv.org/pdf/2005.09007.pdf
+- https://github.com/NathanUA/U-2-Net
+
+### License
+
+Copyright (c) 2020-present [Daniel Gatis](https://github.com/danielgatis)
+
+Licensed under [MIT License](./LICENSE.txt)

二進制
examples/animal-1.jpg


二進制
examples/animal-1.out.png


二進制
examples/animal-2.jpg


二進制
examples/animal-2.out.png


二進制
examples/animal-3.jpg


二進制
examples/animal-3.out.png


二進制
examples/car-1.jpg


二進制
examples/car-1.out.png


二進制
examples/car-2.jpg


二進制
examples/car-2.out.png


二進制
examples/car-3.jpg


二進制
examples/car-3.out.png


二進制
examples/girl-1.jpg


二進制
examples/girl-1.out.png


二進制
examples/girl-2.jpg


二進制
examples/girl-2.out.png


二進制
examples/girl-3.jpg


二進制
examples/girl-3.out.png


+ 5 - 0
pyproject.toml

@@ -0,0 +1,5 @@
+[build-system]
+# These are the assumed default build requirements from pip:
+# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
+requires = ["setuptools>=40.8.0", "wheel"]
+build-backend = "setuptools.build_meta"

+ 8 - 0
requirements.txt

@@ -0,0 +1,8 @@
+flask==1.1.2
+numpy==1.19.1
+pillow==7.2.0
+scikit-image==0.17.2
+torch==1.6.0
+torchvision==0.7.0
+waitress==1.4.4
+tqdm==4.48.2

+ 4 - 0
setup.cfg

@@ -0,0 +1,4 @@
+[metadata]
+# This includes the license file(s) in the wheel.
+# https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file
+license_files = LICENSE.txt

+ 36 - 0
setup.py

@@ -0,0 +1,36 @@
+import pathlib
+
+from setuptools import find_packages, setup
+
+here = pathlib.Path(__file__).parent.resolve()
+
+long_description = (here / "README.md").read_text(encoding="utf-8")
+
+with open("requirements.txt") as f:
+    requireds = f.read().splitlines()
+
+setup(
+    name="rembg",
+    version="1.0.3",
+    description="Remove image background",
+    long_description=long_description,
+    long_description_content_type="text/markdown",
+    url="https://github.com/danielgatis/rembg",
+    author="Daniel Gatis",
+    author_email="[email protected]",
+    classifiers=[
+        "License :: OSI Approved :: MIT License",
+        "Programming Language :: Python :: 3 :: Only",
+    ],
+    keywords="remove, background, u2net",
+    package_dir={"": "src"},
+    packages=find_packages(where="src"),
+    python_requires=">=3.5, <4",
+    install_requires=requireds,
+    entry_points={
+        "console_scripts": [
+            "rembg=rembg.cmd.cli:main",
+            "rembg-server=rembg.cmd.server:main",
+        ],
+    },
+)

+ 0 - 0
src/rembg/__init__.py


+ 30 - 0
src/rembg/bg.py

@@ -0,0 +1,30 @@
+import argparse
+import io
+import os
+
+import numpy as np
+from PIL import Image
+
+from .u2net import detect
+
+model_u2net = detect.load_model(model_name="u2net")
+model_u2netp = detect.load_model(model_name="u2netp")
+
+
+def remove(data, model_name="u2net"):
+    model = model_u2net
+
+    if model == "u2netp":
+        model = model_u2netp
+
+    img = Image.open(io.BytesIO(data))
+    roi = detect.predict(model, np.array(img))
+    roi = roi.resize((img.size), resample=Image.LANCZOS)
+
+    empty = Image.new("RGBA", (img.size), 0)
+    out = Image.composite(img, empty, roi.convert("L"))
+
+    bio = io.BytesIO()
+    out.save(bio, "PNG")
+
+    return bio.getbuffer()

+ 0 - 0
src/rembg/cmd/__init__.py


+ 74 - 0
src/rembg/cmd/cli.py

@@ -0,0 +1,74 @@
+import argparse
+import glob
+import imghdr
+import io
+import os
+
+import numpy as np
+from PIL import Image
+
+from ..bg import remove
+
+
+def main():
+    ap = argparse.ArgumentParser()
+
+    ap.add_argument(
+        "-m",
+        "--model",
+        default="u2net",
+        type=str,
+        choices=("u2net", "u2netp"),
+        help="The model name.",
+    )
+
+    ap.add_argument(
+        "-p", "--path", nargs="+", help="Path of a file or a folder of files.",
+    )
+
+    ap.add_argument(
+        "-o",
+        "--output",
+        nargs="?",
+        default="-",
+        type=argparse.FileType("wb"),
+        help="Path to the output png image.",
+    )
+
+    ap.add_argument(
+        "input",
+        nargs="?",
+        default="-",
+        type=argparse.FileType("rb"),
+        help="Path to the input image.",
+    )
+
+    args = ap.parse_args()
+
+    r = lambda i: i.buffer.read() if hasattr(i, "buffer") else i.read()
+    w = lambda o, data: o.buffer.write(data) if hasattr(o, "buffer") else o.write(data)
+
+    if args.path:
+        full_paths = [os.path.abspath(path) for path in args.path]
+        files = set()
+
+        for path in full_paths:
+            if os.path.isfile(path):
+                files.add(path)
+            else:
+                full_paths += glob.glob(path + "/*")
+
+        for fi in files:
+            if imghdr.what(fi) is None:
+                continue
+
+            with open(fi, "rb") as input:
+                with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output:
+                    w(output, remove(r(input), args.model))
+
+    else:
+        w(args.output, remove(r(args.input), args.model))
+
+
+if __name__ == "__main__":
+    main()

+ 51 - 0
src/rembg/cmd/server.py

@@ -0,0 +1,51 @@
+import argparse
+from io import BytesIO
+from urllib.parse import unquote_plus
+from urllib.request import urlopen
+
+from flask import Flask, request, send_file
+from waitress import serve
+
+from ..bg import remove
+
+
+def index():
+    model = request.args.get("model", type=str, default="u2net")
+    if model not in ("u2net", "u2netp"):
+        return {"error": "invalid query param 'model'"}, 400
+
+    url = request.args.get("url", type=str)
+    if url is None:
+        return {"error": "missing query param 'url'"}, 400
+
+    try:
+        return send_file(
+            BytesIO(remove(urlopen(unquote_plus(url)).read(), model)),
+            mimetype="image/png",
+        )
+    except Exception as e:
+        app.logger.exception(e.message, exc_info=True)
+        return {"error": "oops, something went wrong!"}, 500
+
+
+def main():
+    ap = argparse.ArgumentParser()
+
+    ap.add_argument(
+        "-a", "--addr", default="0.0.0.0", type=str, help="The IP address to bind to.",
+    )
+
+    ap.add_argument(
+        "-p", "--port", default=5000, type=int, help="The port to bind to.",
+    )
+
+    args = ap.parse_args()
+
+    app = Flask(__name__)
+    app.add_url_rule("/", "index", index)
+
+    serve(app, host=args.addr, port=args.port)
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 0
src/rembg/u2net/__init__.py


+ 329 - 0
src/rembg/u2net/data_loader.py

@@ -0,0 +1,329 @@
+# data loader
+from __future__ import division, print_function
+
+import glob
+import math
+import random
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from PIL import Image
+from skimage import color, io, transform
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms, utils
+
+
+# ==========================dataset load==========================
+class RescaleT(object):
+    def __init__(self, output_size):
+        assert isinstance(output_size, (int, tuple))
+        self.output_size = output_size
+
+    def __call__(self, sample):
+        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
+
+        h, w = image.shape[:2]
+
+        if isinstance(self.output_size, int):
+            if h > w:
+                new_h, new_w = self.output_size * h / w, self.output_size
+            else:
+                new_h, new_w = self.output_size, self.output_size * w / h
+        else:
+            new_h, new_w = self.output_size
+
+        new_h, new_w = int(new_h), int(new_w)
+
+        # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
+        # img = transform.resize(image,(new_h,new_w),mode='constant')
+        # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
+
+        img = transform.resize(
+            image, (self.output_size, self.output_size), mode="constant"
+        )
+        lbl = transform.resize(
+            label,
+            (self.output_size, self.output_size),
+            mode="constant",
+            order=0,
+            preserve_range=True,
+        )
+
+        return {"imidx": imidx, "image": img, "label": lbl}
+
+
+class Rescale(object):
+    def __init__(self, output_size):
+        assert isinstance(output_size, (int, tuple))
+        self.output_size = output_size
+
+    def __call__(self, sample):
+        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
+
+        if random.random() >= 0.5:
+            image = image[::-1]
+            label = label[::-1]
+
+        h, w = image.shape[:2]
+
+        if isinstance(self.output_size, int):
+            if h > w:
+                new_h, new_w = self.output_size * h / w, self.output_size
+            else:
+                new_h, new_w = self.output_size, self.output_size * w / h
+        else:
+            new_h, new_w = self.output_size
+
+        new_h, new_w = int(new_h), int(new_w)
+
+        # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
+        img = transform.resize(image, (new_h, new_w), mode="constant")
+        lbl = transform.resize(
+            label, (new_h, new_w), mode="constant", order=0, preserve_range=True
+        )
+
+        return {"imidx": imidx, "image": img, "label": lbl}
+
+
+class RandomCrop(object):
+    def __init__(self, output_size):
+        assert isinstance(output_size, (int, tuple))
+        if isinstance(output_size, int):
+            self.output_size = (output_size, output_size)
+        else:
+            assert len(output_size) == 2
+            self.output_size = output_size
+
+    def __call__(self, sample):
+        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
+
+        if random.random() >= 0.5:
+            image = image[::-1]
+            label = label[::-1]
+
+        h, w = image.shape[:2]
+        new_h, new_w = self.output_size
+
+        top = np.random.randint(0, h - new_h)
+        left = np.random.randint(0, w - new_w)
+
+        image = image[top : top + new_h, left : left + new_w]
+        label = label[top : top + new_h, left : left + new_w]
+
+        return {"imidx": imidx, "image": image, "label": label}
+
+
+class ToTensor(object):
+    """Convert ndarrays in sample to Tensors."""
+
+    def __call__(self, sample):
+
+        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
+
+        tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
+        tmpLbl = np.zeros(label.shape)
+
+        image = image / np.max(image)
+        if np.max(label) < 1e-6:
+            label = label
+        else:
+            label = label / np.max(label)
+
+        if image.shape[2] == 1:
+            tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
+            tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
+            tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
+        else:
+            tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
+            tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
+            tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
+
+        tmpLbl[:, :, 0] = label[:, :, 0]
+
+        # change the r,g,b to b,r,g from [0,255] to [0,1]
+        # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
+        tmpImg = tmpImg.transpose((2, 0, 1))
+        tmpLbl = label.transpose((2, 0, 1))
+
+        return {
+            "imidx": torch.from_numpy(imidx),
+            "image": torch.from_numpy(tmpImg),
+            "label": torch.from_numpy(tmpLbl),
+        }
+
+
+class ToTensorLab(object):
+    """Convert ndarrays in sample to Tensors."""
+
+    def __init__(self, flag=0):
+        self.flag = flag
+
+    def __call__(self, sample):
+
+        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
+
+        tmpLbl = np.zeros(label.shape)
+
+        if np.max(label) < 1e-6:
+            label = label
+        else:
+            label = label / np.max(label)
+
+        # change the color space
+        if self.flag == 2:  # with rgb and Lab colors
+            tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
+            tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
+            if image.shape[2] == 1:
+                tmpImgt[:, :, 0] = image[:, :, 0]
+                tmpImgt[:, :, 1] = image[:, :, 0]
+                tmpImgt[:, :, 2] = image[:, :, 0]
+            else:
+                tmpImgt = image
+            tmpImgtl = color.rgb2lab(tmpImgt)
+
+            # nomalize image to range [0,1]
+            tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
+                np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
+            )
+            tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
+                np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
+            )
+            tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
+                np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
+            )
+            tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
+                np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
+            )
+            tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
+                np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
+            )
+            tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
+                np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
+            )
+
+            # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
+
+            tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
+                tmpImg[:, :, 0]
+            )
+            tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
+                tmpImg[:, :, 1]
+            )
+            tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
+                tmpImg[:, :, 2]
+            )
+            tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
+                tmpImg[:, :, 3]
+            )
+            tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
+                tmpImg[:, :, 4]
+            )
+            tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
+                tmpImg[:, :, 5]
+            )
+
+        elif self.flag == 1:  # with Lab color
+            tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
+
+            if image.shape[2] == 1:
+                tmpImg[:, :, 0] = image[:, :, 0]
+                tmpImg[:, :, 1] = image[:, :, 0]
+                tmpImg[:, :, 2] = image[:, :, 0]
+            else:
+                tmpImg = image
+
+            tmpImg = color.rgb2lab(tmpImg)
+
+            # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
+
+            tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
+                np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
+            )
+            tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
+                np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
+            )
+            tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
+                np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
+            )
+
+            tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
+                tmpImg[:, :, 0]
+            )
+            tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
+                tmpImg[:, :, 1]
+            )
+            tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
+                tmpImg[:, :, 2]
+            )
+
+        else:  # with rgb color
+            tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
+            image = image / np.max(image)
+            if image.shape[2] == 1:
+                tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
+                tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
+                tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
+            else:
+                tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
+                tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
+                tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
+
+        tmpLbl[:, :, 0] = label[:, :, 0]
+
+        # change the r,g,b to b,r,g from [0,255] to [0,1]
+        # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
+        tmpImg = tmpImg.transpose((2, 0, 1))
+        tmpLbl = label.transpose((2, 0, 1))
+
+        return {
+            "imidx": torch.from_numpy(imidx),
+            "image": torch.from_numpy(tmpImg),
+            "label": torch.from_numpy(tmpLbl),
+        }
+
+
+class SalObjDataset(Dataset):
+    def __init__(self, img_name_list, lbl_name_list, transform=None):
+        # self.root_dir = root_dir
+        # self.image_name_list = glob.glob(image_dir+'*.png')
+        # self.label_name_list = glob.glob(label_dir+'*.png')
+        self.image_name_list = img_name_list
+        self.label_name_list = lbl_name_list
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.image_name_list)
+
+    def __getitem__(self, idx):
+
+        # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
+        # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
+
+        image = io.imread(self.image_name_list[idx])
+        imname = self.image_name_list[idx]
+        imidx = np.array([idx])
+
+        if 0 == len(self.label_name_list):
+            label_3 = np.zeros(image.shape)
+        else:
+            label_3 = io.imread(self.label_name_list[idx])
+
+        label = np.zeros(label_3.shape[0:2])
+        if 3 == len(label_3.shape):
+            label = label_3[:, :, 0]
+        elif 2 == len(label_3.shape):
+            label = label_3
+
+        if 3 == len(image.shape) and 2 == len(label.shape):
+            label = label[:, :, np.newaxis]
+        elif 2 == len(image.shape) and 2 == len(label.shape):
+            image = image[:, :, np.newaxis]
+            label = label[:, :, np.newaxis]
+
+        sample = {"imidx": imidx, "image": image, "label": label}
+
+        if self.transform:
+            sample = self.transform(sample)
+
+        return sample

+ 132 - 0
src/rembg/u2net/detect.py

@@ -0,0 +1,132 @@
+import errno
+import os
+import time
+import urllib.request
+import sys
+
+import numpy as np
+import pkg_resources
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+from PIL import Image
+from skimage import transform
+from torchvision import transforms
+from tqdm import tqdm
+
+from . import data_loader, u2net
+
+
+class DownloadProgressBar(tqdm):
+    def update_to(self, b=1, bsize=1, tsize=None):
+        if tsize is not None:
+            self.total = tsize
+        self.update(b * bsize - self.n)
+
+
+def download_url(url, model_name, output_path):
+    if os.path.exists(output_path):
+        return
+
+    print(
+        f"Downloading model to {output_path}".format(output_path=output_path),
+        file=sys.stderr,
+    )
+
+    with DownloadProgressBar(
+        unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
+    ) as t:
+        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
+
+
+def load_model(model_name: str = "u2net"):
+    if model_name == "u2netp":
+        net = u2net.U2NETP(3, 1)
+        path = os.path.expanduser("~/.u2net/u2netp.pth")
+        download_url(
+            "https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1",
+            "u2netp.pth",
+            path,
+        )
+    elif model_name == "u2net":
+        net = u2net.U2NET(3, 1)
+        path = os.path.expanduser("~/.u2net/u2net.pth")
+        download_url(
+            "https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1",
+            "u2net.pth",
+            path,
+        )
+    else:
+        print("Choose between u2net or u2netp", file=sys.stderr)
+
+    try:
+        if torch.cuda.is_available():
+            net.load_state_dict(torch.load(path))
+            net.to(torch.device("cuda"))
+        else:
+            net.load_state_dict(torch.load(path, map_location="cpu",))
+    except FileNotFoundError:
+        raise FileNotFoundError(
+            errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
+        )
+
+    net.eval()
+
+    return net
+
+
+def norm_pred(d):
+    ma = torch.max(d)
+    mi = torch.min(d)
+    dn = (d - mi) / (ma - mi)
+
+    return dn
+
+
+def preprocess(image):
+    label_3 = np.zeros(image.shape)
+    label = np.zeros(label_3.shape[0:2])
+
+    if 3 == len(label_3.shape):
+        label = label_3[:, :, 0]
+    elif 2 == len(label_3.shape):
+        label = label_3
+
+    if 3 == len(image.shape) and 2 == len(label.shape):
+        label = label[:, :, np.newaxis]
+    elif 2 == len(image.shape) and 2 == len(label.shape):
+        image = image[:, :, np.newaxis]
+        label = label[:, :, np.newaxis]
+
+    transform = transforms.Compose(
+        [data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)]
+    )
+    sample = transform({"imidx": np.array([0]), "image": image, "label": label})
+
+    return sample
+
+
+def predict(net, item):
+
+    sample = preprocess(item)
+
+    with torch.no_grad():
+
+        if torch.cuda.is_available():
+            inputs_test = torch.cuda.FloatTensor(sample["image"].unsqueeze(0).float())
+        else:
+            inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
+
+        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
+
+        pred = d1[:, 0, :, :]
+        predict = norm_pred(pred)
+
+        predict = predict.squeeze()
+        predict_np = predict.cpu().detach().numpy()
+        img = Image.fromarray(predict_np * 255).convert("RGB")
+
+        del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
+
+        return img

+ 541 - 0
src/rembg/u2net/u2net.py

@@ -0,0 +1,541 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import models
+
+
+class REBNCONV(nn.Module):
+    def __init__(self, in_ch=3, out_ch=3, dirate=1):
+        super(REBNCONV, self).__init__()
+
+        self.conv_s1 = nn.Conv2d(
+            in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
+        )
+        self.bn_s1 = nn.BatchNorm2d(out_ch)
+        self.relu_s1 = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+
+        hx = x
+        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
+
+        return xout
+
+
+## upsample tensor 'src' to have the same spatial size with tensor 'tar'
+def _upsample_like(src, tar):
+
+    src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
+
+    return src
+
+
+### RSU-7 ###
+class RSU7(nn.Module):  # UNet07DRES(nn.Module):
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU7, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+
+        hx = x
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+        hx = self.pool4(hx4)
+
+        hx5 = self.rebnconv5(hx)
+        hx = self.pool5(hx5)
+
+        hx6 = self.rebnconv6(hx)
+
+        hx7 = self.rebnconv7(hx6)
+
+        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
+        hx6dup = _upsample_like(hx6d, hx5)
+
+        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-6 ###
+class RSU6(nn.Module):  # UNet06DRES(nn.Module):
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU6, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+        hx = self.pool4(hx4)
+
+        hx5 = self.rebnconv5(hx)
+
+        hx6 = self.rebnconv6(hx5)
+
+        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-5 ###
+class RSU5(nn.Module):  # UNet05DRES(nn.Module):
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU5, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+
+        hx5 = self.rebnconv5(hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-4 ###
+class RSU4(nn.Module):  # UNet04DRES(nn.Module):
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU4, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+
+        hx4 = self.rebnconv4(hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-4F ###
+class RSU4F(nn.Module):  # UNet04FRES(nn.Module):
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU4F, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
+
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx2 = self.rebnconv2(hx1)
+        hx3 = self.rebnconv3(hx2)
+
+        hx4 = self.rebnconv4(hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
+        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
+
+        return hx1d + hxin
+
+
+##### U^2-Net ####
+class U2NET(nn.Module):
+    def __init__(self, in_ch=3, out_ch=1):
+        super(U2NET, self).__init__()
+
+        self.stage1 = RSU7(in_ch, 32, 64)
+        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage2 = RSU6(64, 32, 128)
+        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage3 = RSU5(128, 64, 256)
+        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage4 = RSU4(256, 128, 512)
+        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage5 = RSU4F(512, 256, 512)
+        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage6 = RSU4F(512, 256, 512)
+
+        # decoder
+        self.stage5d = RSU4F(1024, 256, 512)
+        self.stage4d = RSU4(1024, 128, 256)
+        self.stage3d = RSU5(512, 64, 128)
+        self.stage2d = RSU6(256, 32, 64)
+        self.stage1d = RSU7(128, 16, 64)
+
+        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
+        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
+        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
+        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
+
+        self.outconv = nn.Conv2d(6, out_ch, 1)
+
+    def forward(self, x):
+
+        hx = x
+
+        # stage 1
+        hx1 = self.stage1(hx)
+        hx = self.pool12(hx1)
+
+        # stage 2
+        hx2 = self.stage2(hx)
+        hx = self.pool23(hx2)
+
+        # stage 3
+        hx3 = self.stage3(hx)
+        hx = self.pool34(hx3)
+
+        # stage 4
+        hx4 = self.stage4(hx)
+        hx = self.pool45(hx4)
+
+        # stage 5
+        hx5 = self.stage5(hx)
+        hx = self.pool56(hx5)
+
+        # stage 6
+        hx6 = self.stage6(hx)
+        hx6up = _upsample_like(hx6, hx5)
+
+        # -------------------- decoder --------------------
+        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
+
+        # side output
+        d1 = self.side1(hx1d)
+
+        d2 = self.side2(hx2d)
+        d2 = _upsample_like(d2, d1)
+
+        d3 = self.side3(hx3d)
+        d3 = _upsample_like(d3, d1)
+
+        d4 = self.side4(hx4d)
+        d4 = _upsample_like(d4, d1)
+
+        d5 = self.side5(hx5d)
+        d5 = _upsample_like(d5, d1)
+
+        d6 = self.side6(hx6)
+        d6 = _upsample_like(d6, d1)
+
+        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
+
+        return (
+            torch.sigmoid(d0),
+            torch.sigmoid(d1),
+            torch.sigmoid(d2),
+            torch.sigmoid(d3),
+            torch.sigmoid(d4),
+            torch.sigmoid(d5),
+            torch.sigmoid(d6),
+        )
+
+
+### U^2-Net small ###
+class U2NETP(nn.Module):
+    def __init__(self, in_ch=3, out_ch=1):
+        super(U2NETP, self).__init__()
+
+        self.stage1 = RSU7(in_ch, 16, 64)
+        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage2 = RSU6(64, 16, 64)
+        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage3 = RSU5(64, 16, 64)
+        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage4 = RSU4(64, 16, 64)
+        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage5 = RSU4F(64, 16, 64)
+        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage6 = RSU4F(64, 16, 64)
+
+        # decoder
+        self.stage5d = RSU4F(128, 16, 64)
+        self.stage4d = RSU4(128, 16, 64)
+        self.stage3d = RSU5(128, 16, 64)
+        self.stage2d = RSU6(128, 16, 64)
+        self.stage1d = RSU7(128, 16, 64)
+
+        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
+
+        self.outconv = nn.Conv2d(6, out_ch, 1)
+
+    def forward(self, x):
+
+        hx = x
+
+        # stage 1
+        hx1 = self.stage1(hx)
+        hx = self.pool12(hx1)
+
+        # stage 2
+        hx2 = self.stage2(hx)
+        hx = self.pool23(hx2)
+
+        # stage 3
+        hx3 = self.stage3(hx)
+        hx = self.pool34(hx3)
+
+        # stage 4
+        hx4 = self.stage4(hx)
+        hx = self.pool45(hx4)
+
+        # stage 5
+        hx5 = self.stage5(hx)
+        hx = self.pool56(hx5)
+
+        # stage 6
+        hx6 = self.stage6(hx)
+        hx6up = _upsample_like(hx6, hx5)
+
+        # decoder
+        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
+
+        # side output
+        d1 = self.side1(hx1d)
+
+        d2 = self.side2(hx2d)
+        d2 = _upsample_like(d2, d1)
+
+        d3 = self.side3(hx3d)
+        d3 = _upsample_like(d3, d1)
+
+        d4 = self.side4(hx4d)
+        d4 = _upsample_like(d4, d1)
+
+        d5 = self.side5(hx5d)
+        d5 = _upsample_like(d5, d1)
+
+        d6 = self.side6(hx6)
+        d6 = _upsample_like(d6, d1)
+
+        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
+
+        return (
+            torch.sigmoid(d0),
+            torch.sigmoid(d1),
+            torch.sigmoid(d2),
+            torch.sigmoid(d3),
+            torch.sigmoid(d4),
+            torch.sigmoid(d5),
+            torch.sigmoid(d6),
+        )