2
0
Эх сурвалжийг харах

Merge pull request #456 from MCYBA/main

Add onnxruntime providers selection feature
Daniel Gatis 2 жил өмнө
parent
commit
8b6abef6cd

+ 2 - 2
rembg/session_factory.py

@@ -8,7 +8,7 @@ from .sessions.base import BaseSession
 from .sessions.u2net import U2netSession
 
 
-def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
+def new_session(model_name: str = "u2net", providers=None, *args, **kwargs) -> BaseSession:
     session_class: Type[BaseSession] = U2netSession
 
     for sc in sessions_class:
@@ -21,4 +21,4 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
     if "OMP_NUM_THREADS" in os.environ:
         sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
 
-    return session_class(model_name, sess_opts, *args, **kwargs)
+    return session_class(model_name, sess_opts, providers, *args, **kwargs)

+ 13 - 2
rembg/sessions/base.py

@@ -8,11 +8,22 @@ from PIL.Image import Image as PILImage
 
 
 class BaseSession:
-    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
+    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, providers=None, *args, **kwargs):
         self.model_name = model_name
+        
+        self.providers = []
+        
+        _providers = ort.get_available_providers()
+        if providers:
+            for provider in providers:
+                if provider in _providers:
+                    self.providers.append(provider)
+        else:
+            self.providers.extend(_providers)
+            
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models()),
-            providers=ort.get_available_providers(),
+            providers=self.providers,
             sess_options=sess_opts,
         )