|
@@ -1,8 +1,11 @@
|
|
|
import json
|
|
|
+import os
|
|
|
+import webbrowser
|
|
|
from typing import Optional, Tuple, cast
|
|
|
|
|
|
import aiohttp
|
|
|
import click
|
|
|
+import gradio as gr
|
|
|
import uvicorn
|
|
|
from asyncer import asyncify
|
|
|
from fastapi import Depends, FastAPI, File, Form, Query
|
|
@@ -70,6 +73,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
|
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
|
|
},
|
|
|
openapi_tags=tags_metadata,
|
|
|
+ docs_url="/api",
|
|
|
)
|
|
|
|
|
|
app.add_middleware(
|
|
@@ -190,13 +194,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
|
only_mask=commons.om,
|
|
|
post_process_mask=commons.ppm,
|
|
|
bgcolor=commons.bgc,
|
|
|
- **kwargs
|
|
|
+ **kwargs,
|
|
|
),
|
|
|
media_type="image/png",
|
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
def startup():
|
|
|
+ try:
|
|
|
+ webbrowser.open(f"http://localhost:{port}")
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
if threads is not None:
|
|
|
from anyio import CapacityLimiter
|
|
|
from anyio.lowlevel import RunVar
|
|
@@ -204,7 +213,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
|
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
|
|
|
|
|
@app.get(
|
|
|
- path="/",
|
|
|
+ path="/api/remove",
|
|
|
tags=["Background Removal"],
|
|
|
summary="Remove from URL",
|
|
|
description="Removes the background from an image obtained by retrieving an URL.",
|
|
@@ -221,7 +230,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
|
return await asyncify(im_without_bg)(file, commons)
|
|
|
|
|
|
@app.post(
|
|
|
- path="/",
|
|
|
+ path="/api/remove",
|
|
|
tags=["Background Removal"],
|
|
|
summary="Remove from Stream",
|
|
|
description="Removes the background from an image sent within the request itself.",
|
|
@@ -235,4 +244,42 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
|
):
|
|
|
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
|
|
|
|
|
- uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
|
|
+ def gr_app(app):
|
|
|
+ def inference(input_path, model):
|
|
|
+ output_path = "output.png"
|
|
|
+ with open(input_path, "rb") as i:
|
|
|
+ with open(output_path, "wb") as o:
|
|
|
+ input = i.read()
|
|
|
+ output = remove(input, session=new_session(model))
|
|
|
+ o.write(output)
|
|
|
+ return os.path.join(output_path)
|
|
|
+
|
|
|
+ interface = gr.Interface(
|
|
|
+ inference,
|
|
|
+ [
|
|
|
+ gr.components.Image(type="filepath", label="Input"),
|
|
|
+ gr.components.Dropdown(
|
|
|
+ [
|
|
|
+ "u2net",
|
|
|
+ "u2netp",
|
|
|
+ "u2net_human_seg",
|
|
|
+ "u2net_cloth_seg",
|
|
|
+ "silueta",
|
|
|
+ "isnet-general-use",
|
|
|
+ "isnet-anime",
|
|
|
+ ],
|
|
|
+ value="u2net",
|
|
|
+ label="Models",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ gr.components.Image(type="filepath", label="Output"),
|
|
|
+ )
|
|
|
+
|
|
|
+ interface.queue(concurrency_count=3)
|
|
|
+ app = gr.mount_gradio_app(app, interface, path="/")
|
|
|
+ return app
|
|
|
+
|
|
|
+ print(f"To access the API documentation, go to http://localhost:{port}/api")
|
|
|
+ print(f"To access the UI, go to http://localhost:{port}")
|
|
|
+
|
|
|
+ uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)
|