|
@@ -13,21 +13,13 @@ from .base import BaseSession
|
|
|
class U2netCustomSession(BaseSession):
|
|
|
"""This is a class representing a custom session for the U2net model."""
|
|
|
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- model_name: str,
|
|
|
- sess_opts: ort.SessionOptions,
|
|
|
- providers=None,
|
|
|
- *args,
|
|
|
- **kwargs
|
|
|
- ):
|
|
|
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
|
|
"""
|
|
|
Initialize a new U2netCustomSession object.
|
|
|
|
|
|
Parameters:
|
|
|
model_name (str): The name of the model.
|
|
|
sess_opts (ort.SessionOptions): The session options.
|
|
|
- providers: The providers.
|
|
|
*args: Additional positional arguments.
|
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
|
@@ -38,7 +30,7 @@ class U2netCustomSession(BaseSession):
|
|
|
if model_path is None:
|
|
|
raise ValueError("model_path is required")
|
|
|
|
|
|
- super().__init__(model_name, sess_opts, providers, *args, **kwargs)
|
|
|
+ super().__init__(model_name, sess_opts, *args, **kwargs)
|
|
|
|
|
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
"""
|