|
@@ -17,16 +17,31 @@ from tqdm import tqdm
|
|
|
from . import data_loader, u2net
|
|
|
|
|
|
|
|
|
-def download(url, fname, path):
|
|
|
- if os.path.exists(path):
|
|
|
+def download_file_from_google_drive(id, fname, destination):
|
|
|
+ if os.path.exists(destination):
|
|
|
return
|
|
|
|
|
|
- resp = requests.get(url, stream=True)
|
|
|
- total = int(resp.headers.get("content-length", 0))
|
|
|
- with open(path, "wb") as file, tqdm(
|
|
|
+ URL = "https://docs.google.com/uc?export=download"
|
|
|
+
|
|
|
+ session = requests.Session()
|
|
|
+ response = session.get(URL, params={"id": id}, stream=True)
|
|
|
+
|
|
|
+ token = None
|
|
|
+ for key, value in response.cookies.items():
|
|
|
+ if key.startswith("download_warning"):
|
|
|
+ token = value
|
|
|
+ break
|
|
|
+
|
|
|
+ if token:
|
|
|
+ params = {"id": id, "confirm": token}
|
|
|
+ response = session.get(URL, params=params, stream=True)
|
|
|
+
|
|
|
+ total = int(response.headers.get("content-length", 0))
|
|
|
+
|
|
|
+ with open(destination, "wb") as file, tqdm(
|
|
|
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
|
|
|
) as bar:
|
|
|
- for data in resp.iter_content(chunk_size=1024):
|
|
|
+ for data in response.iter_content(chunk_size=1024):
|
|
|
size = file.write(data)
|
|
|
bar.update(size)
|
|
|
|
|
@@ -37,18 +52,14 @@ def load_model(model_name: str = "u2net"):
|
|
|
if model_name == "u2netp":
|
|
|
net = u2net.U2NETP(3, 1)
|
|
|
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
|
|
- download(
|
|
|
- "https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1",
|
|
|
- "u2netp.pth",
|
|
|
- path,
|
|
|
+ download_file_from_google_drive(
|
|
|
+ "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
|
|
|
)
|
|
|
elif model_name == "u2net":
|
|
|
net = u2net.U2NET(3, 1)
|
|
|
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
|
|
- download(
|
|
|
- "https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1",
|
|
|
- "u2net.pth",
|
|
|
- path,
|
|
|
+ download_file_from_google_drive(
|
|
|
+ "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
|
|
|
)
|
|
|
else:
|
|
|
print("Choose between u2net or u2netp", file=sys.stderr)
|