Spaces:
Sleeping
Sleeping
import json | |
import os | |
import webbrowser | |
from typing import Optional, Tuple, cast | |
import aiohttp | |
import click | |
import gradio as gr | |
import uvicorn | |
from asyncer import asyncify | |
from fastapi import Depends, FastAPI, File, Form, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from starlette.responses import Response | |
from .._version import get_versions | |
from ..bg import remove | |
from ..session_factory import new_session | |
from ..sessions import sessions_names | |
from ..sessions.base import BaseSession | |
def s_command(port: int, log_level: str, threads: int) -> None: | |
sessions: dict[str, BaseSession] = {} | |
tags_metadata = [ | |
{ | |
"name": "Background Removal", | |
"description": "Endpoints that perform background removal with different image sources.", | |
"externalDocs": { | |
"description": "GitHub Source", | |
"url": "https://github.com/danielgatis/rembg", | |
}, | |
}, | |
] | |
app = FastAPI( | |
title="Rembg", | |
description="Rembg is a tool to remove images background. That is it.", | |
version=get_versions()["version"], | |
contact={ | |
"name": "Daniel Gatis", | |
"url": "https://github.com/danielgatis", | |
"email": "[email protected]", | |
}, | |
license_info={ | |
"name": "MIT License", | |
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", | |
}, | |
openapi_tags=tags_metadata, | |
docs_url="/api", | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_credentials=True, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class CommonQueryParams: | |
def __init__( | |
self, | |
model: str = Query( | |
description="Model to use when processing image", | |
regex=r"(" + "|".join(sessions_names) + ")", | |
default="u2net", | |
), | |
a: bool = Query(default=False, description="Enable Alpha Matting"), | |
af: int = Query( | |
default=240, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Foreground Threshold)", | |
), | |
ab: int = Query( | |
default=10, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Background Threshold)", | |
), | |
ae: int = Query( | |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)" | |
), | |
om: bool = Query(default=False, description="Only Mask"), | |
ppm: bool = Query(default=False, description="Post Process Mask"), | |
bgc: Optional[str] = Query(default=None, description="Background Color"), | |
extras: Optional[str] = Query( | |
default=None, description="Extra parameters as JSON" | |
), | |
): | |
self.model = model | |
self.a = a | |
self.af = af | |
self.ab = ab | |
self.ae = ae | |
self.om = om | |
self.ppm = ppm | |
self.extras = extras | |
self.bgc = ( | |
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) | |
if bgc | |
else None | |
) | |
class CommonQueryPostParams: | |
def __init__( | |
self, | |
model: str = Form( | |
description="Model to use when processing image", | |
regex=r"(" + "|".join(sessions_names) + ")", | |
default="u2net", | |
), | |
a: bool = Form(default=False, description="Enable Alpha Matting"), | |
af: int = Form( | |
default=240, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Foreground Threshold)", | |
), | |
ab: int = Form( | |
default=10, | |
ge=0, | |
le=255, | |
description="Alpha Matting (Background Threshold)", | |
), | |
ae: int = Form( | |
default=10, ge=0, description="Alpha Matting (Erode Structure Size)" | |
), | |
om: bool = Form(default=False, description="Only Mask"), | |
ppm: bool = Form(default=False, description="Post Process Mask"), | |
bgc: Optional[str] = Query(default=None, description="Background Color"), | |
extras: Optional[str] = Query( | |
default=None, description="Extra parameters as JSON" | |
), | |
): | |
self.model = model | |
self.a = a | |
self.af = af | |
self.ab = ab | |
self.ae = ae | |
self.om = om | |
self.ppm = ppm | |
self.extras = extras | |
self.bgc = ( | |
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) | |
if bgc | |
else None | |
) | |
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: | |
kwargs = {} | |
if commons.extras: | |
try: | |
kwargs.update(json.loads(commons.extras)) | |
except Exception: | |
pass | |
return Response( | |
remove( | |
content, | |
session=sessions.setdefault(commons.model, new_session(commons.model)), | |
alpha_matting=commons.a, | |
alpha_matting_foreground_threshold=commons.af, | |
alpha_matting_background_threshold=commons.ab, | |
alpha_matting_erode_size=commons.ae, | |
only_mask=commons.om, | |
post_process_mask=commons.ppm, | |
bgcolor=commons.bgc, | |
**kwargs, | |
), | |
media_type="image/png", | |
) | |
def startup(): | |
try: | |
webbrowser.open(f"http://localhost:{port}") | |
except Exception: | |
pass | |
if threads is not None: | |
from anyio import CapacityLimiter | |
from anyio.lowlevel import RunVar | |
RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) | |
async def get_index( | |
url: str = Query( | |
default=..., description="URL of the image that has to be processed." | |
), | |
commons: CommonQueryParams = Depends(), | |
): | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url) as response: | |
file = await response.read() | |
return await asyncify(im_without_bg)(file, commons) | |
async def post_index( | |
file: bytes = File( | |
default=..., | |
description="Image file (byte stream) that has to be processed.", | |
), | |
commons: CommonQueryPostParams = Depends(), | |
): | |
return await asyncify(im_without_bg)(file, commons) # type: ignore | |
def gr_app(app): | |
def inference(input_path, model): | |
output_path = "output.png" | |
with open(input_path, "rb") as i: | |
with open(output_path, "wb") as o: | |
input = i.read() | |
output = remove(input, session=new_session(model)) | |
o.write(output) | |
return os.path.join(output_path) | |
interface = gr.Interface( | |
inference, | |
[ | |
gr.components.Image(type="filepath", label="Input"), | |
gr.components.Dropdown( | |
[ | |
"u2net", | |
"u2netp", | |
"u2net_human_seg", | |
"u2net_cloth_seg", | |
"silueta", | |
"isnet-general-use", | |
"isnet-anime", | |
], | |
value="u2net", | |
label="Models", | |
), | |
], | |
gr.components.Image(type="filepath", label="Output"), | |
) | |
interface.queue(concurrency_count=3) | |
app = gr.mount_gradio_app(app, interface, path="/") | |
return app | |
print(f"To access the API documentation, go to http://localhost:{port}/api") | |
print(f"To access the UI, go to http://localhost:{port}") | |
uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level) | |