|
@@ -5,7 +5,7 @@ from contextlib import redirect_stdout
|
|
|
from pathlib import Path
|
|
|
from typing import Type
|
|
|
|
|
|
-import gdown
|
|
|
+import pooch
|
|
|
import onnxruntime as ort
|
|
|
|
|
|
from .session_base import BaseSession
|
|
@@ -18,39 +18,40 @@ def new_session(model_name: str) -> BaseSession:
|
|
|
|
|
|
if model_name == "u2netp":
|
|
|
md5 = "8e83ca70e441ab06c318d82300c84806"
|
|
|
- url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
|
|
|
+ url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
|
|
|
session_class = SimpleSession
|
|
|
elif model_name == "u2net":
|
|
|
md5 = "60024c5c889badc19c04ad937298a77b"
|
|
|
- url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
|
|
|
+ url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
|
|
session_class = SimpleSession
|
|
|
elif model_name == "u2net_human_seg":
|
|
|
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
|
|
- url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
|
|
|
+ url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
|
|
|
session_class = SimpleSession
|
|
|
elif model_name == "u2net_cloth_seg":
|
|
|
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
|
|
|
- url = "https://drive.google.com/uc?id=15rKbQSXQzrKCQurUjZFg8HqzZad8bcyz"
|
|
|
+ url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
|
|
|
session_class = ClothSession
|
|
|
else:
|
|
|
assert AssertionError(
|
|
|
"Choose between u2net, u2netp, u2net_human_seg or u2net_cloth_seg"
|
|
|
)
|
|
|
|
|
|
- home = os.getenv(
|
|
|
+ u2net_home = os.getenv(
|
|
|
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
|
|
)
|
|
|
- path = Path(home).expanduser() / f"{model_name}.onnx"
|
|
|
- path.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
- if not path.exists():
|
|
|
- with redirect_stdout(sys.stderr):
|
|
|
- gdown.download(url, str(path), use_cookies=False)
|
|
|
- else:
|
|
|
- hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
|
|
|
- if hashing.hexdigest() != md5:
|
|
|
- with redirect_stdout(sys.stderr):
|
|
|
- gdown.download(url, str(path), use_cookies=False)
|
|
|
+ fname = f"{model_name}.onnx"
|
|
|
+ path = Path(u2net_home).expanduser()
|
|
|
+ full_path = Path(u2net_home).expanduser() / fname
|
|
|
+
|
|
|
+ pooch.retrieve(
|
|
|
+ url,
|
|
|
+ f"md5:{md5}",
|
|
|
+ fname=fname,
|
|
|
+ path=Path(u2net_home).expanduser(),
|
|
|
+ progressbar=True
|
|
|
+ )
|
|
|
|
|
|
sess_opts = ort.SessionOptions()
|
|
|
|
|
@@ -60,6 +61,6 @@ def new_session(model_name: str) -> BaseSession:
|
|
|
return session_class(
|
|
|
model_name,
|
|
|
ort.InferenceSession(
|
|
|
- str(path), providers=ort.get_available_providers(), sess_options=sess_opts
|
|
|
+ str(full_path), providers=ort.get_available_providers(), sess_options=sess_opts
|
|
|
),
|
|
|
)
|