|
@@ -1,10 +1,10 @@
|
|
|
import errno
|
|
|
import os
|
|
|
-import urllib.request
|
|
|
import sys
|
|
|
+import urllib.request
|
|
|
|
|
|
import numpy as np
|
|
|
-import pkg_resources
|
|
|
+import requests
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
@@ -17,43 +17,35 @@ from tqdm import tqdm
|
|
|
from . import data_loader, u2net
|
|
|
|
|
|
|
|
|
-class DownloadProgressBar(tqdm):
|
|
|
- def update_to(self, b=1, bsize=1, tsize=None):
|
|
|
- if tsize is not None:
|
|
|
- self.total = tsize
|
|
|
- self.update(b * bsize - self.n)
|
|
|
-
|
|
|
-
|
|
|
-def download_url(url, model_name, output_path):
|
|
|
- if os.path.exists(output_path):
|
|
|
+def download(url, fname, path):
|
|
|
+ if os.path.exists(path):
|
|
|
return
|
|
|
|
|
|
- os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True)
|
|
|
-
|
|
|
- print(
|
|
|
- f"Downloading model to {output_path}".format(output_path=output_path),
|
|
|
- file=sys.stderr,
|
|
|
- )
|
|
|
-
|
|
|
- with DownloadProgressBar(
|
|
|
- unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
|
|
|
- ) as t:
|
|
|
- urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
|
|
|
+ resp = requests.get(url, stream=True)
|
|
|
+ total = int(resp.headers.get("content-length", 0))
|
|
|
+ with open(path, "wb") as file, tqdm(
|
|
|
+ desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
|
|
|
+ ) as bar:
|
|
|
+ for data in resp.iter_content(chunk_size=1024):
|
|
|
+ size = file.write(data)
|
|
|
+ bar.update(size)
|
|
|
|
|
|
|
|
|
def load_model(model_name: str = "u2net"):
|
|
|
+ os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True)
|
|
|
+
|
|
|
if model_name == "u2netp":
|
|
|
net = u2net.U2NETP(3, 1)
|
|
|
- path = os.path.expanduser("~/.u2net/u2netp.pth")
|
|
|
- download_url(
|
|
|
+ path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
|
|
+ download(
|
|
|
"https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1",
|
|
|
"u2netp.pth",
|
|
|
path,
|
|
|
)
|
|
|
elif model_name == "u2net":
|
|
|
net = u2net.U2NET(3, 1)
|
|
|
- path = os.path.expanduser("~/.u2net/u2net.pth")
|
|
|
- download_url(
|
|
|
+ path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
|
|
+ download(
|
|
|
"https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1",
|
|
|
"u2net.pth",
|
|
|
path,
|