Daniel Gatis hace 1 año
padre
commit
4beb98dbc7

+ 1 - 1
.github/workflows/lint_python.yml

@@ -15,5 +15,5 @@ jobs:
             - run: mypy --install-types --non-interactive --ignore-missing-imports ./rembg
             - run: bandit --recursive --skip B101,B104,B310,B311,B303,B110 --exclude ./rembg/_version.py ./rembg
             - run: black --force-exclude rembg/_version.py --check --diff ./rembg
-            - run: flake8 ./rembg --count --ignore=B008,C901,E203,E266,E731,F401,F811,F841,W503,E501 --show-source --statistics --exclude ./rembg/_version.py
+            - run: flake8 ./rembg --count --ignore=B008,C901,E203,E266,E731,F401,F811,F841,W503,E501,E402 --show-source --statistics --exclude ./rembg/_version.py
             - run: isort --check-only --profile black ./rembg

+ 0 - 1
.gitignore

@@ -3,7 +3,6 @@ build/
 dist/
 .venv/
 .direnv/
-*.spec
 *.egg-info/
 *.egg
 *.py[cod]

+ 2 - 2
Dockerfile

@@ -7,8 +7,8 @@ RUN pip install --upgrade pip
 COPY . .
 
 RUN python -m pip install ".[cli]"
-RUN python -c 'from rembg.bg import download_models; download_models()'
+RUN rembg d
 
-EXPOSE 5000
+EXPOSE 7000
 ENTRYPOINT ["rembg"]
 CMD ["--help"]

+ 4 - 4
README.md

@@ -191,21 +191,21 @@ rembg p -w path/to/input path/to/output
 Used to start http server.
 
 ```
-rembg s --host 0.0.0.0 --port 5000 --log_level info
+rembg s --host 0.0.0.0 --port 7000 --log_level info
 ```
 
-To see the complete endpoints documentation, go to: `http://localhost:5000/api`.
+To see the complete endpoints documentation, go to: `http://localhost:7000/api`.
 
 Remove the background from an image url
 
 ```
-curl -s "http://localhost:5000/api/remove?url=http://input.png" -o output.png
+curl -s "http://localhost:7000/api/remove?url=http://input.png" -o output.png
 ```
 
 Remove the background from an uploaded image
 
 ```
-curl -s -F file=@/path/to/input.jpg "http://localhost:5000/api/remove"  -o output.png
+curl -s -F file=@/path/to/input.jpg "http://localhost:7000/api/remove"  -o output.png
 ```
 
 ### rembg `b`

+ 0 - 3
build-exe

@@ -1,3 +0,0 @@
-#!/bin/sh
-
-pyinstaller -y -p ./rembg rembg.py

+ 8 - 0
build-exe.ps1

@@ -0,0 +1,8 @@
+# Install required packages
+# pip install -e ".[cli]"
+
+# Create PyInstaller spec file with specified data collections
+# pyi-makespec --collect-data=gradio_client --collect-data=gradio rembg.py
+
+# Run PyInstaller with the generated spec file
+pyinstaller rembg.spec

+ 51 - 0
rembg.spec

@@ -0,0 +1,51 @@
+# -*- mode: python ; coding: utf-8 -*-
+from PyInstaller.utils.hooks import collect_data_files
+
+datas = []
+datas += collect_data_files('gradio_client')
+datas += collect_data_files('gradio')
+
+
+a = Analysis(
+    ['rembg.py'],
+    pathex=[],
+    binaries=[],
+    datas=datas,
+    hiddenimports=[],
+    hookspath=[],
+    hooksconfig={},
+    runtime_hooks=[],
+    excludes=[],
+    noarchive=False,
+    module_collection_mode={
+        'gradio': 'py',
+    },
+)
+pyz = PYZ(a.pure)
+
+exe = EXE(
+    pyz,
+    a.scripts,
+    [],
+    exclude_binaries=True,
+    name='rembg',
+    debug=False,
+    bootloader_ignore_signals=False,
+    strip=False,
+    upx=True,
+    console=True,
+    disable_windowed_traceback=False,
+    argv_emulation=False,
+    target_arch=None,
+    codesign_identity=None,
+    entitlements_file=None,
+)
+coll = COLLECT(
+    exe,
+    a.binaries,
+    a.datas,
+    strip=False,
+    upx=True,
+    upx_exclude=[],
+    name='rembg',
+)

+ 10 - 27
rembg/cli.py

@@ -1,33 +1,16 @@
-import pkg_resources
+import click
 
+from . import _version
+from .commands import command_functions
 
-def main() -> None:
-    package_distribution = pkg_resources.get_distribution("rembg")
 
