|
@@ -9,6 +9,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import torchvision
|
|
|
+from hsh.library.hash import Hasher
|
|
|
from PIL import Image
|
|
|
from skimage import transform
|
|
|
from torchvision import transforms
|
|
@@ -18,8 +19,8 @@ from . import data_loader, u2net
|
|
|
|
|
|
|
|
|
def download_file_from_google_drive(id, fname, destination):
|
|
|
- if os.path.exists(destination):
|
|
|
- return
|
|
|
+ head, tail = os.path.split(destination)
|
|
|
+ os.makedirs(head, exist_ok=True)
|
|
|
|
|
|
URL = "https://docs.google.com/uc?export=download"
|
|
|
|
|
@@ -39,7 +40,11 @@ def download_file_from_google_drive(id, fname, destination):
|
|
|
total = int(response.headers.get("content-length", 0))
|
|
|
|
|
|
with open(destination, "wb") as file, tqdm(
|
|
|
- desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
|
|
|
+ desc=f"Downloading {tail} to {head}",
|
|
|
+ total=total,
|
|
|
+ unit="iB",
|
|
|
+ unit_scale=True,
|
|
|
+ unit_divisor=1024,
|
|
|
) as bar:
|
|
|
for data in response.iter_content(chunk_size=1024):
|
|
|
size = file.write(data)
|
|
@@ -47,20 +52,39 @@ def download_file_from_google_drive(id, fname, destination):
|
|
|
|
|
|
|
|
|
def load_model(model_name: str = "u2net"):
|
|
|
- os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True)
|
|
|
+ hasher = Hasher()
|
|
|
|
|
|
if model_name == "u2netp":
|
|
|
net = u2net.U2NETP(3, 1)
|
|
|
- path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
|
|
- download_file_from_google_drive(
|
|
|
- "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
|
|
|
+ path = os.environ.get(
|
|
|
+ "U2NETP_PATH",
|
|
|
+ os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
|
)
|
|
|
+ if (
|
|
|
+ not os.path.exists(path)
|
|
|
+ or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
|
|
+ ):
|
|
|
+ download_file_from_google_drive(
|
|
|
+ "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
|
|
|
+ "u2netp.pth",
|
|
|
+ path,
|
|
|
+ )
|
|
|
+
|
|
|
elif model_name == "u2net":
|
|
|
net = u2net.U2NET(3, 1)
|
|
|
- path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
|
|
- download_file_from_google_drive(
|
|
|
- "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
|
|
|
+ path = os.environ.get(
|
|
|
+ "U2NET_PATH",
|
|
|
+ os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
|
)
|
|
|
+ if (
|
|
|
+ not os.path.exists(path)
|
|
|
+ or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
|
|
+ ):
|
|
|
+ download_file_from_google_drive(
|
|
|
+ "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
|
|
+ "u2net.pth",
|
|
|
+ path,
|
|
|
+ )
|
|
|
else:
|
|
|
print("Choose between u2net or u2netp", file=sys.stderr)
|
|
|
|
|
@@ -69,7 +93,12 @@ def load_model(model_name: str = "u2net"):
|
|
|
net.load_state_dict(torch.load(path))
|
|
|
net.to(torch.device("cuda"))
|
|
|
else:
|
|
|
- net.load_state_dict(torch.load(path, map_location="cpu",))
|
|
|
+ net.load_state_dict(
|
|
|
+ torch.load(
|
|
|
+ path,
|
|
|
+ map_location="cpu",
|
|
|
+ )
|
|
|
+ )
|
|
|
except FileNotFoundError:
|
|
|
raise FileNotFoundError(
|
|
|
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
|