|
@@ -1,76 +1,85 @@
|
|
|
import argparse
|
|
|
+from enum import Enum
|
|
|
from io import BytesIO
|
|
|
-from urllib.parse import quote, unquote_plus
|
|
|
-from urllib.request import urlopen
|
|
|
+from typing import Optional
|
|
|
|
|
|
-from flask import Flask, request, send_file
|
|
|
-from waitress import serve
|
|
|
+import requests
|
|
|
+import uvicorn
|
|
|
+from fastapi import FastAPI, Form, Query, UploadFile
|
|
|
+from PIL import Image
|
|
|
+from starlette.responses import StreamingResponse
|
|
|
|
|
|
from .bg import remove
|
|
|
+from .detect import ort_session
|
|
|
+
|
|
|
+sessions = {}
|
|
|
+app = FastAPI()
|
|
|
+
|
|
|
+
|
|
|
+class ModelType(str, Enum):
|
|
|
+ u2net = "u2net"
|
|
|
+ u2netp = "u2netp"
|
|
|
+ u2net_human_seg = "u2net_human_seg"
|
|
|
+
|
|
|
+
|
|
|
[email protected]("/")
|
|
|
+def get_index(
|
|
|
+ url: str,
|
|
|
+ model: Optional[ModelType] = ModelType.u2net,
|
|
|
+ width: Optional[int] = Query(None, gt=0),
|
|
|
+ height: Optional[int] = Query(None, gt=0),
|
|
|
+ a: Optional[bool] = Query(False),
|
|
|
+ af: Optional[int] = Query(240, ge=0),
|
|
|
+ ab: Optional[int] = Query(10, ge=0),
|
|
|
+ ae: Optional[int] = Query(10, ge=0),
|
|
|
+ az: Optional[int] = Query(1000, ge=0),
|
|
|
+):
|
|
|
+ return StreamingResponse(
|
|
|
+ BytesIO(
|
|
|
+ remove(
|
|
|
+ requests.get(url).content,
|
|
|
+ session=sessions.setdefault(model, ort_session(model)),
|
|
|
+ width=width,
|
|
|
+ height=height,
|
|
|
+ alpha_matting=a,
|
|
|
+ alpha_matting_foreground_threshold=af,
|
|
|
+ alpha_matting_background_threshold=ab,
|
|
|
+ alpha_matting_erode_structure_size=ae,
|
|
|
+ alpha_matting_base_size=az,
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ media_type="image/png",
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
-app = Flask(__name__)
|
|
|
-
|
|
|
-
|
|
|
[email protected]("/", methods=["GET", "POST"])
|
|
|
-def index():
|
|
|
- file_content = ""
|
|
|
-
|
|
|
- if request.method == "POST":
|
|
|
- if "file" not in request.files:
|
|
|
- return {"error": "missing post form param 'file'"}, 400
|
|
|
-
|
|
|
- file_content = request.files["file"].read()
|
|
|
-
|
|
|
- if request.method == "GET":
|
|
|
- url = request.args.get("url", type=str)
|
|
|
- if url is None:
|
|
|
- return {"error": "missing query param 'url'"}, 400
|
|
|
-
|
|
|
- url = unquote_plus(url)
|
|
|
- if " " in url:
|
|
|
- url = quote(url, safe="/:")
|
|
|
-
|
|
|
- file_content = urlopen(url).read()
|
|
|
-
|
|
|
- if file_content == "":
|
|
|
- return {"error": "File content is empty"}, 400
|
|
|
-
|
|
|
- alpha_matting = "a" in request.values
|
|
|
- af = request.values.get("af", type=int, default=240)
|
|
|
- ab = request.values.get("ab", type=int, default=10)
|
|
|
- ae = request.values.get("ae", type=int, default=10)
|
|
|
- az = request.values.get("az", type=int, default=1000)
|
|
|
- width = request.args.get("width", type=int)
|
|
|
- height = request.args.get("height", type=int)
|
|
|
-
|
|
|
- model = request.values.get("model", type=str, default="u2net")
|
|
|
- model_choices = ["u2net", "u2netp", "u2net_human_seg"]
|
|
|
-
|
|
|
- if model not in model_choices:
|
|
|
- return {
|
|
|
- "error": f"invalid query param 'model'. Available options are {model_choices}"
|
|
|
- }, 400
|
|
|
-
|
|
|
- try:
|
|
|
- return send_file(
|
|
|
- BytesIO(
|
|
|
- remove(
|
|
|
- file_content,
|
|
|
- width=width,
|
|
|
- height=height,
|
|
|
- model_name=model,
|
|
|
- alpha_matting=alpha_matting,
|
|
|
- alpha_matting_foreground_threshold=af,
|
|
|
- alpha_matting_background_threshold=ab,
|
|
|
- alpha_matting_erode_structure_size=ae,
|
|
|
- alpha_matting_base_size=az,
|
|
|
- )
|
|
|
- ),
|
|
|
- mimetype="image/png",
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- app.logger.exception(e, exc_info=True)
|
|
|
- return {"error": "oops, something went wrong!"}, 500
|
|
|
[email protected]("/")
|
|
|
+def post_index(
|
|
|
+ file: UploadFile = File(...),
|
|
|
+ model: Optional[ModelType] = ModelType.u2net,
|
|
|
+ width: Optional[int] = Query(None, gt=0),
|
|
|
+ height: Optional[int] = Query(None, gt=0),
|
|
|
+ a: Optional[bool] = Query(False),
|
|
|
+ af: Optional[int] = Query(240, ge=0),
|
|
|
+ ab: Optional[int] = Query(10, ge=0),
|
|
|
+ ae: Optional[int] = Query(10, ge=0),
|
|
|
+ az: Optional[int] = Query(1000, ge=0),
|
|
|
+):
|
|
|
+ return StreamingResponse(
|
|
|
+ BytesIO(
|
|
|
+ remove(
|
|
|
+ file.read(),
|
|
|
+ session=sessions.setdefault(model, ort_session(model)),
|
|
|
+ width=width,
|
|
|
+ height=height,
|
|
|
+ alpha_matting=a,
|
|
|
+ alpha_matting_foreground_threshold=af,
|
|
|
+ alpha_matting_background_threshold=ab,
|
|
|
+ alpha_matting_erode_structure_size=ae,
|
|
|
+ alpha_matting_base_size=az,
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ media_type="image/png",
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def main():
|
|
@@ -92,8 +101,18 @@ def main():
|
|
|
help="The port to bind to.",
|
|
|
)
|
|
|
|
|
|
+ ap.add_argument(
|
|
|
+ "-l",
|
|
|
+ "--log_level",
|
|
|
+ default="info",
|
|
|
+ type=str,
|
|
|
+ help="The log level.",
|
|
|
+ )
|
|
|
+
|
|
|
args = ap.parse_args()
|
|
|
- serve(app, host=args.addr, port=args.port)
|
|
|
+ uvicorn.run(
|
|
|
+ "rembg.server:app", host=args.addr, port=args.port, log_level=args.log_level
|
|
|
+ )
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|