2
0
Daniel Gatis 2 жил өмнө
parent
commit
41409875cd

+ 3 - 1
rembg/session_factory.py

@@ -8,7 +8,9 @@ from .sessions.base import BaseSession
 from .sessions.u2net import U2netSession
 
 
-def new_session(model_name: str = "u2net", providers=None, *args, **kwargs) -> BaseSession:
+def new_session(
+    model_name: str = "u2net", providers=None, *args, **kwargs
+) -> BaseSession:
     session_class: Type[BaseSession] = U2netSession
 
     for sc in sessions_class:

+ 11 - 4
rembg/sessions/base.py

@@ -8,11 +8,18 @@ from PIL.Image import Image as PILImage
 
 
 class BaseSession:
-    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, providers=None, *args, **kwargs):
+    def __init__(
+        self,
+        model_name: str,
+        sess_opts: ort.SessionOptions,
+        providers=None,
+        *args,
+        **kwargs
+    ):
         self.model_name = model_name
-        
+
         self.providers = []
-        
+
         _providers = ort.get_available_providers()
         if providers:
             for provider in providers:
@@ -20,7 +27,7 @@ class BaseSession:
                     self.providers.append(provider)
         else:
             self.providers.extend(_providers)
-            
+
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models()),
             providers=self.providers,