Parcourir la source

Merge pull request #742 from pgr0ss/default-download-only-u2net

Only download u2net model in docker image
Daniel Gatis il y a 5 mois
Parent
commit
860b39d000
4 fichiers modifiés avec 44 ajouts et 47 suppressions
  1. 1 1
      Dockerfile
  2. 16 5
      rembg/bg.py
  3. 5 4
      rembg/commands/d_command.py
  4. 22 37
      rembg/sessions/__init__.py

+ 1 - 1
Dockerfile

@@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/li
 COPY . .
 
 RUN python -m pip install ".[cpu,cli]"
-RUN rembg d
+RUN rembg d u2net
 
 EXPOSE 7000
 ENTRYPOINT ["rembg"]

+ 16 - 5
rembg/bg.py

@@ -1,4 +1,5 @@
 import io
+import sys
 from enum import Enum
 from typing import Any, List, Optional, Tuple, Union, cast
 
@@ -20,7 +21,7 @@ from pymatting.util.util import stack_images
 from scipy.ndimage import binary_erosion
 
 from .session_factory import new_session
-from .sessions import sessions_class
+from .sessions import sessions, sessions_names
 from .sessions.base import BaseSession
 
 ort.set_default_logger_severity(3)
@@ -194,12 +195,22 @@ def fix_image_orientation(img: PILImage) -> PILImage:
     return cast(PILImage, ImageOps.exif_transpose(img))
 
 
-def download_models() -> None:
+def download_models(models: tuple[str, ...]) -> None:
     """
     Download models for image processing.
     """
-    for session in sessions_class:
-        session.download_models()
+    if len(models) == 0:
+        print("No models specified, downloading all models")
+        models = tuple(sessions_names)
+
+    for model in models:
+        session = sessions.get(model)
+        if session is None:
+            print(f"Error: no model found: {model}")
+            sys.exit(1)
+        else:
+            print(f"Downloading model: {model}")
+            session.download_models()
 
 
 def remove(
@@ -214,7 +225,7 @@ def remove(
     bgcolor: Optional[Tuple[int, int, int, int]] = None,
     force_return_bytes: bool = False,
     *args: Optional[Any],
-    **kwargs: Optional[Any]
+    **kwargs: Optional[Any],
 ) -> Union[bytes, PILImage, np.ndarray]:
     """
     Remove the background from an input image.

+ 5 - 4
rembg/commands/d_command.py

@@ -5,10 +5,11 @@ from ..bg import download_models
 
 @click.command(  # type: ignore
     name="d",
-    help="download all models",
+    help="download models",
 )
-def d_command(*args, **kwargs) -> None:
[email protected]("models", nargs=-1)
+def d_command(models: tuple[str, ...]) -> None:
     """
-    Download all models
+    Download models
     """
-    download_models()
+    download_models(models)

+ 22 - 37
rembg/sessions/__init__.py

@@ -1,93 +1,78 @@
 from __future__ import annotations
 
-from typing import List
+from typing import Dict, List
 
 from .base import BaseSession
 
-sessions_class: List[type[BaseSession]] = []
-sessions_names: List[str] = []
+sessions: Dict[str, type[BaseSession]] = {}
 
 from .birefnet_general import BiRefNetSessionGeneral
 
-sessions_class.append(BiRefNetSessionGeneral)
-sessions_names.append(BiRefNetSessionGeneral.name())
+sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral
 
 from .birefnet_general_lite import BiRefNetSessionGeneralLite
 
-sessions_class.append(BiRefNetSessionGeneralLite)
-sessions_names.append(BiRefNetSessionGeneralLite.name())
+sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite
 
 from .birefnet_portrait import BiRefNetSessionPortrait
 
-sessions_class.append(BiRefNetSessionPortrait)
-sessions_names.append(BiRefNetSessionPortrait.name())
+sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait
 
 from .birefnet_dis import BiRefNetSessionDIS
 
-sessions_class.append(BiRefNetSessionDIS)
-sessions_names.append(BiRefNetSessionDIS.name())
+sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS
 
 from .birefnet_hrsod import BiRefNetSessionHRSOD
 
-sessions_class.append(BiRefNetSessionHRSOD)
-sessions_names.append(BiRefNetSessionHRSOD.name())
+sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD
 
 from .birefnet_cod import BiRefNetSessionCOD
 
-sessions_class.append(BiRefNetSessionCOD)
-sessions_names.append(BiRefNetSessionCOD.name())
+sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD
 
 from .birefnet_massive import BiRefNetSessionMassive
 
-sessions_class.append(BiRefNetSessionMassive)
-sessions_names.append(BiRefNetSessionMassive.name())
+sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive
 
 from .dis_anime import DisSession
 
-sessions_class.append(DisSession)
-sessions_names.append(DisSession.name())
+sessions[DisSession.name()] = DisSession
 
 from .dis_general_use import DisSession as DisSessionGeneralUse
 
-sessions_class.append(DisSessionGeneralUse)
-sessions_names.append(DisSessionGeneralUse.name())
+sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
 
 from .sam import SamSession
 
-sessions_class.append(SamSession)
-sessions_names.append(SamSession.name())
+sessions[SamSession.name()] = SamSession
 
 from .silueta import SiluetaSession
 
-sessions_class.append(SiluetaSession)
-sessions_names.append(SiluetaSession.name())
+sessions[SiluetaSession.name()] = SiluetaSession
 
 from .u2net_cloth_seg import Unet2ClothSession
 
-sessions_class.append(Unet2ClothSession)
-sessions_names.append(Unet2ClothSession.name())
+sessions[Unet2ClothSession.name()] = Unet2ClothSession
 
 from .u2net_custom import U2netCustomSession
 
-sessions_class.append(U2netCustomSession)
-sessions_names.append(U2netCustomSession.name())
+sessions[U2netCustomSession.name()] = U2netCustomSession
 
 from .u2net_human_seg import U2netHumanSegSession
 
-sessions_class.append(U2netHumanSegSession)
-sessions_names.append(U2netHumanSegSession.name())
+sessions[U2netHumanSegSession.name()] = U2netHumanSegSession
 
 from .u2net import U2netSession
 
-sessions_class.append(U2netSession)
-sessions_names.append(U2netSession.name())
+sessions[U2netSession.name()] = U2netSession
 
 from .u2netp import U2netpSession
 
-sessions_class.append(U2netpSession)
-sessions_names.append(U2netpSession.name())
+sessions[U2netpSession.name()] = U2netpSession
 
 from .bria_rmbg import BriaRmBgSession
 
-sessions_class.append(BriaRmBgSession)
-sessions_names.append(BriaRmBgSession.name())
+sessions[BriaRmBgSession.name()] = BriaRmBgSession
+
+sessions_names = list(sessions.keys())
+sessions_class = list(sessions.values())