|
@@ -16,6 +16,7 @@ from tqdm import tqdm
|
|
|
|
|
|
from .bg import remove
|
|
|
from .detect import ort_session
|
|
|
+from . import _version
|
|
|
|
|
|
|
|
|
@click.group()
|
|
@@ -191,7 +192,31 @@ def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs):
|
|
|
)
|
|
|
def s(port: int, log_level: str):
|
|
|
sessions: dict[str, ort.InferenceSession] = {}
|
|
|
- app = FastAPI()
|
|
|
+ tags_metadata = [
|
|
|
+ {
|
|
|
+ "name": "Background Removal",
|
|
|
+ "description": "Endpoints that perform background removal with different image sources.",
|
|
|
+ "externalDocs": {
|
|
|
+ "description": "GitHub Source",
|
|
|
+ "url": "https://github.com/danielgatis/rembg"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ ]
|
|
|
+ app = FastAPI(
|
|
|
+ title="Rembg",
|
|
|
+ description="Rembg is a tool to remove images background. That is it.",
|
|
|
+ version=_version.get_versions()["version"],
|
|
|
+ contact={
|
|
|
+ "name": "Daniel Gatis",
|
|
|
+ "url": "https://github.com/danielgatis",
|
|
|
+ "email": "[email protected]"
|
|
|
+ },
|
|
|
+ license_info={
|
|
|
+ "name": "MIT License",
|
|
|
+ "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt"
|
|
|
+ },
|
|
|
+ openapi_tags=tags_metadata,
|
|
|
+ )
|
|
|
|
|
|
class ModelType(str, Enum):
|
|
|
u2net = "u2net"
|
|
@@ -201,12 +226,12 @@ def s(port: int, log_level: str):
|
|
|
class CommonQueryParams:
|
|
|
def __init__(
|
|
|
self,
|
|
|
- model: ModelType = Query(ModelType.u2net),
|
|
|
- a: bool = Query(False),
|
|
|
- af: int = Query(240, ge=0),
|
|
|
- ab: int = Query(10, ge=0),
|
|
|
- ae: int = Query(10, ge=0),
|
|
|
- om: bool = Query(False),
|
|
|
+ model: ModelType = Query(default=ModelType.u2net, description="Model to use when processing image"),
|
|
|
+ a: bool = Query(default=False, description="Enable Alpha Matting"),
|
|
|
+ af: int = Query(default=240, ge=0, description="Alpha Matting (Foreground Threshold)"),
|
|
|
+ ab: int = Query(default=10, ge=0, description="Alpha Matting (Background Threshold)"),
|
|
|
+ ae: int = Query(default=10, ge=0, description="Alpha Matting (Erode Structure Size)"),
|
|
|
+ om: bool = Query(default=False, description="Only Mask"),
|
|
|
):
|
|
|
self.model = model
|
|
|
self.a = a
|
|
@@ -231,16 +256,30 @@ def s(port: int, log_level: str):
|
|
|
media_type="image/png",
|
|
|
)
|
|
|
|
|
|
- @app.get("/")
|
|
|
- async def get_index(url: str, commons: CommonQueryParams = Depends()):
|
|
|
+ @app.get(
|
|
|
+ path="/",
|
|
|
+ tags=["Background Removal"],
|
|
|
+ summary="Remove from URL",
|
|
|
+ description="Removes the background from an image obtained by retrieving an URL.",
|
|
|
+ )
|
|
|
+ async def get_index(
|
|
|
+ url: str = Query(default=..., description="URL of the image that has to be processed."),
|
|
|
+ commons: CommonQueryParams = Depends()
|
|
|
+ ):
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
async with session.get(url) as response:
|
|
|
file = await response.read()
|
|
|
return await asyncify(im_without_bg)(file, commons)
|
|
|
|
|
|
- @app.post("/")
|
|
|
+ @app.post(
|
|
|
+ path="/",
|
|
|
+ tags=["Background Removal"],
|
|
|
+ summary="Remove from Stream",
|
|
|
+ description="Removes the background from an image sent within the request itself.",
|
|
|
+ )
|
|
|
async def post_index(
|
|
|
- file: bytes = File(...), commons: CommonQueryParams = Depends()
|
|
|
+ file: bytes = File(default=..., description="Image file (byte stream) that has to be processed."),
|
|
|
+ commons: CommonQueryParams = Depends()
|
|
|
):
|
|
|
return await asyncify(im_without_bg)(file, commons)
|
|
|
|