Ver Fonte

handle case that OMP_NUM_THREADS does not exist

runa91 há 3 anos atrás
pai
commit
c0eb38cdc4
1 ficheiros alterados com 6 adições e 3 exclusões
  1. 6 3
      rembg/detect.py

+ 6 - 3
rembg/detect.py

@@ -37,9 +37,12 @@ def ort_session(model_name: str) -> ort.InferenceSession:
             with redirect_stdout(sys.stderr):
                 gdown.download(url, str(path), use_cookies=False)
                 
-    sess_opts = ort.SessionOptions()
-    sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])  
-    sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) 
+    if "OMP_NUM_THREADS" in os.environ:
+        sess_opts = ort.SessionOptions()
+        sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])  
+        sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])  
+    else:
+        sess_opts=None
     
     return ort.InferenceSession(str(path), providers=ort.get_available_providers(), sess_options=sess_opts)