Daniel Gatis 6 mesi fa
parent
commit
66ffce4894
1 ha cambiato i file con 6 aggiunte e 3 eliminazioni
  1. 6 3
      rembg/sessions/base.py

+ 6 - 3
rembg/sessions/base.py

@@ -15,10 +15,13 @@ class BaseSession:
         self.model_name = model_name
         self.model_name = model_name
 
 
         device_type = ort.get_device()
         device_type = ort.get_device()
-        if device_type == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers():
-            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
+        if (
+            device_type == "GPU"
+            and "CUDAExecutionProvider" in ort.get_available_providers()
+        ):
+            providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
         else:
         else:
-            providers = ['CPUExecutionProvider']
+            providers = ["CPUExecutionProvider"]
 
 
         self.inner_session = ort.InferenceSession(
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models(*args, **kwargs)),
             str(self.__class__.download_models(*args, **kwargs)),