|
@@ -1,6 +1,8 @@
|
|
|
|
+import hashlib
|
|
import os
|
|
import os
|
|
import sys
|
|
import sys
|
|
from contextlib import redirect_stdout
|
|
from contextlib import redirect_stdout
|
|
|
|
+from pathlib import Path
|
|
|
|
|
|
import gdown
|
|
import gdown
|
|
import numpy as np
|
|
import numpy as np
|
|
@@ -10,11 +12,6 @@ from skimage import transform
|
|
|
|
|
|
|
|
|
|
def ort_session(model_name: str) -> ort.InferenceSession:
|
|
def ort_session(model_name: str) -> ort.InferenceSession:
|
|
- path = os.environ.get(
|
|
|
|
- "U2NETP_PATH",
|
|
|
|
- os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
if model_name == "u2netp":
|
|
if model_name == "u2netp":
|
|
md5 = "8e83ca70e441ab06c318d82300c84806"
|
|
md5 = "8e83ca70e441ab06c318d82300c84806"
|
|
url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
|
|
url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
|
|
@@ -27,18 +24,21 @@ def ort_session(model_name: str) -> ort.InferenceSession:
|
|
else:
|
|
else:
|
|
assert AssertionError("Choose between u2net, u2netp or u2net_human_seg")
|
|
assert AssertionError("Choose between u2net, u2netp or u2net_human_seg")
|
|
|
|
|
|
- with redirect_stdout(sys.stderr):
|
|
|
|
- gdown.cached_download(url, path, md5=md5)
|
|
|
|
|
|
+ home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
|
|
|
|
+ path = Path(home).expanduser() / f"{model_name}.onnx"
|
|
|
|
+ path.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
|
+
|
|
|
|
+ if not (path.exists() and hashlib.md5(path.read_bytes()).hexdigest() == md5):
|
|
|
|
+ with redirect_stdout(sys.stderr):
|
|
|
|
+ gdown.download(url, str(path), use_cookies=False)
|
|
|
|
|
|
- return ort.InferenceSession(path, providers=ort.get_available_providers())
|
|
|
|
|
|
+ return ort.InferenceSession(str(path), providers=ort.get_available_providers())
|
|
|
|
|
|
|
|
|
|
def norm_pred(d: np.ndarray) -> np.ndarray:
|
|
def norm_pred(d: np.ndarray) -> np.ndarray:
|
|
ma = np.max(d)
|
|
ma = np.max(d)
|
|
mi = np.min(d)
|
|
mi = np.min(d)
|
|
- dn = (d - mi) / (ma - mi)
|
|
|
|
-
|
|
|
|
- return dn
|
|
|
|
|
|
+ return (d - mi) / (ma - mi)
|
|
|
|
|
|
|
|
|
|
def rescale(sample: dict, output_size: int) -> dict:
|
|
def rescale(sample: dict, output_size: int) -> dict:
|
|
@@ -80,6 +80,7 @@ def color(sample: dict) -> dict:
|
|
|
|
|
|
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
|
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
|
image = image / np.max(image)
|
|
image = image / np.max(image)
|
|
|
|
+
|
|
if image.shape[2] == 1:
|
|
if image.shape[2] == 1:
|
|
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
|
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
|
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
|
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|