Selaa lähdekoodia

Merge pull request #198 from iory/pretrainedmodel

Fixed downloading pretrained model file
Daniel Gatis 3 vuotta sitten
vanhempi
commit
a351881de2
1 muutettua tiedostoa jossa 3 lisäystä ja 3 poistoa
  1. 3 3
      rembg/detect.py

+ 3 - 3
rembg/detect.py

@@ -27,9 +27,9 @@ def ort_session(model_name: str) -> ort.InferenceSession:
     home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
     home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
     path = Path(home).expanduser() / f"{model_name}.onnx"
     path = Path(home).expanduser() / f"{model_name}.onnx"
     path.parents[0].mkdir(parents=True, exist_ok=True)
     path.parents[0].mkdir(parents=True, exist_ok=True)
-    hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
-
-    if not (path.exists() and hashing.hexdigest() == md5):
+    if path.exists():
+        hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
+    if not path.exists() or hashing.hexdigest() != md5:
         with redirect_stdout(sys.stderr):
         with redirect_stdout(sys.stderr):
             gdown.download(url, str(path), use_cookies=False)
             gdown.download(url, str(path), use_cookies=False)