|
@@ -27,11 +27,15 @@ def ort_session(model_name: str) -> ort.InferenceSession:
|
|
home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
|
|
home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
|
|
path = Path(home).expanduser() / f"{model_name}.onnx"
|
|
path = Path(home).expanduser() / f"{model_name}.onnx"
|
|
path.parents[0].mkdir(parents=True, exist_ok=True)
|
|
path.parents[0].mkdir(parents=True, exist_ok=True)
|
|
- if path.exists():
|
|
|
|
- hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
|
|
|
|
- if not path.exists() or hashing.hexdigest() != md5:
|
|
|
|
|
|
+
|
|
|
|
+ if not path.exists():
|
|
with redirect_stdout(sys.stderr):
|
|
with redirect_stdout(sys.stderr):
|
|
gdown.download(url, str(path), use_cookies=False)
|
|
gdown.download(url, str(path), use_cookies=False)
|
|
|
|
+ else:
|
|
|
|
+ hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
|
|
|
|
+ if hashing.hexdigest() != md5:
|
|
|
|
+ with redirect_stdout(sys.stderr):
|
|
|
|
+ gdown.download(url, str(path), use_cookies=False)
|
|
|
|
|
|
return ort.InferenceSession(str(path), providers=ort.get_available_providers())
|
|
return ort.InferenceSession(str(path), providers=ort.get_available_providers())
|
|
|
|
|