Procházet zdrojové kódy

Merge pull request #768 from fruitful-ai/main

Daniel Gatis před 1 měsícem
rodič
revize
41511fc69a
1 změnil soubory, kde provedl 16 přidání a 13 odebrání
  1. 16 13
      rembg/sessions/base.py

+ 16 - 13
rembg/sessions/base.py

@@ -13,20 +13,23 @@ class BaseSession:
     def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
         """Initialize an instance of the BaseSession class."""
         self.model_name = model_name
-
-        device_type = ort.get_device()
-        if (
-            device_type == "GPU"
-            and "CUDAExecutionProvider" in ort.get_available_providers()
-        ):
-            providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
-        elif (
-            device_type[0:3] == "GPU"
-            and "ROCMExecutionProvider" in ort.get_available_providers()
-        ):
-            providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
+        
+        if "providers" in kwargs and isinstance(kwargs["providers"], list):
+            providers = kwargs.pop("providers")
         else:
-            providers = ["CPUExecutionProvider"]
+            device_type = ort.get_device()
+            if (
+                device_type == "GPU"
+                and "CUDAExecutionProvider" in ort.get_available_providers()
+            ):
+                providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+            elif (
+                device_type[0:3] == "GPU"
+                and "ROCMExecutionProvider" in ort.get_available_providers()
+            ):
+                providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
+            else:
+                providers = ["CPUExecutionProvider"]
 
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models(*args, **kwargs)),