|
@@ -8,6 +8,7 @@ from PIL.Image import Image as PILImage
|
|
|
|
|
|
from .base import BaseSession
|
|
|
|
|
|
+
|
|
|
class DisCustomSession(BaseSession):
|
|
|
"""This is a class representing a custom session for the Dis model."""
|
|
|
|
|
@@ -27,7 +28,6 @@ class DisCustomSession(BaseSession):
|
|
|
|
|
|
super().__init__(model_name, sess_opts, *args, **kwargs)
|
|
|
|
|
|
-
|
|
|
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
"""
|
|
|
Predicts the mask image for the input image.
|