|
@@ -0,0 +1,161 @@
|
|
|
+import asyncio
|
|
|
+import io
|
|
|
+import json
|
|
|
+import os
|
|
|
+import sys
|
|
|
+from typing import IO
|
|
|
+
|
|
|
+import click
|
|
|
+from PIL import Image
|
|
|
+
|
|
|
+from ..bg import remove
|
|
|
+from ..session_factory import new_session
|
|
|
+from ..sessions import sessions_names
|
|
|
+
|
|
|
+
|
|
|
[email protected](
|
|
|
+ name="b",
|
|
|
+ help="for a byte stream as input",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-m",
|
|
|
+ "--model",
|
|
|
+ default="u2net",
|
|
|
+ type=click.Choice(sessions_names),
|
|
|
+ show_default=True,
|
|
|
+ show_choices=True,
|
|
|
+ help="model name",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-a",
|
|
|
+ "--alpha-matting",
|
|
|
+ is_flag=True,
|
|
|
+ show_default=True,
|
|
|
+ help="use alpha matting",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-af",
|
|
|
+ "--alpha-matting-foreground-threshold",
|
|
|
+ default=240,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="trimap fg threshold",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-ab",
|
|
|
+ "--alpha-matting-background-threshold",
|
|
|
+ default=10,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="trimap bg threshold",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-ae",
|
|
|
+ "--alpha-matting-erode-size",
|
|
|
+ default=10,
|
|
|
+ type=int,
|
|
|
+ show_default=True,
|
|
|
+ help="erode size",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-om",
|
|
|
+ "--only-mask",
|
|
|
+ is_flag=True,
|
|
|
+ show_default=True,
|
|
|
+ help="output only the mask",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-ppm",
|
|
|
+ "--post-process-mask",
|
|
|
+ is_flag=True,
|
|
|
+ show_default=True,
|
|
|
+ help="post process the mask",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "-bgc",
|
|
|
+ "--bgcolor",
|
|
|
+ default=None,
|
|
|
+ type=(int, int, int, int),
|
|
|
+ nargs=4,
|
|
|
+ help="Background color (R G B A) to replace the removed background with",
|
|
|
+)
|
|
|
[email protected]("-x", "--extras", type=str)
|
|
|
[email protected](
|
|
|
+ "-o",
|
|
|
+ "--output_specifier",
|
|
|
+ type=str,
|
|
|
+ help="printf-style specifier for output filenames (e.g. 'output-%d.png'))",
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "image_width",
|
|
|
+ type=int,
|
|
|
+)
|
|
|
[email protected](
|
|
|
+ "image_height",
|
|
|
+ type=int,
|
|
|
+)
|
|
|
+def rs_command(
|
|
|
+ model: str,
|
|
|
+ extras: str,
|
|
|
+ image_width: int,
|
|
|
+ image_height: int,
|
|
|
+ output_specifier: str,
|
|
|
+ **kwargs
|
|
|
+) -> None:
|
|
|
+ try:
|
|
|
+ kwargs.update(json.loads(extras))
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ session = new_session(model)
|
|
|
+ bytes_per_img = image_width * image_height * 3
|
|
|
+
|
|
|
+ if output_specifier:
|
|
|
+ output_dir = os.path.dirname(
|
|
|
+ os.path.abspath(os.path.expanduser(output_specifier))
|
|
|
+ )
|
|
|
+
|
|
|
+ if not os.path.isdir(output_dir):
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
+
|
|
|
+ def img_to_byte_array(img: Image) -> bytes:
|
|
|
+ buff = io.BytesIO()
|
|
|
+ img.save(buff, format="PNG")
|
|
|
+ return buff.getvalue()
|
|
|
+
|
|
|
+ async def connect_stdin_stdout():
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ reader = asyncio.StreamReader()
|
|
|
+ protocol = asyncio.StreamReaderProtocol(reader)
|
|
|
+
|
|
|
+ await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
|
|
+ w_transport, w_protocol = await loop.connect_write_pipe(
|
|
|
+ asyncio.streams.FlowControlMixin, sys.stdout
|
|
|
+ )
|
|
|
+
|
|
|
+ writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
|
|
|
+ return reader, writer
|
|
|
+
|
|
|
+ async def main():
|
|
|
+ reader, writer = await connect_stdin_stdout()
|
|
|
+
|
|
|
+ idx = 0
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ img_bytes = await reader.readexactly(bytes_per_img)
|
|
|
+ if not img_bytes:
|
|
|
+ break
|
|
|
+
|
|
|
+ img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
|
|
|
+ output = remove(img, session=session, **kwargs)
|
|
|
+
|
|
|
+ if output_specifier:
|
|
|
+ output.save((output_specifier % idx), format="PNG")
|
|
|
+ else:
|
|
|
+ writer.write(img_to_byte_array(output))
|
|
|
+
|
|
|
+ idx += 1
|
|
|
+ except asyncio.IncompleteReadError:
|
|
|
+ break
|
|
|
+
|
|
|
+ asyncio.run(main())
|