浏览代码

Make sam model respect providers argument (#634)

James Alsop 1 年之前
父节点
当前提交
a6a94a4bad
共有 1 个文件被更改,包括 21 次插入3 次删除
  1. 21 3
      rembg/sessions/sam.py

+ 21 - 3
rembg/sessions/sam.py

@@ -83,7 +83,14 @@ class SamSession(BaseSession):
         **kwargs: Arbitrary keyword arguments.
     """
 
-    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
+    def __init__(
+        self,
+        model_name: str,
+        sess_opts: ort.SessionOptions,
+        providers=None,
+        *args,
+        **kwargs,
+    ):
         """
         Initialize a new SamSession with the given model name and session options.
 
@@ -94,15 +101,26 @@ class SamSession(BaseSession):
             **kwargs: Arbitrary keyword arguments.
         """
         self.model_name = model_name
+
+        self.providers = []
+
+        _providers = ort.get_available_providers()
+        if providers:
+            for provider in providers:
+                if provider in _providers:
+                    self.providers.append(provider)
+        else:
+            self.providers.extend(_providers)
+
         paths = self.__class__.download_models(*args, **kwargs)
         self.encoder = ort.InferenceSession(
             str(paths[0]),
-            providers=ort.get_available_providers(),
+            providers=self.providers,
             sess_options=sess_opts,
         )
         self.decoder = ort.InferenceSession(
             str(paths[1]),
-            providers=ort.get_available_providers(),
+            providers=self.providers,
             sess_options=sess_opts,
         )