Procházet zdrojové kódy

add model var envs

Daniel Gatis před 4 roky
rodič
revize
58fd707872
6 změnil soubory, kde provedl 69 přidání a 20 odebrání
  1. 1 0
      requirements.txt
  2. 1 1
      setup.py
  3. 6 2
      src/rembg/bg.py
  4. 7 3
      src/rembg/cmd/cli.py
  5. 14 3
      src/rembg/cmd/server.py
  6. 40 11
      src/rembg/u2net/detect.py

+ 1 - 0
requirements.txt

@@ -10,3 +10,4 @@ requests==2.24.0
 scipy==1.5.4
 pymatting==1.1.1
 filetype==1.0.7
+hsh==1.1.0

+ 1 - 1
setup.py

@@ -11,7 +11,7 @@ with open("requirements.txt") as f:
 
 setup(
     name="rembg",
-    version="1.0.16",
+    version="1.0.17",
     description="Remove image background",
     long_description=long_description,
     long_description_content_type="text/markdown",

+ 6 - 2
src/rembg/bg.py

@@ -1,6 +1,6 @@
+import functools
 import io
 
-import functools
 import numpy as np
 from PIL import Image
 from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
@@ -12,7 +12,11 @@ from .u2net import detect
 
 
 def alpha_matting_cutout(
-    img, mask, foreground_threshold, background_threshold, erode_structure_size,
+    img,
+    mask,
+    foreground_threshold,
+    background_threshold,
+    erode_structure_size,
 ):
     base_size = (1000, 1000)
     size = img.size

+ 7 - 3
src/rembg/cmd/cli.py

@@ -1,8 +1,9 @@
 import argparse
 import glob
 import os
-import filetype
 from distutils.util import strtobool
+
+import filetype
 from tqdm import tqdm
 
 from ..bg import remove
@@ -55,7 +56,10 @@ def main():
     )
 
     ap.add_argument(
-        "-p", "--path", nargs="+", help="Path of a file or a folder of files.",
+        "-p",
+        "--path",
+        nargs="+",
+        help="Path of a file or a folder of files.",
     )
 
     ap.add_argument(
@@ -95,7 +99,7 @@ def main():
 
             if fi_type is None:
                 continue
-            elif fi_type.mime.find('image') < 0:
+            elif fi_type.mime.find("image") < 0:
                 continue
 
             with open(fi, "rb") as input:

+ 14 - 3
src/rembg/cmd/server.py

@@ -36,7 +36,10 @@ def index():
         return {"error": "invalid query param 'model'"}, 400
 
     try:
-        return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
+        return send_file(
+            BytesIO(remove(file_content, model)),
+            mimetype="image/png",
+        )
     except Exception as e:
         app.logger.exception(e, exc_info=True)
         return {"error": "oops, something went wrong!"}, 500
@@ -46,11 +49,19 @@ def main():
     ap = argparse.ArgumentParser()
 
     ap.add_argument(
-        "-a", "--addr", default="0.0.0.0", type=str, help="The IP address to bind to.",
+        "-a",
+        "--addr",
+        default="0.0.0.0",
+        type=str,
+        help="The IP address to bind to.",
     )
 
     ap.add_argument(
-        "-p", "--port", default=5000, type=int, help="The port to bind to.",
+        "-p",
+        "--port",
+        default=5000,
+        type=int,
+        help="The port to bind to.",
     )
 
     args = ap.parse_args()

+ 40 - 11
src/rembg/u2net/detect.py

@@ -9,6 +9,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torchvision
+from hsh.library.hash import Hasher
 from PIL import Image
 from skimage import transform
 from torchvision import transforms
@@ -18,8 +19,8 @@ from . import data_loader, u2net
 
 
 def download_file_from_google_drive(id, fname, destination):
-    if os.path.exists(destination):
-        return
+    head, tail = os.path.split(destination)
+    os.makedirs(head, exist_ok=True)
 
     URL = "https://docs.google.com/uc?export=download"
 
@@ -39,7 +40,11 @@ def download_file_from_google_drive(id, fname, destination):
     total = int(response.headers.get("content-length", 0))
 
     with open(destination, "wb") as file, tqdm(
-        desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
+        desc=f"Downloading {tail} to {head}",
+        total=total,
+        unit="iB",
+        unit_scale=True,
+        unit_divisor=1024,
     ) as bar:
         for data in response.iter_content(chunk_size=1024):
             size = file.write(data)
@@ -47,20 +52,39 @@ def download_file_from_google_drive(id, fname, destination):
 
 
 def load_model(model_name: str = "u2net"):
-    os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True)
+    hasher = Hasher()
 
     if model_name == "u2netp":
         net = u2net.U2NETP(3, 1)
-        path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
-        download_file_from_google_drive(
-            "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
+        path = os.environ.get(
+            "U2NETP_PATH",
+            os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
         )
+        if (
+            not os.path.exists(path)
+            or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
+        ):
+            download_file_from_google_drive(
+                "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
+                "u2netp.pth",
+                path,
+            )
+
     elif model_name == "u2net":
         net = u2net.U2NET(3, 1)
-        path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
-        download_file_from_google_drive(
-            "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
+        path = os.environ.get(
+            "U2NET_PATH",
+            os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
         )
+        if (
+            not os.path.exists(path)
+            or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
+        ):
+            download_file_from_google_drive(
+                "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
+                "u2net.pth",
+                path,
+            )
     else:
         print("Choose between u2net or u2netp", file=sys.stderr)
 
@@ -69,7 +93,12 @@ def load_model(model_name: str = "u2net"):
             net.load_state_dict(torch.load(path))
             net.to(torch.device("cuda"))
         else:
-            net.load_state_dict(torch.load(path, map_location="cpu",))
+            net.load_state_dict(
+                torch.load(
+                    path,
+                    map_location="cpu",
+                )
+            )
     except FileNotFoundError:
         raise FileNotFoundError(
             errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"