Browse Source

Merge pull request #725 from danielgatis/remove-explicity-providers

refactor: remove unused providers parameter from session constructors
Daniel Gatis 6 months ago
parent
commit
9079508935
4 changed files with 5 additions and 48 deletions
  1. 2 5
      rembg/session_factory.py
  2. 1 20
      rembg/sessions/base.py
  3. 0 13
      rembg/sessions/sam.py
  4. 2 10
      rembg/sessions/u2net_custom.py

+ 2 - 5
rembg/session_factory.py

@@ -8,9 +8,7 @@ from .sessions.base import BaseSession
 from .sessions.u2net import U2netSession
 
 
-def new_session(
-    model_name: str = "u2net", providers=None, *args, **kwargs
-) -> BaseSession:
+def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
     """
     Create a new session object based on the specified model name.
 
@@ -21,7 +19,6 @@ def new_session(
 
     Parameters:
         model_name (str): The name of the model.
-        providers: The providers for the session.
         *args: Additional positional arguments.
         **kwargs: Additional keyword arguments.
 
@@ -41,4 +38,4 @@ def new_session(
         sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
         sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
 
-    return session_class(model_name, sess_opts, providers, *args, **kwargs)
+    return session_class(model_name, sess_opts, *args, **kwargs)

+ 1 - 20
rembg/sessions/base.py

@@ -10,30 +10,11 @@ from PIL.Image import Image as PILImage
 class BaseSession:
     """This is a base class for managing a session with a machine learning model."""
 
-    def __init__(
-        self,
-        model_name: str,
-        sess_opts: ort.SessionOptions,
-        providers=None,
-        *args,
-        **kwargs
-    ):
+    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
         """Initialize an instance of the BaseSession class."""
         self.model_name = model_name
-
-        self.providers = []
-
-        _providers = ort.get_available_providers()
-        if providers:
-            for provider in providers:
-                if provider in _providers:
-                    self.providers.append(provider)
-        else:
-            self.providers.extend(_providers)
-
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models(*args, **kwargs)),
-            providers=self.providers,
             sess_options=sess_opts,
         )
 

+ 0 - 13
rembg/sessions/sam.py

@@ -87,7 +87,6 @@ class SamSession(BaseSession):
         self,
         model_name: str,
         sess_opts: ort.SessionOptions,
-        providers=None,
         *args,
         **kwargs,
     ):
@@ -102,25 +101,13 @@ class SamSession(BaseSession):
         """
         self.model_name = model_name
 
-        valid_providers = []
-        available_providers = ort.get_available_providers()
-
-        if providers:
-            for provider in providers or []:
-                if provider in available_providers:
-                    valid_providers.append(provider)
-        else:
-            valid_providers.extend(available_providers)
-
         paths = self.__class__.download_models(*args, **kwargs)
         self.encoder = ort.InferenceSession(
             str(paths[0]),
-            providers=valid_providers,
             sess_options=sess_opts,
         )
         self.decoder = ort.InferenceSession(
             str(paths[1]),
-            providers=valid_providers,
             sess_options=sess_opts,
         )
 

+ 2 - 10
rembg/sessions/u2net_custom.py

@@ -13,21 +13,13 @@ from .base import BaseSession
 class U2netCustomSession(BaseSession):
     """This is a class representing a custom session for the U2net model."""
 
-    def __init__(
-        self,
-        model_name: str,
-        sess_opts: ort.SessionOptions,
-        providers=None,
-        *args,
-        **kwargs
-    ):
+    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
         """
         Initialize a new U2netCustomSession object.
 
         Parameters:
             model_name (str): The name of the model.
             sess_opts (ort.SessionOptions): The session options.
-            providers: The providers.
             *args: Additional positional arguments.
             **kwargs: Additional keyword arguments.
 
@@ -38,7 +30,7 @@ class U2netCustomSession(BaseSession):
         if model_path is None:
             raise ValueError("model_path is required")
 
-        super().__init__(model_name, sess_opts, providers, *args, **kwargs)
+        super().__init__(model_name, sess_opts, *args, **kwargs)
 
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         """