Explorar o código

Fixed download pretrained model file

iory %!s(int64=3) %!d(string=hai) anos
pai
achega
b9c0deb4a3
Modificáronse 1 ficheiros con 3 adicións e 3 borrados
  1. 3 3
      rembg/detect.py

+ 3 - 3
rembg/detect.py

@@ -27,9 +27,9 @@ def ort_session(model_name: str) -> ort.InferenceSession:
     home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
     path = Path(home).expanduser() / f"{model_name}.onnx"
     path.parents[0].mkdir(parents=True, exist_ok=True)
-    hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
-
-    if not (path.exists() and hashing.hexdigest() == md5):
+    if path.exists():
+        hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
+    if not path.exists() or hashing.hexdigest() != md5:
         with redirect_stdout(sys.stderr):
             gdown.download(url, str(path), use_cookies=False)