|
@@ -9,7 +9,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import torchvision
|
|
|
-from hsh.library.hash import Hasher
|
|
|
+from hashlib import md5
|
|
|
from PIL import Image
|
|
|
from skimage import transform
|
|
|
from torchvision import transforms
|
|
@@ -52,7 +52,7 @@ def download_file_from_google_drive(id, fname, destination):
|
|
|
|
|
|
|
|
|
def load_model(model_name: str = "u2net"):
|
|
|
- hasher = Hasher()
|
|
|
+ hashfile = lambda f: md5(open(f,"rb").read()).hexdigest()
|
|
|
|
|
|
if model_name == "u2netp":
|
|
|
net = u2net.U2NETP(3, 1)
|
|
@@ -62,7 +62,7 @@ def load_model(model_name: str = "u2net"):
|
|
|
)
|
|
|
if (
|
|
|
not os.path.exists(path)
|
|
|
- or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
|
|
+ or hashfile(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
|
|
):
|
|
|
download_file_from_google_drive(
|
|
|
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
|
|
@@ -78,7 +78,7 @@ def load_model(model_name: str = "u2net"):
|
|
|
)
|
|
|
if (
|
|
|
not os.path.exists(path)
|
|
|
- or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
|
|
+ or hashfile(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
|
|
):
|
|
|
download_file_from_google_drive(
|
|
|
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
|
@@ -94,7 +94,7 @@ def load_model(model_name: str = "u2net"):
|
|
|
)
|
|
|
if (
|
|
|
not os.path.exists(path)
|
|
|
- or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
|
|
|
+ or hashfile(path) != "09fb4e49b7f785c9f855baf94916840a"
|
|
|
):
|
|
|
download_file_from_google_drive(
|
|
|
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
|