|
@@ -1,156 +1,284 @@
|
|
|
-import argparse
|
|
|
-import glob
|
|
|
-import os
|
|
|
-from distutils.util import strtobool
|
|
|
-from typing import BinaryIO
|
|
|
+import pathlib
|
|
|
import sys
|
|
|
-from pathlib import Path
|
|
|
+from enum import Enum
|
|
|
+from typing import IO, Optional
|
|
|
|
|
|
+import click
|
|
|
import filetype
|
|
|
-from tqdm import tqdm
|
|
|
import onnxruntime as ort
|
|
|
+import requests
|
|
|
+import uvicorn
|
|
|
+from fastapi import Depends, FastAPI, File, Query
|
|
|
+from starlette.responses import Response
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
from .bg import remove
|
|
|
from .detect import ort_session
|
|
|
|
|
|
-sessions: dict[str, ort.InferenceSession] = {}
|
|
|
-
|
|
|
|
|
|
[email protected]()
|
|
|
[email protected]_option()
|
|
|
def main():
|
|
|
- ap = argparse.ArgumentParser()
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "-m",
|
|
|
- "--model",
|
|
|
- default="u2net",
|
|
|
- type=str,
|
|
|
- choices=["u2net", "u2netp", "u2net_human_seg"],
|
|
|
- help="The model name.",
|
|
|
- )
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "-a",
|
|
|
- "--alpha-matting",
|
|
|
- nargs="?",
|
|
|
- const=True,
|
|
|
- default=False,
|
|
|
- type=lambda x: bool(strtobool(x)),
|
|
|
- help="When true use alpha matting cutout.",
|
|
|
- )
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "-af",
|
|
|
- "--alpha-matting-foreground-threshold",
|
|
|
- default=240,
|
|
|
- type=int,
|
|
|
- help="The trimap foreground threshold.",
|
|
|
- )
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "-ab",
|
|
|
- "--alpha-matting-background-threshold",
|
|
|
- default=10,
|
|
|
- type=int,
|
|
|
- help="The trimap background threshold.",
|
|
|
- )
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "-ae",
|
|
|
- "--alpha-matting-erode-size",
|
|
|
- default=10,
|
|
|
- type=int,
|
|
|
- help="Size of element used for the erosion.",
|
|
|
- )
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "-az",
|
|
|
- "--alpha-matting-base-size",
|
|
|
- default=1000,
|
|
|
- type=int,
|
|
|
- help="The image base size.",
|
|
|
- )
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "-p",
|
|
|
- "--path",
|
|
|
- nargs=2,
|
|
|
- help="An input folder and an output folder.",
|
|
|
- )
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "input",
|
|
|
- nargs=(None if sys.stdin.isatty() else "?"),
|
|
|
- default=(None if sys.stdin.isatty() else sys.stdin.buffer),
|
|
|
- type=argparse.FileType("rb"),
|
|
|
- help="Path to the input image.",
|
|
|
- )
|
|
|
-
|
|
|
- ap.add_argument(
|
|
|
- "output",
|
|
|
- nargs=(None if sys.stdin.isatty() else "?"),
|
|
|
- default=(None if sys.stdin.isatty() else sys.stdout.buffer),
|
|
|
- type=argparse.FileType("wb"),
|
|
|
- help="Path to the output png image.",
|
|
|
- )
|
|
|
-
|
|
|
- args = ap.parse_args()
|
|
|
- session = sessions.setdefault(args.model, ort_session(args.model))
|
|
|
-
|
|
|
- if args.path:
|
|
|
- full_paths = [os.path.abspath(path) for path in args.path]
|
|
|
-
|
|
|
- input_paths = [full_paths[0]]
|
|
|
- output_path = full_paths[1]
|
|
|
-
|
|
|
- if not os.path.exists(output_path):
|
|
|
- os.makedirs(output_path)
|
|
|
-
|
|
|
- input_files = set()
|
|
|
-
|
|
|
- for input_path in input_paths:
|
|
|
- if os.path.isfile(path):
|
|
|
- input_files.add(path)
|
|
|
- else:
|
|
|
- input_paths += set(glob.glob(input_path + "/*"))
|
|
|
-
|
|
|
- for input_file in tqdm(input_files):
|
|
|
- input_file_type = filetype.guess(input_file)
|
|
|
-
|
|
|
- if input_file_type is None:
|
|
|
- continue
|
|
|
-
|
|
|
- if input_file_type.mime.find("image") < 0:
|
|
|
- continue
|
|
|
-
|
|
|
- out_file = os.path.join(
|
|
|
- output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png"
|
|
|
- )
|
|
|
-
|
|
|
- Path(out_file).write_bytes(
|
|
|
- remove(
|
|
|
- Path(input_file).read_bytes(),
|
|
|
- session=session,
|
|
|
- alpha_matting=args.alpha_matting,
|
|
|
- alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
|
|
- alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
|
|
- alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
|
|
- alpha_matting_base_size=args.alpha_matting_base_size,
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- else:
|
|
|
- args.output.write(
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
[email protected](help="for a file as input")
|
|
|
[email protected](
|
|
|
+ "-m",
|
|
|
+ "--model",
|
|
|
+ default="u2net",
|
|
|
+ type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
|
|
|
+ show_default=True,
|
|
|
+ show_choices=True,
|
|
|
+ help="model name",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-a",
|
|
|
+ "--alpha-matting",
|
|
|
+ is_flag=True,
|
|
|
+ show_default=True,
|
|
|
+ help="use alpha matting",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-af",
|
|
|
+ "--alpha-matting-foreground-threshold",
|
|
|
+ default=240,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="trimap fg threshold",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-ab",
|
|
|
+ "--alpha-matting-background-threshold",
|
|
|
+ default=10,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="trimap bg threshold",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-ae",
|
|
|
+ "--alpha-matting-erode-size",
|
|
|
+ default=10,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="erode size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-az",
|
|
|
+ "--alpha-matting-base-size",
|
|
|
+ default=1000,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="image base size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-w",
|
|
|
+ "--width",
|
|
|
+ default=None,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="output image size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-h",
|
|
|
+ "--height",
|
|
|
+ default=None,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="output image size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "output",
|
|
|
+ default=(None if sys.stdin.isatty() else "-"),
|
|
|
+ type=click.File("wb", lazy=True),
|
|
|
+)
|
|
|
+def i(model: str, input: IO, output: IO, **kwargs: dict):
|
|
|
+ output.write(remove(input.read(), session=ort_session(model), **kwargs))
|
|
|
+
|
|
|
+
|
|
|
[email protected](help="for a folder as input")
|
|
|
[email protected](
|
|
|
+ "-m",
|
|
|
+ "--model",
|
|
|
+ default="u2net",
|
|
|
+ type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]),
|
|
|
+ show_default=True,
|
|
|
+ show_choices=True,
|
|
|
+ help="model name",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-a",
|
|
|
+ "--alpha-matting",
|
|
|
+ is_flag=True,
|
|
|
+ show_default=True,
|
|
|
+ help="use alpha matting",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-af",
|
|
|
+ "--alpha-matting-foreground-threshold",
|
|
|
+ default=240,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="trimap fg threshold",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-ab",
|
|
|
+ "--alpha-matting-background-threshold",
|
|
|
+ default=10,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="trimap bg threshold",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-ae",
|
|
|
+ "--alpha-matting-erode-size",
|
|
|
+ default=10,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="erode size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-az",
|
|
|
+ "--alpha-matting-base-size",
|
|
|
+ default=1000,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="image base size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-w",
|
|
|
+ "--width",
|
|
|
+ default=None,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="output image size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-h",
|
|
|
+ "--height",
|
|
|
+ default=None,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="output image size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "input",
|
|
|
+ type=click.Path(
|
|
|
+ exists=True,
|
|
|
+ path_type=pathlib.Path,
|
|
|
+ file_okay=False,
|
|
|
+ dir_okay=True,
|
|
|
+ readable=True,
|
|
|
+ ),
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "output",
|
|
|
+ type=click.Path(
|
|
|
+ exists=False,
|
|
|
+ path_type=pathlib.Path,
|
|
|
+ file_okay=False,
|
|
|
+ dir_okay=True,
|
|
|
+ writable=True,
|
|
|
+ ),
|
|
|
+)
|
|
|
+def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs: dict):
|
|
|
+ session = ort_session(model)
|
|
|
+ for each_input in tqdm(list(input.glob("**/*"))):
|
|
|
+ if each_input.is_dir():
|
|
|
+ continue
|
|
|
+
|
|
|
+ mimetype = filetype.guess(each_input)
|
|
|
+ if mimetype is None:
|
|
|
+ continue
|
|
|
+ if mimetype.mime.find("image") < 0:
|
|
|
+ continue
|
|
|
+
|
|
|
+ each_output = (output / each_input.name).with_suffix(".png")
|
|
|
+ each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ each_output.write_bytes(
|
|
|
+ remove(each_input.read_bytes(), session=session, **kwargs)
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
[email protected](help="for a http server")
|
|
|
[email protected](
|
|
|
+ "-p",
|
|
|
+ "--port",
|
|
|
+ default=5000,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="port",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-l",
|
|
|
+ "--log_level",
|
|
|
+ default="info",
|
|
|
+ type=str,
|
|
|
+ show_default=True,
|
|
|
+ help="log level",
|
|
|
+)
|
|
|
+def s(port: int, log_level: str):
|
|
|
+ sessions: dict[str, ort.InferenceSession] = {}
|
|
|
+ app = FastAPI()
|
|
|
+
|
|
|
+ class ModelType(str, Enum):
|
|
|
+ u2net = "u2net"
|
|
|
+ u2netp = "u2netp"
|
|
|
+ u2net_human_seg = "u2net_human_seg"
|
|
|
+
|
|
|
+ class CommonQueryParams:
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ model: ModelType = Query(ModelType.u2net),
|
|
|
+ a: bool = Query(False),
|
|
|
+ af: int = Query(240, ge=0),
|
|
|
+ ab: int = Query(10, ge=0),
|
|
|
+ ae: int = Query(10, ge=0),
|
|
|
+ az: int = Query(1000, ge=0),
|
|
|
+ width: Optional[int] = Query(None, gt=0),
|
|
|
+ height: Optional[int] = Query(None, gt=0),
|
|
|
+ ):
|
|
|
+ self.model = model
|
|
|
+ self.width = width
|
|
|
+ self.height = height
|
|
|
+ self.a = a
|
|
|
+ self.af = af
|
|
|
+ self.ab = ab
|
|
|
+ self.ae = ae
|
|
|
+ self.az = az
|
|
|
+
|
|
|
+ def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
|
|
+ return Response(
|
|
|
remove(
|
|
|
- args.input.read(),
|
|
|
- session=session,
|
|
|
- alpha_matting=args.alpha_matting,
|
|
|
- alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold,
|
|
|
- alpha_matting_background_threshold=args.alpha_matting_background_threshold,
|
|
|
- alpha_matting_erode_structure_size=args.alpha_matting_erode_size,
|
|
|
- alpha_matting_base_size=args.alpha_matting_base_size,
|
|
|
- )
|
|
|
+ content,
|
|
|
+ session=sessions.setdefault(
|
|
|
+ commons.model.value, ort_session(commons.model.value)
|
|
|
+ ),
|
|
|
+ width=commons.width,
|
|
|
+ height=commons.height,
|
|
|
+ alpha_matting=commons.a,
|
|
|
+ alpha_matting_foreground_threshold=commons.af,
|
|
|
+ alpha_matting_background_threshold=commons.ab,
|
|
|
+ alpha_matting_erode_size=commons.ae,
|
|
|
+ alpha_matting_base_size=commons.az,
|
|
|
+ ),
|
|
|
+ media_type="image/png",
|
|
|
)
|
|
|
|
|
|
+ @app.get("/")
|
|
|
+ def get_index(url: str, commons: CommonQueryParams = Depends()):
|
|
|
+ return im_without_bg(requests.get(url).content, commons)
|
|
|
+
|
|
|
+ @app.post("/")
|
|
|
+ def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()):
|
|
|
+ return im_without_bg(file, commons)
|
|
|
+
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|