|
@@ -85,8 +85,24 @@ def load_model(model_name: str = "u2net"):
|
|
"u2net.pth",
|
|
"u2net.pth",
|
|
path,
|
|
path,
|
|
)
|
|
)
|
|
|
|
+
|
|
|
|
+ elif model_name == "u2net_human_seg":
|
|
|
|
+ net = u2net.U2NET(3, 1)
|
|
|
|
+ path = os.environ.get(
|
|
|
|
+ "U2NET_PATH",
|
|
|
|
+ os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
|
|
+ )
|
|
|
|
+ if (
|
|
|
|
+ not os.path.exists(path)
|
|
|
|
+ or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
|
|
|
|
+ ):
|
|
|
|
+ download_file_from_google_drive(
|
|
|
|
+ "1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
|
|
|
|
+ "u2net_human_seg.pth",
|
|
|
|
+ path,
|
|
|
|
+ )
|
|
else:
|
|
else:
|
|
- print("Choose between u2net or u2netp", file=sys.stderr)
|
|
|
|
|
|
+ print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
|
|
|
|
|
|
try:
|
|
try:
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|