File size: 2,846 Bytes
5cb1539
3a6f1f2
 
5f89db3
3a6f1f2
782da61
3a6f1f2
 
 
3faa99b
3a6f1f2
 
 
3eecd07
3a6f1f2
 
 
3eecd07
 
 
 
 
 
5f89db3
 
3eecd07
 
 
782da61
3eecd07
782da61
3eecd07
 
3a6f1f2
 
3eecd07
 
 
 
 
 
 
 
 
3a6f1f2
ef928a1
 
df195bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782da61
 
 
 
df195bf
5ee83f7
782da61
ef928a1
3eecd07
5f89db3
 
3eecd07
3a6f1f2
 
 
782da61
3a6f1f2
3eecd07
 
 
5f89db3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import os
import cv2
import numpy as np

def inference(file, mask, model, alpha_influence, segmentation_strength):
    im = cv2.imread(file, cv2.IMREAD_COLOR)
    cv2.imwrite(os.path.join("input.png"), im)

    from rembg import new_session, remove

    input_path = 'input.png'
    output_path = 'output.png'
    mask_path = 'mask.png'

    with open(input_path, 'rb') as i:
        with open(output_path, 'wb') as o:
            with open(mask_path, 'wb') as m:
                input = i.read()
                output = remove(
                    input, 
                    session=new_session(model), 
                    only_mask=(True if mask == "Mask only" else False),
                    alpha=alpha_influence,  # Control de influencia del canal alfa
                    bg_color=(0, 0, 0, segmentation_strength)  # Control de fuerza de segmentación
                )
                o.write(output)
                m.write(output)

    return os.path.join("output.png"), os.path.join("mask.png")

title = "RemBG"
description = "Gradio demo for RemBG. To use it, simply upload your image and adjust the alpha influence and segmentation strength."
article = "<p style='text-align: center;'><a href='https://github.com/danielgatis/rembg' target='_blank'>Github Repo</a></p>"

def show_processed_image(output_image_path):
    output_image = cv2.imread(output_image_path)
    return output_image

def show_processed_mask(mask_image_path):
    mask_image = cv2.imread(mask_image_path)
    return mask_image

iface = gr.Interface(
    inference, 
    [
        gr.inputs.Image(type="filepath", label="Input"),
        gr.inputs.Radio(
            [
                "Default", 
                "Mask only"
            ], 
            type="value",
            default="Default",
            label="Choices"
        ),
        gr.inputs.Dropdown([
            "u2net", 
            "u2netp", 
            "u2net_human_seg", 
            "u2net_cloth_seg", 
            "silueta",
            "isnet-general-use",
            "isnet-anime",
            "sam",
        ], 
        type="value",
        default="isnet-general-use",
        label="Models"
        ),
        gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.5, label="Alpha Influence"),
        gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.5, label="Segmentation Strength"),
    ], 
    [
        gr.outputs.Image(type="plot", label="Processed Image", output=show_processed_image),
        gr.outputs.Image(type="plot", label="Processed Mask", output=show_processed_mask),
    ],
    title=title,
    description=description,
    article=article,
    examples=[["lion.png", "Default", "u2net", 0.5, 0.5], ["girl.jpg", "Default", "u2net", 0.5, 0.5], ["anime-girl.jpg", "Default", "isnet-anime", 0.5, 0.5]],
    enable_queue=True
)

iface.launch()