import asyncio import io import json import os import sys from typing import IO import click from PIL import Image from ..bg import remove from ..session_factory import new_session from ..sessions import sessions_names @click.command( name="b", help="for a byte stream as input", ) @click.option( "-m", "--model", default="u2net", type=click.Choice(sessions_names), show_default=True, show_choices=True, help="model name", ) @click.option( "-a", "--alpha-matting", is_flag=True, show_default=True, help="use alpha matting", ) @click.option( "-af", "--alpha-matting-foreground-threshold", default=240, type=int, show_default=True, help="trimap fg threshold", ) @click.option( "-ab", "--alpha-matting-background-threshold", default=10, type=int, show_default=True, help="trimap bg threshold", ) @click.option( "-ae", "--alpha-matting-erode-size", default=10, type=int, show_default=True, help="erode size", ) @click.option( "-om", "--only-mask", is_flag=True, show_default=True, help="output only the mask", ) @click.option( "-ppm", "--post-process-mask", is_flag=True, show_default=True, help="post process the mask", ) @click.option( "-bgc", "--bgcolor", default=None, type=(int, int, int, int), nargs=4, help="Background color (R G B A) to replace the removed background with", ) @click.option("-x", "--extras", type=str) @click.option( "-o", "--output_specifier", type=str, help="printf-style specifier for output filenames (e.g. 'output-%d.png'))", ) @click.argument( "image_width", type=int, ) @click.argument( "image_height", type=int, ) def rs_command( model: str, extras: str, image_width: int, image_height: int, output_specifier: str, **kwargs ) -> None: try: kwargs.update(json.loads(extras)) except Exception: pass session = new_session(model) bytes_per_img = image_width * image_height * 3 if output_specifier: output_dir = os.path.dirname( os.path.abspath(os.path.expanduser(output_specifier)) ) if not os.path.isdir(output_dir): os.makedirs(output_dir, exist_ok=True) def img_to_byte_array(img: Image) -> bytes: buff = io.BytesIO() img.save(buff, format="PNG") return buff.getvalue() async def connect_stdin_stdout(): loop = asyncio.get_event_loop() reader = asyncio.StreamReader() protocol = asyncio.StreamReaderProtocol(reader) await loop.connect_read_pipe(lambda: protocol, sys.stdin) w_transport, w_protocol = await loop.connect_write_pipe( asyncio.streams.FlowControlMixin, sys.stdout ) writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop) return reader, writer async def main(): reader, writer = await connect_stdin_stdout() idx = 0 while True: try: img_bytes = await reader.readexactly(bytes_per_img) if not img_bytes: break img = Image.frombytes("RGB", (image_width, image_height), img_bytes) output = remove(img, session=session, **kwargs) if output_specifier: output.save((output_specifier % idx), format="PNG") else: writer.write(img_to_byte_array(output)) idx += 1 except asyncio.IncompleteReadError: break asyncio.run(main())