import json import os import webbrowser from typing import Optional, Tuple, cast import aiohttp import click import gradio as gr import uvicorn from asyncer import asyncify from fastapi import Depends, FastAPI, File, Form, Query from fastapi.middleware.cors import CORSMiddleware from starlette.responses import Response from .._version import get_versions from ..bg import remove from ..session_factory import new_session from ..sessions import sessions_names from ..sessions.base import BaseSession @click.command( name="s", help="for a http server", ) @click.option( "-p", "--port", default=5000, type=int, show_default=True, help="port", ) @click.option( "-l", "--log_level", default="info", type=str, show_default=True, help="log level", ) @click.option( "-t", "--threads", default=None, type=int, show_default=True, help="number of worker threads", ) def s_command(port: int, log_level: str, threads: int) -> None: sessions: dict[str, BaseSession] = {} tags_metadata = [ { "name": "Background Removal", "description": "Endpoints that perform background removal with different image sources.", "externalDocs": { "description": "GitHub Source", "url": "https://github.com/danielgatis/rembg", }, }, ] app = FastAPI( title="Rembg", description="Rembg is a tool to remove images background. That is it.", version=get_versions()["version"], contact={ "name": "Daniel Gatis", "url": "https://github.com/danielgatis", "email": "danielgatis@gmail.com", }, license_info={ "name": "MIT License", "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", }, openapi_tags=tags_metadata, docs_url="/api", ) app.add_middleware( CORSMiddleware, allow_credentials=True, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class CommonQueryParams: def __init__( self, model: str = Query( description="Model to use when processing image", regex=r"(" + "|".join(sessions_names) + ")", default="u2net", ), a: bool = Query(default=False, description="Enable Alpha Matting"), af: int = Query( default=240, ge=0, le=255, description="Alpha Matting (Foreground Threshold)", ), ab: int = Query( default=10, ge=0, le=255, description="Alpha Matting (Background Threshold)", ), ae: int = Query( default=10, ge=0, description="Alpha Matting (Erode Structure Size)" ), om: bool = Query(default=False, description="Only Mask"), ppm: bool = Query(default=False, description="Post Process Mask"), bgc: Optional[str] = Query(default=None, description="Background Color"), extras: Optional[str] = Query( default=None, description="Extra parameters as JSON" ), ): self.model = model self.a = a self.af = af self.ab = ab self.ae = ae self.om = om self.ppm = ppm self.extras = extras self.bgc = ( cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) if bgc else None ) class CommonQueryPostParams: def __init__( self, model: str = Form( description="Model to use when processing image", regex=r"(" + "|".join(sessions_names) + ")", default="u2net", ), a: bool = Form(default=False, description="Enable Alpha Matting"), af: int = Form( default=240, ge=0, le=255, description="Alpha Matting (Foreground Threshold)", ), ab: int = Form( default=10, ge=0, le=255, description="Alpha Matting (Background Threshold)", ), ae: int = Form( default=10, ge=0, description="Alpha Matting (Erode Structure Size)" ), om: bool = Form(default=False, description="Only Mask"), ppm: bool = Form(default=False, description="Post Process Mask"), bgc: Optional[str] = Query(default=None, description="Background Color"), extras: Optional[str] = Query( default=None, description="Extra parameters as JSON" ), ): self.model = model self.a = a self.af = af self.ab = ab self.ae = ae self.om = om self.ppm = ppm self.extras = extras self.bgc = ( cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) if bgc else None ) def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: kwargs = {} if commons.extras: try: kwargs.update(json.loads(commons.extras)) except Exception: pass return Response( remove( content, session=sessions.setdefault(commons.model, new_session(commons.model)), alpha_matting=commons.a, alpha_matting_foreground_threshold=commons.af, alpha_matting_background_threshold=commons.ab, alpha_matting_erode_size=commons.ae, only_mask=commons.om, post_process_mask=commons.ppm, bgcolor=commons.bgc, **kwargs, ), media_type="image/png", ) @app.on_event("startup") def startup(): try: webbrowser.open(f"http://localhost:{port}") except Exception: pass if threads is not None: from anyio import CapacityLimiter from anyio.lowlevel import RunVar RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) @app.get( path="/api/remove", tags=["Background Removal"], summary="Remove from URL", description="Removes the background from an image obtained by retrieving an URL.", ) async def get_index( url: str = Query( default=..., description="URL of the image that has to be processed." ), commons: CommonQueryParams = Depends(), ): async with aiohttp.ClientSession() as session: async with session.get(url) as response: file = await response.read() return await asyncify(im_without_bg)(file, commons) @app.post( path="/api/remove", tags=["Background Removal"], summary="Remove from Stream", description="Removes the background from an image sent within the request itself.", ) async def post_index( file: bytes = File( default=..., description="Image file (byte stream) that has to be processed.", ), commons: CommonQueryPostParams = Depends(), ): return await asyncify(im_without_bg)(file, commons) # type: ignore def gr_app(app): def inference(input_path, model): output_path = "output.png" with open(input_path, "rb") as i: with open(output_path, "wb") as o: input = i.read() output = remove(input, session=new_session(model)) o.write(output) return os.path.join(output_path) interface = gr.Interface( inference, [ gr.components.Image(type="filepath", label="Input"), gr.components.Dropdown( [ "u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta", "isnet-general-use", "isnet-anime", ], value="u2net", label="Models", ), ], gr.components.Image(type="filepath", label="Output"), ) interface.queue(concurrency_count=3) app = gr.mount_gradio_app(app, interface, path="/") return app print(f"To access the API documentation, go to http://localhost:{port}/api") print(f"To access the UI, go to http://localhost:{port}") uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)