2
0
Daniel Gatis 3 жил өмнө
parent
commit
521e49a18a
1 өөрчлөгдсөн 12 нэмэгдсэн , 11 устгасан
  1. 12 11
      rembg/detect.py

+ 12 - 11
rembg/detect.py

@@ -1,6 +1,8 @@
+import hashlib
 import os
 import sys
 from contextlib import redirect_stdout
+from pathlib import Path
 
 import gdown
 import numpy as np
@@ -10,11 +12,6 @@ from skimage import transform
 
 
 def ort_session(model_name: str) -> ort.InferenceSession:
-    path = os.environ.get(
-        "U2NETP_PATH",
-        os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
-    )
-
     if model_name == "u2netp":
         md5 = "8e83ca70e441ab06c318d82300c84806"
         url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
@@ -27,18 +24,21 @@ def ort_session(model_name: str) -> ort.InferenceSession:
     else:
         assert AssertionError("Choose between u2net, u2netp or u2net_human_seg")
 
-    with redirect_stdout(sys.stderr):
-        gdown.cached_download(url, path, md5=md5)
+    home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
+    path = Path(home).expanduser() / f"{model_name}.onnx"
+    path.parents[0].mkdir(parents=True, exist_ok=True)
+
+    if not (path.exists() and hashlib.md5(path.read_bytes()).hexdigest() == md5):
+        with redirect_stdout(sys.stderr):
+            gdown.download(url, str(path), use_cookies=False)
 
-    return ort.InferenceSession(path, providers=ort.get_available_providers())
+    return ort.InferenceSession(str(path), providers=ort.get_available_providers())
 
 
 def norm_pred(d: np.ndarray) -> np.ndarray:
     ma = np.max(d)
     mi = np.min(d)
-    dn = (d - mi) / (ma - mi)
-
-    return dn
+    return (d - mi) / (ma - mi)
 
 
 def rescale(sample: dict, output_size: int) -> dict:
@@ -80,6 +80,7 @@ def color(sample: dict) -> dict:
 
     tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
     image = image / np.max(image)
+
     if image.shape[2] == 1:
         tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
         tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229