|
@@ -36,15 +36,16 @@ def ort_session(model_name: str) -> ort.InferenceSession:
|
|
|
if hashing.hexdigest() != md5:
|
|
|
with redirect_stdout(sys.stderr):
|
|
|
gdown.download(url, str(path), use_cookies=False)
|
|
|
-
|
|
|
+
|
|
|
+ sess_opts = ort.SessionOptions()
|
|
|
+
|
|
|
if "OMP_NUM_THREADS" in os.environ:
|
|
|
- sess_opts = ort.SessionOptions()
|
|
|
- sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
|
|
- sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
|
|
- else:
|
|
|
- sess_opts=None
|
|
|
-
|
|
|
- return ort.InferenceSession(str(path), providers=ort.get_available_providers(), sess_options=sess_opts)
|
|
|
+ sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
|
|
+ sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
|
|
+
|
|
|
+ return ort.InferenceSession(
|
|
|
+ str(path), providers=ort.get_available_providers(), sess_options=sess_opts
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def norm_pred(d: np.ndarray) -> np.ndarray:
|