Spaces:
Sleeping
Sleeping
import io | |
from enum import Enum | |
from typing import Any, List, Optional, Tuple, Union | |
import numpy as np | |
from cv2 import ( | |
BORDER_DEFAULT, | |
MORPH_ELLIPSE, | |
MORPH_OPEN, | |
GaussianBlur, | |
getStructuringElement, | |
morphologyEx, | |
) | |
from PIL import Image, ImageOps | |
from PIL.Image import Image as PILImage | |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf | |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml | |
from pymatting.util.util import stack_images | |
from scipy.ndimage import binary_erosion | |
from .session_factory import new_session | |
from .sessions import sessions_class | |
from .sessions.base import BaseSession | |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) | |
class ReturnType(Enum): | |
BYTES = 0 | |
PILLOW = 1 | |
NDARRAY = 2 | |
# Argumentos configurables desde la interfaz Gradio | |
# Estos valores predeterminados pueden modificarse en la interfaz. | |
foreground_threshold = 240 | |
background_threshold = 10 | |
erode_structure_size = 10 | |
alpha_matting = False | |
only_mask = False | |
post_process_mask = False | |
bgcolor = None | |
def alpha_matting_cutout( | |
img: PILImage, | |
mask: PILImage, | |
) -> PILImage: | |
if img.mode == "RGBA" or img.mode == "CMYK": | |
img = img.convert("RGB") | |
img = np.asarray(img) | |
mask = np.asarray(mask) | |
is_foreground = mask > foreground_threshold | |
is_background = mask < background_threshold | |
structure = None | |
if erode_structure_size > 0: | |
structure = np.ones( | |
(erode_structure_size, erode_structure_size), dtype=np.uint8 | |
) | |
is_foreground = binary_erosion(is_foreground, structure=structure) | |
is_background = binary_erosion(is_background, structure=structure, border_value=1) | |
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128) | |
trimap[is_foreground] = 255 | |
trimap[is_background] = 0 | |
img_normalized = img / 255.0 | |
trimap_normalized = trimap / 255.0 | |
alpha = estimate_alpha_cf(img_normalized, trimap_normalized) | |
foreground = estimate_foreground_ml(img_normalized, alpha) | |
cutout = stack_images(foreground, alpha) | |
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) | |
cutout = Image.fromarray(cutout) | |
return cutout | |
def remove( | |
data: Union[bytes, PILImage, np.ndarray], | |
alpha_matting: bool = alpha_matting, | |
alpha_matting_foreground_threshold: int = foreground_threshold, | |
alpha_matting_background_threshold: int = background_threshold, | |
alpha_matting_erode_size: int = erode_structure_size, | |
only_mask: bool = only_mask, | |
post_process_mask: bool = post_process_mask, | |
bgcolor: Optional[Tuple[int, int, int, int]] = bgcolor, | |
*args: Optional[Any], | |
**kwargs: Optional[Any] | |
) -> Union[bytes, PILImage, np.ndarray]: | |
if isinstance(data, PILImage): | |
return_type = ReturnType.PILLOW | |
img = data | |
elif isinstance(data, bytes): | |
return_type = ReturnType.BYTES | |
img = Image.open(io.BytesIO(data)) | |
elif isinstance(data, np.ndarray): | |
return_type = ReturnType.NDARRAY | |
img = Image.fromarray(data) | |
else: | |
raise ValueError("Input type {} is not supported.".format(type(data)) | |
# Fix image orientation | |
img = fix_image_orientation(img) | |
if session is None: | |
session = new_session("u2net", *args, **kwargs) | |
masks = session.predict(img, *args, **kwargs) | |
cutouts = [] | |
for mask in masks: | |
if post_process_mask: | |
mask = Image.fromarray(post_process(np.array(mask))) | |
if only_mask: | |
cutout = mask | |
elif alpha_matting: | |
try: | |
cutout = alpha_matting_cutout( | |
img, | |
mask, | |
) | |
except ValueError: | |
cutout = naive_cutout(img, mask) | |
else: | |
cutout = naive_cutout(img, mask) | |
cutouts.append(cutout) | |
cutout = img | |
if len(cutouts) > 0: | |
cutout = get_concat_v_multi(cutouts) | |
if bgcolor is not None and not only_mask: | |
cutout = apply_background_color(cutout, bgcolor) | |
if ReturnType.PILLOW == return_type: | |
return cutout | |
if ReturnType.NDARRAY == return_type: | |
return np.asarray(cutout) | |
bio = io.BytesIO() | |
cutout.save(bio, "PNG") | |
bio.seek(0) | |
return bio.read() | |