|
@@ -10,6 +10,26 @@ from .base import BaseSession
|
|
|
|
|
|
|
|
|
class U2netCustomSession(BaseSession):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ model_name: str,
|
|
|
+ sess_opts: ort.SessionOptions,
|
|
|
+ providers=None,
|
|
|
+ *args,
|
|
|
+ **kwargs
|
|
|
+ ):
|
|
|
+ model_path = kwargs.get("model_path")
|
|
|
+ if model_path is None:
|
|
|
+ raise ValueError("model_path is required")
|
|
|
+
|
|
|
+ super().__init__(
|
|
|
+ model_name,
|
|
|
+ sess_opts,
|
|
|
+ providers,
|
|
|
+ *args,
|
|
|
+ **kwargs
|
|
|
+ )
|
|
|
+
|
|
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
ort_outs = self.inner_session.run(
|
|
|
None,
|