Daniel Gatis 3 年之前
父節點
當前提交
3b18bad858
共有 12 個文件被更改,包括 111 次插入1068 次删除
  1. 5 5
      Dockerfile
  2. 6 33
      README.md
  3. 0 2
      rembg/bg.py
  4. 1 12
      rembg/cli.py
  5. 0 326
      rembg/data_loader.py
  6. 84 131
      rembg/detect.py
  7. 1 10
      rembg/server.py
  8. 0 541
      rembg/u2net.py
  9. 1 0
      requirements-cpu.txt
  10. 1 0
      requirements-gpu.txt
  11. 5 8
      requirements.txt
  12. 7 0
      setup.py

+ 5 - 5
Dockerfile

@@ -2,16 +2,16 @@ FROM nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04
 
 RUN apt-get update &&\
     apt-get install -y --no-install-recommends \
-        python3 \
-        python3-pip \
-        python3-dev \
-        build-essential
+    python3 \
+    python3-pip \
+    python3-dev \
+    build-essential
 
 WORKDIR /rembg
 
 COPY . .
 
-RUN pip3 install .
+RUN GPU=1 pip3 install .
 
 # First run to compile AOT & download model
 RUN rembg pixel.png >/dev/null

+ 6 - 33
README.md

@@ -36,38 +36,19 @@ Rembg is a tool to remove images background. That is it.
 
 #### *** If you want to remove background from videos try this this fork: https://github.com/ecsplendid/rembg-greenscreen ***
 
