|
@@ -58,6 +58,7 @@ class SamSession(BaseSession):
|
|
|
*args: Variable length argument list.
|
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
"""
|
|
|
+
|
|
|
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
|
|
"""
|
|
|
Initialize a new SamSession with the given model name and session options.
|
|
@@ -181,7 +182,7 @@ class SamSession(BaseSession):
|
|
|
|
|
|
@classmethod
|
|
|
def download_models(cls, *args, **kwargs):
|
|
|
- '''
|
|
|
+ """
|
|
|
Class method to download ONNX model files.
|
|
|
|
|
|
This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
|
|
@@ -193,7 +194,7 @@ class SamSession(BaseSession):
|
|
|
|
|
|
Returns:
|
|
|
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
|
|
|
- '''
|
|
|
+ """
|
|
|
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
|
|
|
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
|
|
|
|
|
@@ -224,7 +225,7 @@ class SamSession(BaseSession):
|
|
|
|
|
|
@classmethod
|
|
|
def name(cls, *args, **kwargs):
|
|
|
- '''
|
|
|
+ """
|
|
|
Class method to return a string value.
|
|
|
|
|
|
This method returns the string value 'sam'.
|
|
@@ -236,5 +237,5 @@ class SamSession(BaseSession):
|
|
|
|
|
|
Returns:
|
|
|
str: The string value 'sam'.
|
|
|
- '''
|
|
|
+ """
|
|
|
return "sam"
|