RemBG_super / rembg /commands /s_command.py
KenjieDec's picture
Update
5f57808
raw
history blame
9.12 kB
import json
import os
import webbrowser
from typing import Optional, Tuple, cast
import aiohttp
import click
import gradio as gr
import uvicorn
from asyncer import asyncify
from fastapi import Depends, FastAPI, File, Form, Query
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response
from .._version import get_versions
from ..bg import remove
from ..session_factory import new_session
from ..sessions import sessions_names
from ..sessions.base import BaseSession
@click.command(
name="s",
help="for a http server",
)
@click.option(
"-p",
"--port",
default=5000,
type=int,
show_default=True,
help="port",
)
@click.option(
"-l",
"--log_level",
default="info",
type=str,
show_default=True,
help="log level",
)
@click.option(
"-t",
"--threads",
default=None,
type=int,
show_default=True,
help="number of worker threads",
)
def s_command(port: int, log_level: str, threads: int) -> None:
sessions: dict[str, BaseSession] = {}
tags_metadata = [
{
"name": "Background Removal",
"description": "Endpoints that perform background removal with different image sources.",
"externalDocs": {
"description": "GitHub Source",
"url": "https://github.com/danielgatis/rembg",
},
},
]
app = FastAPI(
title="Rembg",
description="Rembg is a tool to remove images background. That is it.",
version=get_versions()["version"],
contact={
"name": "Daniel Gatis",
"url": "https://github.com/danielgatis",
"email": "[email protected]",
},
license_info={
"name": "MIT License",
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
},
openapi_tags=tags_metadata,
docs_url="/api",
)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class CommonQueryParams:
def __init__(
self,
model: str = Query(
description="Model to use when processing image",
regex=r"(" + "|".join(sessions_names) + ")",
default="u2net",
),
a: bool = Query(default=False, description="Enable Alpha Matting"),
af: int = Query(
default=240,
ge=0,
le=255,
description="Alpha Matting (Foreground Threshold)",
),
ab: int = Query(
default=10,
ge=0,
le=255,
description="Alpha Matting (Background Threshold)",
),
ae: int = Query(
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
),
om: bool = Query(default=False, description="Only Mask"),
ppm: bool = Query(default=False, description="Post Process Mask"),
bgc: Optional[str] = Query(default=None, description="Background Color"),
extras: Optional[str] = Query(
default=None, description="Extra parameters as JSON"
),
):
self.model = model
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.om = om
self.ppm = ppm
self.extras = extras
self.bgc = (
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
if bgc
else None
)
class CommonQueryPostParams:
def __init__(
self,
model: str = Form(
description="Model to use when processing image",
regex=r"(" + "|".join(sessions_names) + ")",
default="u2net",
),
a: bool = Form(default=False, description="Enable Alpha Matting"),
af: int = Form(
default=240,
ge=0,
le=255,
description="Alpha Matting (Foreground Threshold)",
),
ab: int = Form(
default=10,
ge=0,
le=255,
description="Alpha Matting (Background Threshold)",
),
ae: int = Form(
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
),
om: bool = Form(default=False, description="Only Mask"),
ppm: bool = Form(default=False, description="Post Process Mask"),
bgc: Optional[str] = Query(default=None, description="Background Color"),
extras: Optional[str] = Query(
default=None, description="Extra parameters as JSON"
),
):
self.model = model
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.om = om
self.ppm = ppm
self.extras = extras
self.bgc = (
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
if bgc
else None
)
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
kwargs = {}
if commons.extras:
try:
kwargs.update(json.loads(commons.extras))
except Exception:
pass
return Response(
remove(
content,
session=sessions.setdefault(commons.model, new_session(commons.model)),
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_size=commons.ae,
only_mask=commons.om,
post_process_mask=commons.ppm,
bgcolor=commons.bgc,
**kwargs,
),
media_type="image/png",
)
@app.on_event("startup")
def startup():
try:
webbrowser.open(f"http://localhost:{port}")
except Exception:
pass
if threads is not None:
from anyio import CapacityLimiter
from anyio.lowlevel import RunVar
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
@app.get(
path="/api/remove",
tags=["Background Removal"],
summary="Remove from URL",
description="Removes the background from an image obtained by retrieving an URL.",
)
async def get_index(
url: str = Query(
default=..., description="URL of the image that has to be processed."
),
commons: CommonQueryParams = Depends(),
):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
file = await response.read()
return await asyncify(im_without_bg)(file, commons)
@app.post(
path="/api/remove",
tags=["Background Removal"],
summary="Remove from Stream",
description="Removes the background from an image sent within the request itself.",
)
async def post_index(
file: bytes = File(
default=...,
description="Image file (byte stream) that has to be processed.",
),
commons: CommonQueryPostParams = Depends(),
):
return await asyncify(im_without_bg)(file, commons) # type: ignore
def gr_app(app):
def inference(input_path, model):
output_path = "output.png"
with open(input_path, "rb") as i:
with open(output_path, "wb") as o:
input = i.read()
output = remove(input, session=new_session(model))
o.write(output)
return os.path.join(output_path)
interface = gr.Interface(
inference,
[
gr.components.Image(type="filepath", label="Input"),
gr.components.Dropdown(
[
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
"isnet-anime",
],
value="u2net",
label="Models",
),
],
gr.components.Image(type="filepath", label="Output"),
)
interface.queue(concurrency_count=3)
app = gr.mount_gradio_app(app, interface, path="/")
return app
print(f"To access the API documentation, go to http://localhost:{port}/api")
print(f"To access the UI, go to http://localhost:{port}")
uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)