|
@@ -23,13 +23,7 @@ class U2netCustomSession(BaseSession):
|
|
|
if model_path is None:
|
|
|
raise ValueError("model_path is required")
|
|
|
|
|
|
- super().__init__(
|
|
|
- model_name,
|
|
|
- sess_opts,
|
|
|
- providers,
|
|
|
- *args,
|
|
|
- **kwargs
|
|
|
- )
|
|
|
+ super().__init__(model_name, sess_opts, providers, *args, **kwargs)
|
|
|
|
|
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
ort_outs = self.inner_session.run(
|