فهرست منبع

ensure CUDAExecutionProvider is called in session when using nvidia gpu.

catscarlet 6 ماه پیش
والد
کامیت
aa6fb76e0f
1فایلهای تغییر یافته به همراه8 افزوده شده و 0 حذف شده
  1. 8 0
      rembg/sessions/base.py

+ 8 - 0
rembg/sessions/base.py

@@ -13,9 +13,17 @@ class BaseSession:
     def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
         """Initialize an instance of the BaseSession class."""
         self.model_name = model_name
+
+        device_type = ort.get_device()
+        if device_type == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers():
+            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
+        else:
+            providers = ['CPUExecutionProvider']
+
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models(*args, **kwargs)),
             sess_options=sess_opts,
+            providers=providers,
         )
 
     def normalize(