|
@@ -15,8 +15,8 @@ from skimage import transform
|
|
|
from torchvision import transforms
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
-from . import data_loader, u2net
|
|
|
-
|
|
|
+from .data_loader import RescaleT, ToTensorLab
|
|
|
+from .u2net import U2NETP, U2NET
|
|
|
|
|
|
def download_file_from_google_drive(id, fname, destination):
|
|
|
head, tail = os.path.split(destination)
|
|
@@ -55,7 +55,7 @@ def load_model(model_name: str = "u2net"):
|
|
|
hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
|
|
|
|
|
|
if model_name == "u2netp":
|
|
|
- net = u2net.U2NETP(3, 1)
|
|
|
+ net = U2NETP(3, 1)
|
|
|
path = os.environ.get(
|
|
|
"U2NETP_PATH",
|
|
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
@@ -71,7 +71,7 @@ def load_model(model_name: str = "u2net"):
|
|
|
)
|
|
|
|
|
|
elif model_name == "u2net":
|
|
|
- net = u2net.U2NET(3, 1)
|
|
|
+ net = U2NET(3, 1)
|
|
|
path = os.environ.get(
|
|
|
"U2NET_PATH",
|
|
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
@@ -82,12 +82,12 @@ def load_model(model_name: str = "u2net"):
|
|
|
):
|
|
|
download_file_from_google_drive(
|
|
|
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
|
|
- "u2net.pth",
|
|
|
+ "pth",
|
|
|
path,
|
|
|
)
|
|
|
|
|
|
elif model_name == "u2net_human_seg":
|
|
|
- net = u2net.U2NET(3, 1)
|
|
|
+ net = U2NET(3, 1)
|
|
|
path = os.environ.get(
|
|
|
"U2NET_PATH",
|
|
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
@@ -149,7 +149,7 @@ def preprocess(image):
|
|
|
label = label[:, :, np.newaxis]
|
|
|
|
|
|
transform = transforms.Compose(
|
|
|
- [data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)]
|
|
|
+ [RescaleT(320), ToTensorLab(flag=0)]
|
|
|
)
|
|
|
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
|
|
|
|