2
0
Daniel Gatis 3 жил өмнө
parent
commit
ca97ce4e59
2 өөрчлөгдсөн 68 нэмэгдсэн , 17 устгасан
  1. 67 17
      rembg/cli.py
  2. 1 0
      requirements.txt

+ 67 - 17
rembg/cli.py

@@ -1,5 +1,6 @@
 import pathlib
 import pathlib
 import sys
 import sys
+import time
 from enum import Enum
 from enum import Enum
 from typing import IO, Optional
 from typing import IO, Optional
 
 
@@ -13,6 +14,8 @@ from asyncer import asyncify
 from fastapi import Depends, FastAPI, File, Query
 from fastapi import Depends, FastAPI, File, Query
 from starlette.responses import Response
 from starlette.responses import Response
 from tqdm import tqdm
 from tqdm import tqdm
+from watchdog.events import FileSystemEvent, FileSystemEventHandler
+from watchdog.observers import Observer
 
 
 from . import _version
 from . import _version
 from .bg import remove
 from .bg import remove
@@ -21,7 +24,7 @@ from .detect import ort_session
 
 
 @click.group()
 @click.group()
 @click.version_option()
 @click.version_option()
-def main():
+def main() -> None:
     pass
     pass
 
 
 
 
@@ -81,7 +84,7 @@ def main():
     default=(None if sys.stdin.isatty() else "-"),
     default=(None if sys.stdin.isatty() else "-"),
     type=click.File("wb", lazy=True),
     type=click.File("wb", lazy=True),
 )
 )
-def i(model: str, input: IO, output: IO, **kwargs):
+def i(model: str, input: IO, output: IO, **kwargs) -> None:
     output.write(remove(input.read(), session=ort_session(model), **kwargs))
     output.write(remove(input.read(), session=ort_session(model), **kwargs))
 
 
 
 
@@ -133,6 +136,14 @@ def i(model: str, input: IO, output: IO, **kwargs):
     show_default=True,
     show_default=True,
     help="output only the mask",
     help="output only the mask",
 )
 )
[email protected](
+    "-w",
+    "--watch",
+    default=False,
+    is_flag=True,
+    show_default=True,
+    help="watches a folder for changes",
+)
 @click.argument(
 @click.argument(
     "input",
     "input",
     type=click.Path(
     type=click.Path(
@@ -153,24 +164,63 @@ def i(model: str, input: IO, output: IO, **kwargs):
         writable=True,
         writable=True,
     ),
     ),
 )
 )
-def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs):
+def p(
+    model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
+) -> None:
     session = ort_session(model)
     session = ort_session(model)
-    for each_input in tqdm(list(input.glob("**/*"))):
-        if each_input.is_dir():
-            continue
 
 
-        mimetype = filetype.guess(each_input)
-        if mimetype is None:
-            continue
-        if mimetype.mime.find("image") < 0:
-            continue
+    def process(each_input: pathlib.Path) -> None:
+        try:
+            mimetype = filetype.guess(each_input)
+            if mimetype is None:
+                return
+            if mimetype.mime.find("image") < 0:
+                return
 
 
-        each_output = (output / each_input.name).with_suffix(".png")
-        each_output.parents[0].mkdir(parents=True, exist_ok=True)
+            each_output = (output / each_input.name).with_suffix(".png")
+            each_output.parents[0].mkdir(parents=True, exist_ok=True)
 
 
-        each_output.write_bytes(
-            remove(each_input.read_bytes(), session=session, **kwargs)
-        )
+            if not each_output.exists():
+                each_output.write_bytes(
+                    remove(each_input.read_bytes(), session=session, **kwargs)
+                )
+
+                if watch:
+                    print(
+                        f"processed: {each_input.absolute()} -> {each_output.absolute()}"
+                    )
+        except e:
+            print(e)
+
+    inputs = list(input.glob("**/*"))
+    if not watch:
+        inputs = tqdm(inputs)
+
+    for each_input in inputs:
+        if not each_input.is_dir():
+            process(each_input)
+
+    if watch:
+        observer = Observer()
+
+        class EventHandler(FileSystemEventHandler):
+            def on_any_event(self, event: FileSystemEvent) -> None:
+                if not (
+                    event.is_directory or event.event_type in ["deleted", "closed"]
+                ):
+                    process(pathlib.Path(event.src_path))
+
+        event_handler = EventHandler()
+        observer.schedule(event_handler, input, recursive=False)
+        observer.start()
+
+        try:
+            while True:
+                time.sleep(1)
+
+        finally:
+            observer.stop()
+            observer.join()
 
 
 
 
 @main.command(help="for a http server")
 @main.command(help="for a http server")
@@ -190,7 +240,7 @@ def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs):
     show_default=True,
     show_default=True,
     help="log level",
     help="log level",
 )
 )
-def s(port: int, log_level: str):
+def s(port: int, log_level: str) -> None:
     sessions: dict[str, ort.InferenceSession] = {}
     sessions: dict[str, ort.InferenceSession] = {}
     tags_metadata = [
     tags_metadata = [
         {
         {

+ 1 - 0
requirements.txt

@@ -13,3 +13,4 @@ scikit-image==0.19.1
 scipy==1.7.3
 scipy==1.7.3
 tqdm==4.62.3
 tqdm==4.62.3
 uvicorn==0.17.0
 uvicorn==0.17.0
+watchdog==2.1.7