Browse Source

fix linters

Daniel Gatis 1 year ago
parent
commit
5c65374c9e

+ 1 - 0
rembg/sessions/base.py

@@ -9,6 +9,7 @@ from PIL.Image import Image as PILImage
 
 class BaseSession:
     """This is a base class for managing a session with a machine learning model."""
+
     def __init__(
         self,
         model_name: str,

+ 1 - 0
rembg/sessions/dis_anime.py

@@ -13,6 +13,7 @@ class DisSession(BaseSession):
     """
     This class represents a session for object detection.
     """
+
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         """
         Use a pre-trained model to predict the object in the given image.

+ 5 - 4
rembg/sessions/sam.py

@@ -58,6 +58,7 @@ class SamSession(BaseSession):
         *args: Variable length argument list.
         **kwargs: Arbitrary keyword arguments.
     """
+
     def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
         """
         Initialize a new SamSession with the given model name and session options.
@@ -181,7 +182,7 @@ class SamSession(BaseSession):
 
     @classmethod
     def download_models(cls, *args, **kwargs):
-        '''
+        """
         Class method to download ONNX model files.
 
         This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
@@ -193,7 +194,7 @@ class SamSession(BaseSession):
 
         Returns:
             tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
-        '''
+        """
         fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
         fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
 
@@ -224,7 +225,7 @@ class SamSession(BaseSession):
 
     @classmethod
     def name(cls, *args, **kwargs):
-        '''
+        """
         Class method to return a string value.
 
         This method returns the string value 'sam'.
@@ -236,5 +237,5 @@ class SamSession(BaseSession):
 
         Returns:
             str: The string value 'sam'.
-        '''
+        """
         return "sam"

+ 1 - 0
rembg/sessions/silueta.py

@@ -11,6 +11,7 @@ from .base import BaseSession
 
 class SiluetaSession(BaseSession):
     """This is a class representing a SiluetaSession object."""
+
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         """
         Predict the mask of the input image.

+ 1 - 0
rembg/sessions/u2net.py

@@ -13,6 +13,7 @@ class U2netSession(BaseSession):
     """
     This class represents a U2net session, which is a subclass of BaseSession.
     """
+
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         """
         Predicts the output masks for the input image using the inner session.

+ 1 - 0
rembg/sessions/u2net_custom.py

@@ -12,6 +12,7 @@ from .base import BaseSession
 
 class U2netCustomSession(BaseSession):
     """This is a class representing a custom session for the U2net model."""
+
     def __init__(
         self,
         model_name: str,

+ 1 - 0
rembg/sessions/u2net_human_seg.py

@@ -13,6 +13,7 @@ class U2netHumanSegSession(BaseSession):
     """
     This class represents a session for performing human segmentation using the U2Net model.
     """
+
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         """
         Predicts human segmentation masks for the input image.

+ 1 - 0
rembg/sessions/u2netp.py

@@ -11,6 +11,7 @@ from .base import BaseSession
 
 class U2netpSession(BaseSession):
     """This class represents a session for using the U2netp model."""
+
     def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
         """
         Predicts the mask for the given image using the U2netp model.