瀏覽代碼

add bynary stream support based on #421

Daniel Gatis 2 年之前
父節點
當前提交
c287d1c3d4
共有 1 個文件被更改,包括 161 次插入0 次删除
  1. 161 0
      rembg/commands/b_command.py

+ 161 - 0
rembg/commands/b_command.py

@@ -0,0 +1,161 @@
+import asyncio
+import io
+import json
+import os
+import sys
+from typing import IO
+
+import click
+from PIL import Image
+
+from ..bg import remove
+from ..session_factory import new_session
+from ..sessions import sessions_names
+
+
[email protected](
+    name="b",
+    help="for a byte stream as input",
+)
[email protected](
+    "-m",
+    "--model",
+    default="u2net",
+    type=click.Choice(sessions_names),
+    show_default=True,
+    show_choices=True,
+    help="model name",
+)
[email protected](
+    "-a",
+    "--alpha-matting",
+    is_flag=True,
+    show_default=True,
+    help="use alpha matting",
+)
[email protected](
+    "-af",
+    "--alpha-matting-foreground-threshold",
+    default=240,
+    type=int,
+    show_default=True,
+    help="trimap fg threshold",
+)
[email protected](
+    "-ab",
+    "--alpha-matting-background-threshold",
+    default=10,
+    type=int,
+    show_default=True,
+    help="trimap bg threshold",
+)
[email protected](
+    "-ae",
+    "--alpha-matting-erode-size",
+    default=10,
+    type=int,
+    show_default=True,
+    help="erode size",
+)
[email protected](
+    "-om",
+    "--only-mask",
+    is_flag=True,
+    show_default=True,
+    help="output only the mask",
+)
[email protected](
+    "-ppm",
+    "--post-process-mask",
+    is_flag=True,
+    show_default=True,
+    help="post process the mask",
+)
[email protected](
+    "-bgc",
+    "--bgcolor",
+    default=None,
+    type=(int, int, int, int),
+    nargs=4,
+    help="Background color (R G B A) to replace the removed background with",
+)
[email protected]("-x", "--extras", type=str)
[email protected](
+    "-o",
+    "--output_specifier",
+    type=str,
+    help="printf-style specifier for output filenames (e.g. 'output-%d.png'))",
+)
[email protected](
+    "image_width",
+    type=int,
+)
[email protected](
+    "image_height",
+    type=int,
+)
+def rs_command(
+    model: str,
+    extras: str,
+    image_width: int,
+    image_height: int,
+    output_specifier: str,
+    **kwargs
+) -> None:
+    try:
+        kwargs.update(json.loads(extras))
+    except Exception:
+        pass
+
+    session = new_session(model)
+    bytes_per_img = image_width * image_height * 3
+
+    if output_specifier:
+        output_dir = os.path.dirname(
+            os.path.abspath(os.path.expanduser(output_specifier))
+        )
+
+        if not os.path.isdir(output_dir):
+            os.makedirs(output_dir, exist_ok=True)
+
+    def img_to_byte_array(img: Image) -> bytes:
+        buff = io.BytesIO()
+        img.save(buff, format="PNG")
+        return buff.getvalue()
+
+    async def connect_stdin_stdout():
+        loop = asyncio.get_event_loop()
+        reader = asyncio.StreamReader()
+        protocol = asyncio.StreamReaderProtocol(reader)
+
+        await loop.connect_read_pipe(lambda: protocol, sys.stdin)
+        w_transport, w_protocol = await loop.connect_write_pipe(
+            asyncio.streams.FlowControlMixin, sys.stdout
+        )
+
+        writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
+        return reader, writer
+
+    async def main():
+        reader, writer = await connect_stdin_stdout()
+
+        idx = 0
+        while True:
+            try:
+                img_bytes = await reader.readexactly(bytes_per_img)
+                if not img_bytes:
+                    break
+
+                img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
+                output = remove(img, session=session, **kwargs)
+
+                if output_specifier:
+                    output.save((output_specifier % idx), format="PNG")
+                else:
+                    writer.write(img_to_byte_array(output))
+
+                idx += 1
+            except asyncio.IncompleteReadError:
+                break
+
+    asyncio.run(main())