-### Requirements
-
-* python 3.8 or newer
-
-* torch and torchvision stable version (https://pytorch.org)
-
-#### How to install torch/torchvision
-
-Go to https://pytorch.org and scrool down to `INSTALL PYTORCH` section and follow the instructions.
-
-For example:
-```
-PyTorch Build: Stable (1.7.1)
-Your OS: Windows
-Package: Pip
-Language: Python
-CUDA: None
-```
-
-The install cmd is:
-```
-pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
-```
 
 ### Installation
 
-Install it from pypi
-
+CPU support:
 ```bash
 pip install rembg
 ```
 
+GPU support:
+```bash
+GPU=1 pip install rembg
+```
+
 ### Usage as a cli
 
 Remove the background from a remote image
@@ -85,14 +66,6 @@ Remove the background from all images in a folder
 rembg -p path/to/input path/to/output
 ```
 
-### Add a custom model
-
-Copy the `custom-model.pth` file to `~/.u2net` and run:
-
-```bash
-curl -s http://input.png | rembg -m custom-model > output.png
-```
-
 ### Usage as a server
 
 Start the server

+ 0 - 2
rembg/bg.py

@@ -1,4 +1,3 @@
-import functools
 import io
 
 import numpy as np
@@ -68,7 +67,6 @@ def naive_cutout(img, mask):
     return cutout
 
 
[email protected]_cache(maxsize=None)
 def get_model(model_name):
     if model_name == "u2netp":
         return load_model(model_name="u2netp")

+ 1 - 12
rembg/cli.py

@@ -10,17 +10,6 @@ from .bg import remove
 
 
 def main():
-    model_path = os.environ.get(
-        "U2NETP_PATH",
-        os.path.expanduser(os.path.join("~", ".u2net")),
-    )
-    model_choices = [
-        os.path.splitext(os.path.basename(x))[0]
-        for x in set(glob.glob(model_path + "/*"))
-    ]
-
-    model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
-
     ap = argparse.ArgumentParser()
 
     ap.add_argument(
@@ -28,7 +17,7 @@ def main():
         "--model",
         default="u2net",
         type=str,
-        choices=model_choices,
+        choices=["u2net", "u2netp", "u2net_human_seg"],
         help="The model name.",
     )
 

+ 0 - 326
rembg/data_loader.py

@@ -1,326 +0,0 @@
-# data loader
-
-import random
-
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
-from PIL import Image
-from skimage import color, io, transform
-from torch.utils.data import DataLoader, Dataset
-from torchvision import transforms, utils
-
-
-# ==========================dataset load==========================
-class RescaleT:
-    def __init__(self, output_size):
-        assert isinstance(output_size, (int, tuple))
-        self.output_size = output_size
-
-    def __call__(self, sample):
-        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
-
-        h, w = image.shape[:2]
-
-        if isinstance(self.output_size, int):
-            if h > w:
-                new_h, new_w = self.output_size * h / w, self.output_size
-            else:
-                new_h, new_w = self.output_size, self.output_size * w / h
-        else:
-            new_h, new_w = self.output_size
-
-        new_h, new_w = int(new_h), int(new_w)
-
-        # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
-        # img = transform.resize(image,(new_h,new_w),mode='constant')
-        # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
-
-        img = transform.resize(
-            image, (self.output_size, self.output_size), mode="constant"
-        )
-        lbl = transform.resize(
-            label,
-            (self.output_size, self.output_size),
-            mode="constant",
-            order=0,
-            preserve_range=True,
-        )
-
-        return {"imidx": imidx, "image": img, "label": lbl}
-
-
-class Rescale:
-    def __init__(self, output_size):
-        assert isinstance(output_size, (int, tuple))
-        self.output_size = output_size
-
-    def __call__(self, sample):
-        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
-
-        if random.random() >= 0.5:
-            image = image[::-1]
-            label = label[::-1]
-
-        h, w = image.shape[:2]
-
-        if isinstance(self.output_size, int):
-            if h > w:
-                new_h, new_w = self.output_size * h / w, self.output_size
-            else:
-                new_h, new_w = self.output_size, self.output_size * w / h
-        else:
-            new_h, new_w = self.output_size
-
-        new_h, new_w = int(new_h), int(new_w)
-
-        # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
-        img = transform.resize(image, (new_h, new_w), mode="constant")
-        lbl = transform.resize(
-            label, (new_h, new_w), mode="constant", order=0, preserve_range=True
-        )
-
-        return {"imidx": imidx, "image": img, "label": lbl}
-
-
-class RandomCrop:
-    def __init__(self, output_size):
-        assert isinstance(output_size, (int, tuple))
-        if isinstance(output_size, int):
-            self.output_size = (output_size, output_size)
-        else:
-            assert len(output_size) == 2
-            self.output_size = output_size
-
-    def __call__(self, sample):
-        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
-
-        if random.random() >= 0.5:
-            image = image[::-1]
-            label = label[::-1]
-
-        h, w = image.shape[:2]
-        new_h, new_w = self.output_size
-
-        top = np.random.randint(0, h - new_h)
-        left = np.random.randint(0, w - new_w)
-
-        image = image[top : top + new_h, left : left + new_w]
-        label = label[top : top + new_h, left : left + new_w]
-
-        return {"imidx": imidx, "image": image, "label": label}
-
-
-class ToTensor:
-    """Convert ndarrays in sample to Tensors."""
-
-    def __call__(self, sample):
-
-        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
-
-        tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
-        tmpLbl = np.zeros(label.shape)
-
-        image = image / np.max(image)
-        if np.max(label) < 1e-6:
-            label = label
-        else:
-            label = label / np.max(label)
-
-        if image.shape[2] == 1:
-            tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
-            tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
-            tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
-        else:
-            tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
-            tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
-            tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
-
-        tmpLbl[:, :, 0] = label[:, :, 0]
-
-        # change the r,g,b to b,r,g from [0,255] to [0,1]
-        # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
-        tmpImg = tmpImg.transpose((2, 0, 1))
-        tmpLbl = label.transpose((2, 0, 1))
-
-        return {
-            "imidx": torch.from_numpy(imidx),
-            "image": torch.from_numpy(tmpImg),
-            "label": torch.from_numpy(tmpLbl),
-        }
-
-
-class ToTensorLab:
-    """Convert ndarrays in sample to Tensors."""
-
-    def __init__(self, flag=0):
-        self.flag = flag
-
-    def __call__(self, sample):
-
-        imidx, image, label = sample["imidx"], sample["image"], sample["label"]
-
-        tmpLbl = np.zeros(label.shape)
-
-        if np.max(label) < 1e-6:
-            label = label
-        else:
-            label = label / np.max(label)
-
-        # change the color space
-        if self.flag == 2:  # with rgb and Lab colors
-            tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
-            tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
-            if image.shape[2] == 1:
-                tmpImgt[:, :, 0] = image[:, :, 0]
-                tmpImgt[:, :, 1] = image[:, :, 0]
-                tmpImgt[:, :, 2] = image[:, :, 0]
-            else:
-                tmpImgt = image
-            tmpImgtl = color.rgb2lab(tmpImgt)
-
-            # nomalize image to range [0,1]
-            tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
-                np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
-            )
-            tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
-                np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
-            )
-            tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
-                np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
-            )
-            tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
-                np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
-            )
-            tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
-                np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
-            )
-            tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
-                np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
-            )
-
-            # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
-
-            tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
-                tmpImg[:, :, 0]
-            )
-            tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
-                tmpImg[:, :, 1]
-            )
-            tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
-                tmpImg[:, :, 2]
-            )
-            tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
-                tmpImg[:, :, 3]
-            )
-            tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
-                tmpImg[:, :, 4]
-            )
-            tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
-                tmpImg[:, :, 5]
-            )
-
-        elif self.flag == 1:  # with Lab color
-            tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
-
-            if image.shape[2] == 1:
-                tmpImg[:, :, 0] = image[:, :, 0]
-                tmpImg[:, :, 1] = image[:, :, 0]
-                tmpImg[:, :, 2] = image[:, :, 0]
-            else:
-                tmpImg = image
-
-            tmpImg = color.rgb2lab(tmpImg)
-
-            # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
-
-            tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
-                np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
-            )
-            tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
-                np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
-            )
-            tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
-                np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
-            )
-
-            tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
-                tmpImg[:, :, 0]
-            )
-            tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
-                tmpImg[:, :, 1]
-            )
-            tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
-                tmpImg[:, :, 2]
-            )
-
-        else:  # with rgb color
-            tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
-            image = image / np.max(image)
-            if image.shape[2] == 1:
-                tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
-                tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
-                tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
-            else:
-                tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
-                tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
-                tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
-
-        tmpLbl[:, :, 0] = label[:, :, 0]
-
-        # change the r,g,b to b,r,g from [0,255] to [0,1]
-        # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
-        tmpImg = tmpImg.transpose((2, 0, 1))
-        tmpLbl = label.transpose((2, 0, 1))
-
-        return {
-            "imidx": torch.from_numpy(imidx),
-            "image": torch.from_numpy(tmpImg),
-            "label": torch.from_numpy(tmpLbl),
-        }
-
-
-class SalObjDataset(Dataset):
-    def __init__(self, img_name_list, lbl_name_list, transform=None):
-        # self.root_dir = root_dir
-        # self.image_name_list = glob.glob(image_dir+'*.png')
-        # self.label_name_list = glob.glob(label_dir+'*.png')
-        self.image_name_list = img_name_list
-        self.label_name_list = lbl_name_list
-        self.transform = transform
-
-    def __len__(self):
-        return len(self.image_name_list)
-
-    def __getitem__(self, idx):
-
-        # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
-        # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
-
-        image = io.imread(self.image_name_list[idx])
-        imname = self.image_name_list[idx]
-        imidx = np.array([idx])
-
-        if 0 == len(self.label_name_list):
-            label_3 = np.zeros(image.shape)
-        else:
-            label_3 = io.imread(self.label_name_list[idx])
-
-        label = np.zeros(label_3.shape[0:2])
-        if 3 == len(label_3.shape):
-            label = label_3[:, :, 0]
-        elif 2 == len(label_3.shape):
-            label = label_3
-
-        if 3 == len(image.shape) and 2 == len(label.shape):
-            label = label[:, :, np.newaxis]
-        elif 2 == len(image.shape) and 2 == len(label.shape):
-            image = image[:, :, np.newaxis]
-            label = label[:, :, np.newaxis]
-
-        sample = {"imidx": imidx, "image": image, "label": label}
-
-        if self.transform:
-            sample = self.transform(sample)
-
-        return sample

+ 84 - 131
rembg/detect.py

@@ -1,137 +1,101 @@
-import errno
 import os
 import sys
-import urllib.request
-from hashlib import md5
 
+import gdown
 import numpy as np
-import requests
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torchvision
+import onnxruntime as ort
 from PIL import Image
 from skimage import transform
-from torchvision import transforms
-from tqdm import tqdm
 
-from .data_loader import RescaleT, ToTensorLab
-from .u2net import U2NET, U2NETP
+SESSIONS = {}
 
 
-def download_file_from_google_drive(id, fname, destination):
-    head, tail = os.path.split(destination)
-    os.makedirs(head, exist_ok=True)
-
-    URL = "https://docs.google.com/uc?export=download"
+def load_model(model_name: str = "u2net"):
+    path = os.environ.get(
+        "U2NETP_PATH",
+        os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
+    )
 
-    session = requests.Session()
-    response = session.get(URL, params={"id": id}, stream=True)
+    if model_name == "u2netp":
+        md5 = "8e83ca70e441ab06c318d82300c84806"
+        url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
+    elif model_name == "u2net":
+        md5 = "60024c5c889badc19c04ad937298a77b"
+        url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
+    elif model_name == "u2net_human_seg":
+        md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
+        url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
+    else:
+        print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
 
-    token = None
-    for key, value in response.cookies.items():
-        if key.startswith("download_warning"):
-            token = value
-            break
+    if SESSIONS.get(md5) is None:
+        gdown.cached_download(url, path, md5=md5, quiet=False)
+        SESSIONS[md5] = ort.InferenceSession(path)
 
-    if token:
-        params = {"id": id, "confirm": token}
-        response = session.get(URL, params=params, stream=True)
+    return SESSIONS[md5]
 
-    total = int(response.headers.get("content-length", 0))
 
-    with open(destination, "wb") as file, tqdm(
-        desc=f"Downloading {tail} to {head}",
-        total=total,
-        unit="iB",
-        unit_scale=True,
-        unit_divisor=1024,
-    ) as bar:
-        for data in response.iter_content(chunk_size=1024):
-            size = file.write(data)
-            bar.update(size)
+def norm_pred(d):
+    ma = np.max(d)
+    mi = np.min(d)
+    dn = (d - mi) / (ma - mi)
 
+    return dn
 
-def load_model(model_name: str = "u2net"):
-    hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
 
-    if model_name == "u2netp":
-        net = U2NETP(3, 1)
-        path = os.environ.get(
-            "U2NETP_PATH",
-            os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
-        )
-        if (
-            not os.path.exists(path)
-            or hashfile(path) != "e4f636406ca4e2af789941e7f139ee2e"
-        ):
-            download_file_from_google_drive(
-                "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
-                "u2netp.pth",
-                path,
-            )
+def rescale(sample, output_size):
+    imidx, image, label = sample["imidx"], sample["image"], sample["label"]
 
-    elif model_name == "u2net":
-        net = U2NET(3, 1)
-        path = os.environ.get(
-            "U2NET_PATH",
-            os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
-        )
-        if (
-            not os.path.exists(path)
-            or hashfile(path) != "347c3d51b01528e5c6c071e3cff1cb55"
-        ):
-            download_file_from_google_drive(
-                "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
-                "pth",
-                path,
-            )
+    h, w = image.shape[:2]
 
-    elif model_name == "u2net_human_seg":
-        net = U2NET(3, 1)
-        path = os.environ.get(
-            "U2NET_PATH",
-            os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
-        )
-        if (
-            not os.path.exists(path)
-            or hashfile(path) != "09fb4e49b7f785c9f855baf94916840a"
-        ):
-            download_file_from_google_drive(
-                "1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
-                "u2net_human_seg.pth",
-                path,
-            )
+    if isinstance(output_size, int):
+        if h > w:
+            new_h, new_w = output_size * h / w, output_size
+        else:
+            new_h, new_w = output_size, output_size * w / h
     else:
-        print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
+        new_h, new_w = output_size
 
-    try:
-        if torch.cuda.is_available():
-            net.load_state_dict(torch.load(path))
-            net.to(torch.device("cuda"))
-        else:
-            net.load_state_dict(
-                torch.load(
-                    path,
-                    map_location="cpu",
-                )
-            )
-    except FileNotFoundError:
-        raise FileNotFoundError(
-            errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
-        )
+    new_h, new_w = int(new_h), int(new_w)
 
-    net.eval()
+    img = transform.resize(image, (output_size, output_size), mode="constant")
+    lbl = transform.resize(
+        label,
+        (output_size, output_size),
+        mode="constant",
+        order=0,
+        preserve_range=True,
+    )
 
-    return net
+    return {"imidx": imidx, "image": img, "label": lbl}
 
 
-def norm_pred(d):
-    ma = torch.max(d)
-    mi = torch.min(d)
-    dn = (d - mi) / (ma - mi)
+def color(sample):
+    imidx, image, label = sample["imidx"], sample["image"], sample["label"]
 
-    return dn
+    tmpLbl = np.zeros(label.shape)
+
+    if np.max(label) < 1e-6:
+        label = label
+    else:
+        label = label / np.max(label)
+
+    tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
+    image = image / np.max(image)
+    if image.shape[2] == 1:
+        tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
+        tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
+        tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
+    else:
+        tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
+        tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
+        tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
+
+    tmpLbl[:, :, 0] = label[:, :, 0]
+    tmpImg = tmpImg.transpose((2, 0, 1))
+    tmpLbl = label.transpose((2, 0, 1))
+
+    return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
 
 
 def preprocess(image):
@@ -149,34 +113,23 @@ def preprocess(image):
         image = image[:, :, np.newaxis]
         label = label[:, :, np.newaxis]
 
-    transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
-    sample = transform({"imidx": np.array([0]), "image": image, "label": label})
+    sample = {"imidx": np.array([0]), "image": image, "label": label}
+    sample = rescale(sample, 320)
+    sample = color(sample)
 
     return sample
 
 
-def predict(net, item):
-
+def predict(ort_session, item):
     sample = preprocess(item)
+    inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
 
-    with torch.no_grad():
-
-        if torch.cuda.is_available():
-            inputs_test = torch.cuda.FloatTensor(
-                sample["image"].unsqueeze(0).cuda().float()
-            )
-        else:
-            inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
-
-        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
-
-        pred = d1[:, 0, :, :]
-        predict = norm_pred(pred)
-
-        predict = predict.squeeze()
-        predict_np = predict.cpu().detach().numpy()
-        img = Image.fromarray(predict_np * 255).convert("RGB")
+    ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}
+    ort_outs = ort_session.run(None, ort_inputs)
 
-        del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
+    d1 = ort_outs[0]
+    pred = d1[:, 0, :, :]
+    predict = np.squeeze(norm_pred(pred))
+    img = Image.fromarray(predict * 255).convert("RGB")
 
-        return img
+    return img

+ 1 - 10
rembg/server.py

@@ -46,16 +46,7 @@ def index():
     height = request.args.get("height", type=int)
 
     model = request.values.get("model", type=str, default="u2net")
-    model_path = os.environ.get(
-        "U2NETP_PATH",
-        os.path.expanduser(os.path.join("~", ".u2net")),
-    )
-    model_choices = [
-        os.path.splitext(os.path.basename(x))[0]
-        for x in set(glob.glob(model_path + "/*"))
-    ]
-
-    model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
+    model_choices = ["u2net", "u2netp", "u2net_human_seg"]
 
     if model not in model_choices:
         return {

+ 0 - 541
rembg/u2net.py

@@ -1,541 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torchvision import models
-
-
-class REBNCONV(nn.Module):
-    def __init__(self, in_ch=3, out_ch=3, dirate=1):
-        super().__init__()
-
-        self.conv_s1 = nn.Conv2d(
-            in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
-        )
-        self.bn_s1 = nn.BatchNorm2d(out_ch)
-        self.relu_s1 = nn.ReLU(inplace=True)
-
-    def forward(self, x):
-
-        hx = x
-        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
-
-        return xout
-
-
-## upsample tensor 'src' to have the same spatial size with tensor 'tar'
-def _upsample_like(src, tar):
-
-    src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
-
-    return src
-
-
-### RSU-7 ###
-class RSU7(nn.Module):  # UNet07DRES(nn.Module):
-    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super().__init__()
-
-        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
-
-        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
-        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
-
-        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
-
-        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
-
-    def forward(self, x):
-
-        hx = x
-        hxin = self.rebnconvin(hx)
-
-        hx1 = self.rebnconv1(hxin)
-        hx = self.pool1(hx1)
-
-        hx2 = self.rebnconv2(hx)
-        hx = self.pool2(hx2)
-
-        hx3 = self.rebnconv3(hx)
-        hx = self.pool3(hx3)
-
-        hx4 = self.rebnconv4(hx)
-        hx = self.pool4(hx4)
-
-        hx5 = self.rebnconv5(hx)
-        hx = self.pool5(hx5)
-
-        hx6 = self.rebnconv6(hx)
-
-        hx7 = self.rebnconv7(hx6)
-
-        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
-        hx6dup = _upsample_like(hx6d, hx5)
-
-        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
-        hx5dup = _upsample_like(hx5d, hx4)
-
-        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
-        hx4dup = _upsample_like(hx4d, hx3)
-
-        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
-        hx3dup = _upsample_like(hx3d, hx2)
-
-        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
-        hx2dup = _upsample_like(hx2d, hx1)
-
-        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
-
-        return hx1d + hxin
-
-
-### RSU-6 ###
-class RSU6(nn.Module):  # UNet06DRES(nn.Module):
-    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super().__init__()
-
-        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
-
-        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
-        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
-
-        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
-
-        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
-
-    def forward(self, x):
-
-        hx = x
-
-        hxin = self.rebnconvin(hx)
-
-        hx1 = self.rebnconv1(hxin)
-        hx = self.pool1(hx1)
-
-        hx2 = self.rebnconv2(hx)
-        hx = self.pool2(hx2)
-
-        hx3 = self.rebnconv3(hx)
-        hx = self.pool3(hx3)
-
-        hx4 = self.rebnconv4(hx)
-        hx = self.pool4(hx4)
-
-        hx5 = self.rebnconv5(hx)
-
-        hx6 = self.rebnconv6(hx5)
-
-        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
-        hx5dup = _upsample_like(hx5d, hx4)
-
-        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
-        hx4dup = _upsample_like(hx4d, hx3)
-
-        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
-        hx3dup = _upsample_like(hx3d, hx2)
-
-        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
-        hx2dup = _upsample_like(hx2d, hx1)
-
-        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
-
-        return hx1d + hxin
-
-
-### RSU-5 ###
-class RSU5(nn.Module):  # UNet05DRES(nn.Module):
-    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super().__init__()
-
-        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
-
-        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
-        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
-
-        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
-
-        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
-
-    def forward(self, x):
-
-        hx = x
-
-        hxin = self.rebnconvin(hx)
-
-        hx1 = self.rebnconv1(hxin)
-        hx = self.pool1(hx1)
-
-        hx2 = self.rebnconv2(hx)
-        hx = self.pool2(hx2)
-
-        hx3 = self.rebnconv3(hx)
-        hx = self.pool3(hx3)
-
-        hx4 = self.rebnconv4(hx)
-
-        hx5 = self.rebnconv5(hx4)
-
-        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
-        hx4dup = _upsample_like(hx4d, hx3)
-
-        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
-        hx3dup = _upsample_like(hx3d, hx2)
-
-        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
-        hx2dup = _upsample_like(hx2d, hx1)
-
-        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
-
-        return hx1d + hxin
-
-
-### RSU-4 ###
-class RSU4(nn.Module):  # UNet04DRES(nn.Module):
-    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super().__init__()
-
-        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
-
-        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
-        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
-        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
-
-        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
-
-        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
-        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
-
-    def forward(self, x):
-
-        hx = x
-
-        hxin = self.rebnconvin(hx)
-
-        hx1 = self.rebnconv1(hxin)
-        hx = self.pool1(hx1)
-
-        hx2 = self.rebnconv2(hx)
-        hx = self.pool2(hx2)
-
-        hx3 = self.rebnconv3(hx)
-
-        hx4 = self.rebnconv4(hx3)
-
-        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
-        hx3dup = _upsample_like(hx3d, hx2)
-
-        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
-        hx2dup = _upsample_like(hx2d, hx1)
-
-        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
-
-        return hx1d + hxin
-
-
-### RSU-4F ###
-class RSU4F(nn.Module):  # UNet04FRES(nn.Module):
-    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
-        super().__init__()
-
-        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
-
-        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
-        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
-        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
-
-        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
-
-        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
-        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
-        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
-
-    def forward(self, x):
-
-        hx = x
-
-        hxin = self.rebnconvin(hx)
-
-        hx1 = self.rebnconv1(hxin)
-        hx2 = self.rebnconv2(hx1)
-        hx3 = self.rebnconv3(hx2)
-
-        hx4 = self.rebnconv4(hx3)
-
-        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
-        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
-        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
-
-        return hx1d + hxin
-
-
-##### U^2-Net ####
-class U2NET(nn.Module):
-    def __init__(self, in_ch=3, out_ch=1):
-        super().__init__()
-
-        self.stage1 = RSU7(in_ch, 32, 64)
-        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage2 = RSU6(64, 32, 128)
-        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage3 = RSU5(128, 64, 256)
-        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage4 = RSU4(256, 128, 512)
-        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage5 = RSU4F(512, 256, 512)
-        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage6 = RSU4F(512, 256, 512)
-
-        # decoder
-        self.stage5d = RSU4F(1024, 256, 512)
-        self.stage4d = RSU4(1024, 128, 256)
-        self.stage3d = RSU5(512, 64, 128)
-        self.stage2d = RSU6(256, 32, 64)
-        self.stage1d = RSU7(128, 16, 64)
-
-        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
-        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
-        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
-        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
-        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
-        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
-
-        self.outconv = nn.Conv2d(6, out_ch, 1)
-
-    def forward(self, x):
-
-        hx = x
-
-        # stage 1
-        hx1 = self.stage1(hx)
-        hx = self.pool12(hx1)
-
-        # stage 2
-        hx2 = self.stage2(hx)
-        hx = self.pool23(hx2)
-
-        # stage 3
-        hx3 = self.stage3(hx)
-        hx = self.pool34(hx3)
-
-        # stage 4
-        hx4 = self.stage4(hx)
-        hx = self.pool45(hx4)
-
-        # stage 5
-        hx5 = self.stage5(hx)
-        hx = self.pool56(hx5)
-
-        # stage 6
-        hx6 = self.stage6(hx)
-        hx6up = _upsample_like(hx6, hx5)
-
-        # -------------------- decoder --------------------
-        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
-        hx5dup = _upsample_like(hx5d, hx4)
-
-        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
-        hx4dup = _upsample_like(hx4d, hx3)
-
-        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
-        hx3dup = _upsample_like(hx3d, hx2)
-
-        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
-        hx2dup = _upsample_like(hx2d, hx1)
-
-        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
-
-        # side output
-        d1 = self.side1(hx1d)
-
-        d2 = self.side2(hx2d)
-        d2 = _upsample_like(d2, d1)
-
-        d3 = self.side3(hx3d)
-        d3 = _upsample_like(d3, d1)
-
-        d4 = self.side4(hx4d)
-        d4 = _upsample_like(d4, d1)
-
-        d5 = self.side5(hx5d)
-        d5 = _upsample_like(d5, d1)
-
-        d6 = self.side6(hx6)
-        d6 = _upsample_like(d6, d1)
-
-        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
-
-        return (
-            torch.sigmoid(d0),
-            torch.sigmoid(d1),
-            torch.sigmoid(d2),
-            torch.sigmoid(d3),
-            torch.sigmoid(d4),
-            torch.sigmoid(d5),
-            torch.sigmoid(d6),
-        )
-
-
-### U^2-Net small ###
-class U2NETP(nn.Module):
-    def __init__(self, in_ch=3, out_ch=1):
-        super().__init__()
-
-        self.stage1 = RSU7(in_ch, 16, 64)
-        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage2 = RSU6(64, 16, 64)
-        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage3 = RSU5(64, 16, 64)
-        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage4 = RSU4(64, 16, 64)
-        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage5 = RSU4F(64, 16, 64)
-        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
-
-        self.stage6 = RSU4F(64, 16, 64)
-
-        # decoder
-        self.stage5d = RSU4F(128, 16, 64)
-        self.stage4d = RSU4(128, 16, 64)
-        self.stage3d = RSU5(128, 16, 64)
-        self.stage2d = RSU6(128, 16, 64)
-        self.stage1d = RSU7(128, 16, 64)
-
-        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
-        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
-        self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
-        self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
-        self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
-        self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
-
-        self.outconv = nn.Conv2d(6, out_ch, 1)
-
-    def forward(self, x):
-
-        hx = x
-
-        # stage 1
-        hx1 = self.stage1(hx)
-        hx = self.pool12(hx1)
-
-        # stage 2
-        hx2 = self.stage2(hx)
-        hx = self.pool23(hx2)
-
-        # stage 3
-        hx3 = self.stage3(hx)
-        hx = self.pool34(hx3)
-
-        # stage 4
-        hx4 = self.stage4(hx)
-        hx = self.pool45(hx4)
-
-        # stage 5
-        hx5 = self.stage5(hx)
-        hx = self.pool56(hx5)
-
-        # stage 6
-        hx6 = self.stage6(hx)
-        hx6up = _upsample_like(hx6, hx5)
-
-        # decoder
-        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
-        hx5dup = _upsample_like(hx5d, hx4)
-
-        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
-        hx4dup = _upsample_like(hx4d, hx3)
-
-        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
-        hx3dup = _upsample_like(hx3d, hx2)
-
-        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
-        hx2dup = _upsample_like(hx2d, hx1)
-
-        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
-
-        # side output
-        d1 = self.side1(hx1d)
-
-        d2 = self.side2(hx2d)
-        d2 = _upsample_like(d2, d1)
-
-        d3 = self.side3(hx3d)
-        d3 = _upsample_like(d3, d1)
-
-        d4 = self.side4(hx4d)
-        d4 = _upsample_like(d4, d1)
-
-        d5 = self.side5(hx5d)
-        d5 = _upsample_like(d5, d1)
-
-        d6 = self.side6(hx6)
-        d6 = _upsample_like(d6, d1)
-
-        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
-
-        return (
-            torch.sigmoid(d0),
-            torch.sigmoid(d1),
-            torch.sigmoid(d2),
-            torch.sigmoid(d3),
-            torch.sigmoid(d4),
-            torch.sigmoid(d5),
-            torch.sigmoid(d6),
-        )

+ 1 - 0
requirements-cpu.txt

@@ -0,0 +1 @@
+onnxruntime==1.10.0

+ 1 - 0
requirements-gpu.txt

@@ -0,0 +1 @@
+onnxruntime-gpu==1.10.0

+ 5 - 8
requirements.txt

@@ -1,13 +1,10 @@
+filetype==1.0.7
 flask==1.1.2
+gdown==4.2.0
 numpy==1.20.0
 pillow==8.3.2
+pymatting==1.1.5
 scikit-image==0.19.1
-torch==1.9.1
-torchvision==0.10.1
-waitress==1.4.4
-tqdm==4.51.0
-requests==2.24.0
 scipy==1.5.4
-pymatting==1.1.1
-filetype==1.0.7
-matplotlib==3.5.1
+tqdm==4.51.0
+waitress==1.4.4

+ 7 - 0
setup.py

@@ -14,6 +14,13 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
 with open("requirements.txt") as f:
     requireds = f.read().splitlines()
 
+if os.getenv("GPU") is None:
+    with open("requirements-cpu.txt") as f:
+        requireds += f.read().splitlines()
+else:
+    with open("requirements-gpu.txt") as f:
+        requireds += f.read().splitlines()
+
 setup(
     name="rembg",
     description="Remove image background",