|
@@ -1,25 +1,7 @@
|
|
-import pathlib
|
|
|
|
-import sys
|
|
|
|
-import time
|
|
|
|
-from enum import Enum
|
|
|
|
-from typing import IO, Optional, Tuple, cast
|
|
|
|
-
|
|
|
|
-import aiohttp
|
|
|
|
import click
|
|
import click
|
|
-import filetype
|
|
|
|
-import uvicorn
|
|
|
|
-from asyncer import asyncify
|
|
|
|
-from fastapi import Depends, FastAPI, File, Form, Query
|
|
|
|
-from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
-from starlette.responses import Response
|
|
|
|
-from tqdm import tqdm
|
|
|
|
-from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
|
|
|
-from watchdog.observers import Observer
|
|
|
|
|
|
|
|
from . import _version
|
|
from . import _version
|
|
-from .bg import remove
|
|
|
|
-from .session_base import BaseSession
|
|
|
|
-from .session_factory import new_session
|
|
|
|
|
|
+from .commands import command_functions
|
|
|
|
|
|
|
|
|
|
@click.group()
|
|
@click.group()
|
|
@@ -28,457 +10,5 @@ def main() -> None:
|
|
pass
|
|
pass
|
|
|
|
|
|
|
|
|
|
[email protected](help="for a file as input")
|
|
|
|
[email protected](
|
|
|
|
- "-m",
|
|
|
|
- "--model",
|
|
|
|
- default="u2net",
|
|
|
|
- type=click.Choice(
|
|
|
|
- [
|
|
|
|
- "u2net",
|
|
|
|
- "u2netp",
|
|
|
|
- "u2net_human_seg",
|
|
|
|
- "u2net_cloth_seg",
|
|
|
|
- "silueta",
|
|
|
|
- "isnet-general-use",
|
|
|
|
- ]
|
|
|
|
- ),
|
|
|
|
- show_default=True,
|
|
|
|
- show_choices=True,
|
|
|
|
- help="model name",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-a",
|
|
|
|
- "--alpha-matting",
|
|
|
|
- is_flag=True,
|
|
|
|
- show_default=True,
|
|
|
|
- help="use alpha matting",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-af",
|
|
|
|
- "--alpha-matting-foreground-threshold",
|
|
|
|
- default=240,
|
|
|
|
- type=int,
|
|
|
|
- show_default=True,
|
|
|
|
- help="trimap fg threshold",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-ab",
|
|
|
|
- "--alpha-matting-background-threshold",
|
|
|
|
- default=10,
|
|
|
|
- type=int,
|
|
|
|
- show_default=True,
|
|
|
|
- help="trimap bg threshold",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-ae",
|
|
|
|
- "--alpha-matting-erode-size",
|
|
|
|
- default=10,
|
|
|
|
- type=int,
|
|
|
|
- show_default=True,
|
|
|
|
- help="erode size",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-om",
|
|
|
|
- "--only-mask",
|
|
|
|
- is_flag=True,
|
|
|
|
- show_default=True,
|
|
|
|
- help="output only the mask",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-ppm",
|
|
|
|
- "--post-process-mask",
|
|
|
|
- is_flag=True,
|
|
|
|
- show_default=True,
|
|
|
|
- help="post process the mask",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-bgc",
|
|
|
|
- "--bgcolor",
|
|
|
|
- default=None,
|
|
|
|
- type=(int, int, int, int),
|
|
|
|
- nargs=4,
|
|
|
|
- help="Background color (R G B A) to replace the removed background with",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "output",
|
|
|
|
- default=(None if sys.stdin.isatty() else "-"),
|
|
|
|
- type=click.File("wb", lazy=True),
|
|
|
|
-)
|
|
|
|
-def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
|
|
|
- output.write(remove(input.read(), session=new_session(model), **kwargs))
|
|
|
|
-
|
|
|
|
-
|
|
|
|
[email protected](help="for a folder as input")
|
|
|
|
[email protected](
|
|
|
|
- "-m",
|
|
|
|
- "--model",
|
|
|
|
- default="u2net",
|
|
|
|
- type=click.Choice(
|
|
|
|
- [
|
|
|
|
- "u2net",
|
|
|
|
- "u2netp",
|
|
|
|
- "u2net_human_seg",
|
|
|
|
- "u2net_cloth_seg",
|
|
|
|
- "silueta",
|
|
|
|
- "isnet-general-use",
|
|
|
|
- ]
|
|
|
|
- ),
|
|
|
|
- show_default=True,
|
|
|
|
- show_choices=True,
|
|
|
|
- help="model name",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-a",
|
|
|
|
- "--alpha-matting",
|
|
|
|
- is_flag=True,
|
|
|
|
- show_default=True,
|
|
|
|
- help="use alpha matting",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-af",
|
|
|
|
- "--alpha-matting-foreground-threshold",
|
|
|
|
- default=240,
|
|
|
|
- type=int,
|
|
|
|
- show_default=True,
|
|
|
|
- help="trimap fg threshold",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-ab",
|
|
|
|
- "--alpha-matting-background-threshold",
|
|
|
|
- default=10,
|
|
|
|
- type=int,
|
|
|
|
- show_default=True,
|
|
|
|
- help="trimap bg threshold",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-ae",
|
|
|
|
- "--alpha-matting-erode-size",
|
|
|
|
- default=10,
|
|
|
|
- type=int,
|
|
|
|
- show_default=True,
|
|
|
|
- help="erode size",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-om",
|
|
|
|
- "--only-mask",
|
|
|
|
- is_flag=True,
|
|
|
|
- show_default=True,
|
|
|
|
- help="output only the mask",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-ppm",
|
|
|
|
- "--post-process-mask",
|
|
|
|
- is_flag=True,
|
|
|
|
- show_default=True,
|
|
|
|
- help="post process the mask",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-w",
|
|
|
|
- "--watch",
|
|
|
|
- default=False,
|
|
|
|
- is_flag=True,
|
|
|
|
- show_default=True,
|
|
|
|
- help="watches a folder for changes",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-bgc",
|
|
|
|
- "--bgcolor",
|
|
|
|
- default=None,
|
|
|
|
- type=(int, int, int, int),
|
|
|
|
- nargs=4,
|
|
|
|
- help="Background color (R G B A) to replace the removed background with",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "input",
|
|
|
|
- type=click.Path(
|
|
|
|
- exists=True,
|
|
|
|
- path_type=pathlib.Path,
|
|
|
|
- file_okay=False,
|
|
|
|
- dir_okay=True,
|
|
|
|
- readable=True,
|
|
|
|
- ),
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "output",
|
|
|
|
- type=click.Path(
|
|
|
|
- exists=False,
|
|
|
|
- path_type=pathlib.Path,
|
|
|
|
- file_okay=False,
|
|
|
|
- dir_okay=True,
|
|
|
|
- writable=True,
|
|
|
|
- ),
|
|
|
|
-)
|
|
|
|
-def p(
|
|
|
|
- model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
|
|
|
|
-) -> None:
|
|
|
|
- session = new_session(model)
|
|
|
|
-
|
|
|
|
- def process(each_input: pathlib.Path) -> None:
|
|
|
|
- try:
|
|
|
|
- mimetype = filetype.guess(each_input)
|
|
|
|
- if mimetype is None:
|
|
|
|
- return
|
|
|
|
- if mimetype.mime.find("image") < 0:
|
|
|
|
- return
|
|
|
|
-
|
|
|
|
- each_output = (output / each_input.name).with_suffix(".png")
|
|
|
|
- each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
|
-
|
|
|
|
- if not each_output.exists():
|
|
|
|
- each_output.write_bytes(
|
|
|
|
- cast(
|
|
|
|
- bytes,
|
|
|
|
- remove(each_input.read_bytes(), session=session, **kwargs),
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if watch:
|
|
|
|
- print(
|
|
|
|
- f"processed: {each_input.absolute()} -> {each_output.absolute()}"
|
|
|
|
- )
|
|
|
|
- except Exception as e:
|
|
|
|
- print(e)
|
|
|
|
-
|
|
|
|
- inputs = list(input.glob("**/*"))
|
|
|
|
- if not watch:
|
|
|
|
- inputs = tqdm(inputs)
|
|
|
|
-
|
|
|
|
- for each_input in inputs:
|
|
|
|
- if not each_input.is_dir():
|
|
|
|
- process(each_input)
|
|
|
|
-
|
|
|
|
- if watch:
|
|
|
|
- observer = Observer()
|
|
|
|
-
|
|
|
|
- class EventHandler(FileSystemEventHandler):
|
|
|
|
- def on_any_event(self, event: FileSystemEvent) -> None:
|
|
|
|
- if not (
|
|
|
|
- event.is_directory or event.event_type in ["deleted", "closed"]
|
|
|
|
- ):
|
|
|
|
- process(pathlib.Path(event.src_path))
|
|
|
|
-
|
|
|
|
- event_handler = EventHandler()
|
|
|
|
- observer.schedule(event_handler, input, recursive=False)
|
|
|
|
- observer.start()
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- while True:
|
|
|
|
- time.sleep(1)
|
|
|
|
-
|
|
|
|
- finally:
|
|
|
|
- observer.stop()
|
|
|
|
- observer.join()
|
|
|
|
-
|
|
|
|
-
|
|
|
|
[email protected](help="for a http server")
|
|
|
|
[email protected](
|
|
|
|
- "-p",
|
|
|
|
- "--port",
|
|
|
|
- default=5000,
|
|
|
|
- type=int,
|
|
|
|
- show_default=True,
|
|
|
|
- help="port",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-l",
|
|
|
|
- "--log_level",
|
|
|
|
- default="info",
|
|
|
|
- type=str,
|
|
|
|
- show_default=True,
|
|
|
|
- help="log level",
|
|
|
|
-)
|
|
|
|
[email protected](
|
|
|
|
- "-t",
|
|
|
|
- "--threads",
|
|
|
|
- default=None,
|
|
|
|
- type=int,
|
|
|
|
- show_default=True,
|
|
|
|
- help="number of worker threads",
|
|
|
|
-)
|
|
|
|
-def s(port: int, log_level: str, threads: int) -> None:
|
|
|
|
- sessions: dict[str, BaseSession] = {}
|
|
|
|
- tags_metadata = [
|
|
|
|
- {
|
|
|
|
- "name": "Background Removal",
|
|
|
|
- "description": "Endpoints that perform background removal with different image sources.",
|
|
|
|
- "externalDocs": {
|
|
|
|
- "description": "GitHub Source",
|
|
|
|
- "url": "https://github.com/danielgatis/rembg",
|
|
|
|
- },
|
|
|
|
- },
|
|
|
|
- ]
|
|
|
|
- app = FastAPI(
|
|
|
|
- title="Rembg",
|
|
|
|
- description="Rembg is a tool to remove images background. That is it.",
|
|
|
|
- version=_version.get_versions()["version"],
|
|
|
|
- contact={
|
|
|
|
- "name": "Daniel Gatis",
|
|
|
|
- "url": "https://github.com/danielgatis",
|
|
|
|
- "email": "[email protected]",
|
|
|
|
- },
|
|
|
|
- license_info={
|
|
|
|
- "name": "MIT License",
|
|
|
|
- "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
|
|
|
- },
|
|
|
|
- openapi_tags=tags_metadata,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- app.add_middleware(
|
|
|
|
- CORSMiddleware,
|
|
|
|
- allow_credentials=True,
|
|
|
|
- allow_origins=["*"],
|
|
|
|
- allow_methods=["*"],
|
|
|
|
- allow_headers=["*"],
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- class ModelType(str, Enum):
|
|
|
|
- u2net = "u2net"
|
|
|
|
- u2netp = "u2netp"
|
|
|
|
- u2net_human_seg = "u2net_human_seg"
|
|
|
|
- u2net_cloth_seg = "u2net_cloth_seg"
|
|
|
|
- silueta = "silueta"
|
|
|
|
- isnet_general_use = "isnet-general-use"
|
|
|
|
-
|
|
|
|
- class CommonQueryParams:
|
|
|
|
- def __init__(
|
|
|
|
- self,
|
|
|
|
- model: ModelType = Query(
|
|
|
|
- default=ModelType.u2net,
|
|
|
|
- description="Model to use when processing image",
|
|
|
|
- ),
|
|
|
|
- a: bool = Query(default=False, description="Enable Alpha Matting"),
|
|
|
|
- af: int = Query(
|
|
|
|
- default=240,
|
|
|
|
- ge=0,
|
|
|
|
- le=255,
|
|
|
|
- description="Alpha Matting (Foreground Threshold)",
|
|
|
|
- ),
|
|
|
|
- ab: int = Query(
|
|
|
|
- default=10,
|
|
|
|
- ge=0,
|
|
|
|
- le=255,
|
|
|
|
- description="Alpha Matting (Background Threshold)",
|
|
|
|
- ),
|
|
|
|
- ae: int = Query(
|
|
|
|
- default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
|
|
|
- ),
|
|
|
|
- om: bool = Query(default=False, description="Only Mask"),
|
|
|
|
- ppm: bool = Query(default=False, description="Post Process Mask"),
|
|
|
|
- bgc: Optional[str] = Query(default=None, description="Background Color"),
|
|
|
|
- ):
|
|
|
|
- self.model = model
|
|
|
|
- self.a = a
|
|
|
|
- self.af = af
|
|
|
|
- self.ab = ab
|
|
|
|
- self.ae = ae
|
|
|
|
- self.om = om
|
|
|
|
- self.ppm = ppm
|
|
|
|
- self.bgc = (
|
|
|
|
- cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
|
|
|
|
- if bgc
|
|
|
|
- else None
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- class CommonQueryPostParams:
|
|
|
|
- def __init__(
|
|
|
|
- self,
|
|
|
|
- model: ModelType = Form(
|
|
|
|
- default=ModelType.u2net,
|
|
|
|
- description="Model to use when processing image",
|
|
|
|
- ),
|
|
|
|
- a: bool = Form(default=False, description="Enable Alpha Matting"),
|
|
|
|
- af: int = Form(
|
|
|
|
- default=240,
|
|
|
|
- ge=0,
|
|
|
|
- le=255,
|
|
|
|
- description="Alpha Matting (Foreground Threshold)",
|
|
|
|
- ),
|
|
|
|
- ab: int = Form(
|
|
|
|
- default=10,
|
|
|
|
- ge=0,
|
|
|
|
- le=255,
|
|
|
|
- description="Alpha Matting (Background Threshold)",
|
|
|
|
- ),
|
|
|
|
- ae: int = Form(
|
|
|
|
- default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
|
|
|
- ),
|
|
|
|
- om: bool = Form(default=False, description="Only Mask"),
|
|
|
|
- ppm: bool = Form(default=False, description="Post Process Mask"),
|
|
|
|
- bgc: Optional[str] = Query(default=None, description="Background Color"),
|
|
|
|
- ):
|
|
|
|
- self.model = model
|
|
|
|
- self.a = a
|
|
|
|
- self.af = af
|
|
|
|
- self.ab = ab
|
|
|
|
- self.ae = ae
|
|
|
|
- self.om = om
|
|
|
|
- self.ppm = ppm
|
|
|
|
- self.bgc = (
|
|
|
|
- cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
|
|
|
|
- if bgc
|
|
|
|
- else None
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
|
|
|
- return Response(
|
|
|
|
- remove(
|
|
|
|
- content,
|
|
|
|
- session=sessions.setdefault(
|
|
|
|
- commons.model.value, new_session(commons.model.value)
|
|
|
|
- ),
|
|
|
|
- alpha_matting=commons.a,
|
|
|
|
- alpha_matting_foreground_threshold=commons.af,
|
|
|
|
- alpha_matting_background_threshold=commons.ab,
|
|
|
|
- alpha_matting_erode_size=commons.ae,
|
|
|
|
- only_mask=commons.om,
|
|
|
|
- post_process_mask=commons.ppm,
|
|
|
|
- bgcolor=commons.bgc,
|
|
|
|
- ),
|
|
|
|
- media_type="image/png",
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- @app.on_event("startup")
|
|
|
|
- def startup():
|
|
|
|
- if threads is not None:
|
|
|
|
- from anyio import CapacityLimiter
|
|
|
|
- from anyio.lowlevel import RunVar
|
|
|
|
-
|
|
|
|
- RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
|
|
|
-
|
|
|
|
- @app.get(
|
|
|
|
- path="/",
|
|
|
|
- tags=["Background Removal"],
|
|
|
|
- summary="Remove from URL",
|
|
|
|
- description="Removes the background from an image obtained by retrieving an URL.",
|
|
|
|
- )
|
|
|
|
- async def get_index(
|
|
|
|
- url: str = Query(
|
|
|
|
- default=..., description="URL of the image that has to be processed."
|
|
|
|
- ),
|
|
|
|
- commons: CommonQueryParams = Depends(),
|
|
|
|
- ):
|
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
|
- async with session.get(url) as response:
|
|
|
|
- file = await response.read()
|
|
|
|
- return await asyncify(im_without_bg)(file, commons)
|
|
|
|
-
|
|
|
|
- @app.post(
|
|
|
|
- path="/",
|
|
|
|
- tags=["Background Removal"],
|
|
|
|
- summary="Remove from Stream",
|
|
|
|
- description="Removes the background from an image sent within the request itself.",
|
|
|
|
- )
|
|
|
|
- async def post_index(
|
|
|
|
- file: bytes = File(
|
|
|
|
- default=...,
|
|
|
|
- description="Image file (byte stream) that has to be processed.",
|
|
|
|
- ),
|
|
|
|
- commons: CommonQueryPostParams = Depends(),
|
|
|
|
- ):
|
|
|
|
- return await asyncify(im_without_bg)(file, commons) # type: ignore
|
|
|
|
-
|
|
|
|
- uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
|
|
|
|
|
+for command in command_functions:
|
|
|
|
+ main.add_command(command)
|