浏览代码

add providers member to BaseSession

MCYBA 2 年之前
父节点
当前提交
eb1796898f
共有 1 个文件被更改,包括 16 次插入2 次删除
  1. 16 2
      rembg/sessions/base.py

+ 16 - 2
rembg/sessions/base.py

@@ -8,11 +8,25 @@ from PIL.Image import Image as PILImage
 
 
 class BaseSession:
-    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
+    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, providers=None, *args, **kwargs):
         self.model_name = model_name
+        
+        self.providers = []
+        
+        _providers = ort.get_available_providers()
+        if providers:
+            for provider in providers:
+                if provider in _providers:
+                    self.providers.append(provider)
+        else:
+            self.providers.extend(_providers)
+            
+        
+        self.providers=
+        
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models()),
-            providers=ort.get_available_providers(),
+            providers=self.providers,
             sess_options=sess_opts,
         )