Explorar el Código

add watch mode

Daniel Gatis hace 3 años
padre
commit
ca97ce4e59
Se han modificado 2 ficheros con 68 adiciones y 17 borrados
  1. 67 17
      rembg/cli.py
  2. 1 0
      requirements.txt

+ 67 - 17
rembg/cli.py

@@ -1,5 +1,6 @@
 import pathlib
 import sys
+import time
 from enum import Enum
 from typing import IO, Optional
 
@@ -13,6 +14,8 @@ from asyncer import asyncify
 from fastapi import Depends, FastAPI, File, Query
 from starlette.responses import Response
 from tqdm import tqdm
+from watchdog.events import FileSystemEvent, FileSystemEventHandler
+from watchdog.observers import Observer
 
 from . import _version
 from .bg import remove
@@ -21,7 +24,7 @@ from .detect import ort_session
 
 @click.group()
 @click.version_option()
-def main():
+def main() -> None:
     pass
 
 
@@ -81,7 +84,7 @@ def main():
     default=(None if sys.stdin.isatty() else "-"),
     type=click.File("wb", lazy=True),
 )
-def i(model: str, input: IO, output: IO, **kwargs):
+def i(model: str, input: IO, output: IO, **kwargs) -> None:
     output.write(remove(input.read(), session=ort_session(model), **kwargs))
 
 
@@ -133,6 +136,14 @@ def i(model: str, input: IO, output: IO, **kwargs):
     show_default=True,
     help="output only the mask",
 )
[email protected](
+    "-w",
+    "--watch",
+    default=False,
+    is_flag=True,
+    show_default=True,
+    help="watches a folder for changes",
+)
 @click.argument(
     "input",
     type=click.Path(
@@ -153,24 +164,63 @@ def i(model: str, input: IO, output: IO, **kwargs):
         writable=True,
     ),
 )
-def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs):
+def p(
+    model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
+) -> None:
     session = ort_session(model)
-    for each_input in tqdm(list(input.glob("**/*"))):
-        if each_input.is_dir():
-            continue
 
-        mimetype = filetype.guess(each_input)
-        if mimetype is None:
-            continue
-        if mimetype.mime.find("image") < 0:
-            continue
+    def process(each_input: pathlib.Path) -> None:
+        try:
+            mimetype = filetype.guess(each_input)
+            if mimetype is None:
+                return
+            if mimetype.mime.find("image") < 0:
+                return
 
-        each_output = (output / each_input.name).with_suffix(".png")
-        each_output.parents[0].mkdir(parents=True, exist_ok=True)
+            each_output = (output / each_input.name).with_suffix(".png")
+            each_output.parents[0].mkdir(parents=True, exist_ok=True)
 
-        each_output.write_bytes(
-            remove(each_input.read_bytes(), session=session, **kwargs)
-        )
+            if not each_output.exists():
+                each_output.write_bytes(
+                    remove(each_input.read_bytes(), session=session, **kwargs)
+                )
+
+                if watch:
+                    print(
+                        f"processed: {each_input.absolute()} -> {each_output.absolute()}"
+                    )
+        except e:
+            print(e)
+
+    inputs = list(input.glob("**/*"))
+    if not watch:
+        inputs = tqdm(inputs)
+
+    for each_input in inputs:
+        if not each_input.is_dir():
+            process(each_input)
+
+    if watch:
+        observer = Observer()
+
+        class EventHandler(FileSystemEventHandler):
+            def on_any_event(self, event: FileSystemEvent) -> None:
+                if not (
+                    event.is_directory or event.event_type in ["deleted", "closed"]
+                ):
+                    process(pathlib.Path(event.src_path))
+
+        event_handler = EventHandler()
+        observer.schedule(event_handler, input, recursive=False)
+        observer.start()
+
+        try:
+            while True:
+                time.sleep(1)
+
+        finally:
+            observer.stop()
+            observer.join()
 
 
 @main.command(help="for a http server")
@@ -190,7 +240,7 @@ def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs):
     show_default=True,
     help="log level",
 )
-def s(port: int, log_level: str):
+def s(port: int, log_level: str) -> None:
     sessions: dict[str, ort.InferenceSession] = {}
     tags_metadata = [
         {

+ 1 - 0
requirements.txt

@@ -13,3 +13,4 @@ scikit-image==0.19.1
 scipy==1.7.3
 tqdm==4.62.3
 uvicorn==0.17.0
+watchdog==2.1.7