Jelajahi Sumber

a fix for number of threads in cluster infrastructure

runa91 3 tahun lalu
induk
melakukan
17340088d4
1 mengubah file dengan 6 tambahan dan 2 penghapusan
  1. 6 2
      rembg/detect.py

+ 6 - 2
rembg/detect.py

@@ -36,8 +36,12 @@ def ort_session(model_name: str) -> ort.InferenceSession:
         if hashing.hexdigest() != md5:
         if hashing.hexdigest() != md5:
             with redirect_stdout(sys.stderr):
             with redirect_stdout(sys.stderr):
                 gdown.download(url, str(path), use_cookies=False)
                 gdown.download(url, str(path), use_cookies=False)
-
-    return ort.InferenceSession(str(path), providers=ort.get_available_providers())
+                
+    sess_opts = ort.SessionOptions()
+    sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])  
+    sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"]) 
+    
+    return ort.InferenceSession(str(path), providers=ort.get_available_providers(), sess_options=sess_opts)
 
 
 
 
 def norm_pred(d: np.ndarray) -> np.ndarray:
 def norm_pred(d: np.ndarray) -> np.ndarray: