Daniel Gatis 5 роки тому
батько
коміт
39257c4502
5 змінених файлів з 20 додано та 29 видалено
  1. 1 0
      requirements.txt
  2. 1 1
      setup.py
  3. 0 1
      src/rembg/cmd/cli.py
  4. 0 1
      src/rembg/cmd/server.py
  5. 18 26
      src/rembg/u2net/detect.py

+ 1 - 0
requirements.txt

@@ -6,3 +6,4 @@ torch==1.6.0
 torchvision==0.7.0
 waitress==1.4.4
 tqdm==4.48.2
+requests==2.24.0

+ 1 - 1
setup.py

@@ -11,7 +11,7 @@ with open("requirements.txt") as f:
 
 setup(
     name="rembg",
-    version="1.0.6",
+    version="1.0.7",
     description="Remove image background",
     long_description=long_description,
     long_description_content_type="text/markdown",

+ 0 - 1
src/rembg/cmd/cli.py

@@ -3,7 +3,6 @@ import glob
 import imghdr
 import os
 
-
 from ..bg import remove
 
 

+ 0 - 1
src/rembg/cmd/server.py

@@ -8,7 +8,6 @@ from waitress import serve
 
 from ..bg import remove
 
-
 app = Flask(__name__)
 
 

+ 18 - 26
src/rembg/u2net/detect.py

@@ -1,10 +1,10 @@
 import errno
 import os
-import urllib.request
 import sys
+import urllib.request
 
 import numpy as np
-import pkg_resources
+import requests
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -17,43 +17,35 @@ from tqdm import tqdm
 from . import data_loader, u2net
 
 
-class DownloadProgressBar(tqdm):
-    def update_to(self, b=1, bsize=1, tsize=None):
-        if tsize is not None:
-            self.total = tsize
-        self.update(b * bsize - self.n)
-
-
-def download_url(url, model_name, output_path):
-    if os.path.exists(output_path):
+def download(url, fname, path):
+    if os.path.exists(path):
         return
 
-    os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True)
-
-    print(
-        f"Downloading model to {output_path}".format(output_path=output_path),
-        file=sys.stderr,
-    )
-
-    with DownloadProgressBar(
-        unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
-    ) as t:
-        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
+    resp = requests.get(url, stream=True)
+    total = int(resp.headers.get("content-length", 0))
+    with open(path, "wb") as file, tqdm(
+        desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
+    ) as bar:
+        for data in resp.iter_content(chunk_size=1024):
+            size = file.write(data)
+            bar.update(size)
 
 
 def load_model(model_name: str = "u2net"):
+    os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True)
+
     if model_name == "u2netp":
         net = u2net.U2NETP(3, 1)
-        path = os.path.expanduser("~/.u2net/u2netp.pth")
-        download_url(
+        path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
+        download(
             "https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1",
             "u2netp.pth",
             path,
         )
     elif model_name == "u2net":
         net = u2net.U2NET(3, 1)
-        path = os.path.expanduser("~/.u2net/u2net.pth")
-        download_url(
+        path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
+        download(
             "https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1",
             "u2net.pth",
             path,