Daniel Gatis 3 ani în urmă
părinte
comite
e67c6407dd
1 a modificat fișierele cu 9 adăugiri și 8 ștergeri
  1. 9 8
      rembg/detect.py

+ 9 - 8
rembg/detect.py

@@ -36,15 +36,16 @@ def ort_session(model_name: str) -> ort.InferenceSession:
         if hashing.hexdigest() != md5:
             with redirect_stdout(sys.stderr):
                 gdown.download(url, str(path), use_cookies=False)
-                
+
+    sess_opts = ort.SessionOptions()
+
     if "OMP_NUM_THREADS" in os.environ:
-        sess_opts = ort.SessionOptions()
-        sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])  
-        sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])  
-    else:
-        sess_opts=None
-    
-    return ort.InferenceSession(str(path), providers=ort.get_available_providers(), sess_options=sess_opts)
+        sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
+        sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
+
+    return ort.InferenceSession(
+        str(path), providers=ort.get_available_providers(), sess_options=sess_opts
+    )
 
 
 def norm_pred(d: np.ndarray) -> np.ndarray: