2
0
Daniel Gatis 2 жил өмнө
parent
commit
b527af6af5

+ 0 - 5
rembg/sessions/base.py

@@ -28,11 +28,6 @@ class BaseSession:
         else:
             self.providers.extend(_providers)
 
-        model_path = kwargs.get("model_path")
-
-        if model_path is None:
-            raise ValueError("model_path is required")
-
         self.inner_session = ort.InferenceSession(
             str(self.__class__.download_models(*args, **kwargs)),
             providers=self.providers,

+ 20 - 0
rembg/sessions/u2net_custom.py

@@ -10,6 +10,26 @@ from .base import BaseSession
 
 
 class U2netCustomSession(BaseSession):
+    def __init__(
+        self,
+        model_name: str,
+        sess_opts: ort.SessionOptions,
+        providers=None,
+        *args,
+        **kwargs
+    ):
+        model_path = kwargs.get("model_path")
+        if model_path is None:
+            raise ValueError("model_path is required")
+
+        super().__init__(
+            model_name,
+            sess_opts,
+            providers,
+            *args,
+            **kwargs
+        )
+
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         ort_outs = self.inner_session.run(
             None,