|
@@ -1,5 +1,6 @@
|
|
|
import pathlib
|
|
|
import sys
|
|
|
+import time
|
|
|
from enum import Enum
|
|
|
from typing import IO, Optional
|
|
|
|
|
@@ -13,6 +14,8 @@ from asyncer import asyncify
|
|
|
from fastapi import Depends, FastAPI, File, Query
|
|
|
from starlette.responses import Response
|
|
|
from tqdm import tqdm
|
|
|
+from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
|
|
+from watchdog.observers import Observer
|
|
|
|
|
|
from . import _version
|
|
|
from .bg import remove
|
|
@@ -21,7 +24,7 @@ from .detect import ort_session
|
|
|
|
|
|
@click.group()
|
|
|
@click.version_option()
|
|
|
-def main():
|
|
|
+def main() -> None:
|
|
|
pass
|
|
|
|
|
|
|
|
@@ -81,7 +84,7 @@ def main():
|
|
|
default=(None if sys.stdin.isatty() else "-"),
|
|
|
type=click.File("wb", lazy=True),
|
|
|
)
|
|
|
-def i(model: str, input: IO, output: IO, **kwargs):
|
|
|
+def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
|
|
output.write(remove(input.read(), session=ort_session(model), **kwargs))
|
|
|
|
|
|
|
|
@@ -133,6 +136,14 @@ def i(model: str, input: IO, output: IO, **kwargs):
|
|
|
show_default=True,
|
|
|
help="output only the mask",
|
|
|
)
|
|
|
[email protected](
|
|
|
+ "-w",
|
|
|
+ "--watch",
|
|
|
+ default=False,
|
|
|
+ is_flag=True,
|
|
|
+ show_default=True,
|
|
|
+ help="watches a folder for changes",
|
|
|
+)
|
|
|
@click.argument(
|
|
|
"input",
|
|
|
type=click.Path(
|
|
@@ -153,24 +164,63 @@ def i(model: str, input: IO, output: IO, **kwargs):
|
|
|
writable=True,
|
|
|
),
|
|
|
)
|
|
|
-def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs):
|
|
|
+def p(
|
|
|
+ model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
|
|
|
+) -> None:
|
|
|
session = ort_session(model)
|
|
|
- for each_input in tqdm(list(input.glob("**/*"))):
|
|
|
- if each_input.is_dir():
|
|
|
- continue
|
|
|
|
|
|
- mimetype = filetype.guess(each_input)
|
|
|
- if mimetype is None:
|
|
|
- continue
|
|
|
- if mimetype.mime.find("image") < 0:
|
|
|
- continue
|
|
|
+ def process(each_input: pathlib.Path) -> None:
|
|
|
+ try:
|
|
|
+ mimetype = filetype.guess(each_input)
|
|
|
+ if mimetype is None:
|
|
|
+ return
|
|
|
+ if mimetype.mime.find("image") < 0:
|
|
|
+ return
|
|
|
|
|
|
- each_output = (output / each_input.name).with_suffix(".png")
|
|
|
- each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
+ each_output = (output / each_input.name).with_suffix(".png")
|
|
|
+ each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
- each_output.write_bytes(
|
|
|
- remove(each_input.read_bytes(), session=session, **kwargs)
|
|
|
- )
|
|
|
+ if not each_output.exists():
|
|
|
+ each_output.write_bytes(
|
|
|
+ remove(each_input.read_bytes(), session=session, **kwargs)
|
|
|
+ )
|
|
|
+
|
|
|
+ if watch:
|
|
|
+ print(
|
|
|
+ f"processed: {each_input.absolute()} -> {each_output.absolute()}"
|
|
|
+ )
|
|
|
+ except e:
|
|
|
+ print(e)
|
|
|
+
|
|
|
+ inputs = list(input.glob("**/*"))
|
|
|
+ if not watch:
|
|
|
+ inputs = tqdm(inputs)
|
|
|
+
|
|
|
+ for each_input in inputs:
|
|
|
+ if not each_input.is_dir():
|
|
|
+ process(each_input)
|
|
|
+
|
|
|
+ if watch:
|
|
|
+ observer = Observer()
|
|
|
+
|
|
|
+ class EventHandler(FileSystemEventHandler):
|
|
|
+ def on_any_event(self, event: FileSystemEvent) -> None:
|
|
|
+ if not (
|
|
|
+ event.is_directory or event.event_type in ["deleted", "closed"]
|
|
|
+ ):
|
|
|
+ process(pathlib.Path(event.src_path))
|
|
|
+
|
|
|
+ event_handler = EventHandler()
|
|
|
+ observer.schedule(event_handler, input, recursive=False)
|
|
|
+ observer.start()
|
|
|
+
|
|
|
+ try:
|
|
|
+ while True:
|
|
|
+ time.sleep(1)
|
|
|
+
|
|
|
+ finally:
|
|
|
+ observer.stop()
|
|
|
+ observer.join()
|
|
|
|
|
|
|
|
|
@main.command(help="for a http server")
|
|
@@ -190,7 +240,7 @@ def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs):
|
|
|
show_default=True,
|
|
|
help="log level",
|
|
|
)
|
|
|
-def s(port: int, log_level: str):
|
|
|
+def s(port: int, log_level: str) -> None:
|
|
|
sessions: dict[str, ort.InferenceSession] = {}
|
|
|
tags_metadata = [
|
|
|
{
|