-    for extra in package_distribution.extras:
-        if extra == "cli":
-            requirements = package_distribution.requires(extras=(extra,))
-            for requirement in requirements:
-                try:
-                    pkg_resources.require(requirement.project_name)
-                except pkg_resources.DistributionNotFound:
-                    print(f"Missing dependency: '{requirement.project_name}'")
-                    print(
-                        "Please, install rembg with the cli feature: pip install rembg[cli]"
-                    )
-                    exit(1)
[email protected]()
[email protected]_option(version=_version.get_versions()["version"])
+def _main() -> None:
+    pass
 
-    import click
 
-    from . import _version
-    from .commands import command_functions
+for command in command_functions:
+    _main.add_command(command)
 
-    @click.group()  # type: ignore
-    @click.version_option(version=_version.get_versions()["version"])
-    def _main() -> None:
-        pass
-
-    for command in command_functions:
-        _main.add_command(command)  # type: ignore
-
-    _main()  # type: ignore
+_main()

+ 11 - 11
rembg/commands/__init__.py

@@ -1,13 +1,13 @@
-from importlib import import_module
-from pathlib import Path
-from pkgutil import iter_modules
-
 command_functions = []
 
-package_dir = Path(__file__).resolve().parent
-for _b, module_name, _p in iter_modules([str(package_dir)]):
-    module = import_module(f"{__name__}.{module_name}")
-    for attribute_name in dir(module):
-        attribute = getattr(module, attribute_name)
-        if attribute_name.endswith("_command"):
-            command_functions.append(attribute)
+from .b_command import b_command
+from .d_command import d_command
+from .i_command import i_command
+from .p_command import p_command
+from .s_command import s_command
+
+command_functions.append(b_command)
+command_functions.append(d_command)
+command_functions.append(i_command)
+command_functions.append(p_command)
+command_functions.append(s_command)

+ 1 - 1
rembg/commands/b_command.py

@@ -94,7 +94,7 @@ from ..sessions import sessions_names
     "image_height",
     type=int,
 )
-def rs_command(
+def b_command(
     model: str,
     extras: str,
     image_width: int,

+ 14 - 0
rembg/commands/d_command.py

@@ -0,0 +1,14 @@
+import click
+
+from ..bg import download_models
+
+
[email protected](  # type: ignore
+    name="d",
+    help="download all models",
+)
+def d_command(*args, **kwargs) -> None:
+    """
+    Download all models
+    """
+    download_models()

+ 1 - 1
rembg/commands/s_command.py

@@ -26,7 +26,7 @@ from ..sessions.base import BaseSession
 @click.option(
     "-p",
     "--port",
-    default=5000,
+    default=7000,
     type=int,
     show_default=True,
     help="port",

+ 50 - 19
rembg/sessions/__init__.py

@@ -1,22 +1,53 @@
-from importlib import import_module
-from inspect import isclass
-from pathlib import Path
-from pkgutil import iter_modules
+from __future__ import annotations
+
+from typing import List
 
 from .base import BaseSession
 
-sessions_class = []
-sessions_names = []
-
-package_dir = Path(__file__).resolve().parent
-for _b, module_name, _p in iter_modules([str(package_dir)]):
-    module = import_module(f"{__name__}.{module_name}")
-    for attribute_name in dir(module):
-        attribute = getattr(module, attribute_name)
-        if (
-            isclass(attribute)
-            and issubclass(attribute, BaseSession)
-            and attribute != BaseSession
-        ):
-            sessions_class.append(attribute)
-            sessions_names.append(attribute.name())
+sessions_class: List[type[BaseSession]] = []
+sessions_names: List[str] = []
+
+from .dis_anime import DisSession
+
+sessions_class.append(DisSession)
+sessions_names.append(DisSession.name())
+
+from .dis_general_use import DisSession as DisSessionGeneralUse
+
+sessions_class.append(DisSessionGeneralUse)
+sessions_names.append(DisSessionGeneralUse.name())
+
+from .sam import SamSession
+
+sessions_class.append(SamSession)
+sessions_names.append(SamSession.name())
+
+from .silueta import SiluetaSession
+
+sessions_class.append(SiluetaSession)
+sessions_names.append(SiluetaSession.name())
+
+from .u2net_cloth_seg import Unet2ClothSession
+
+sessions_class.append(Unet2ClothSession)
+sessions_names.append(Unet2ClothSession.name())
+
+from .u2net_custom import U2netCustomSession
+
+sessions_class.append(U2netCustomSession)
+sessions_names.append(U2netCustomSession.name())
+
+from .u2net_human_seg import U2netHumanSegSession
+
+sessions_class.append(U2netHumanSegSession)
+sessions_names.append(U2netHumanSegSession.name())
+
+from .u2net import U2netSession
+
+sessions_class.append(U2netSession)
+sessions_names.append(U2netSession.name())
+
+from .u2netp import U2netpSession
+
+sessions_class.append(U2netpSession)
+sessions_names.append(U2netpSession.name())