File size: 6,799 Bytes
fe7ff55
 
 
 
 
 
 
 
 
0ff9fca
fe7ff55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from torchvision.transforms import ColorJitter, functional as F
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
import evaluate

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the models
original_model_id = "guimCC/segformer-v0-gta"
lora_model_id = "guimCC/segformer-v0-gta-cityscapes"

original_model = SegformerForSemanticSegmentation.from_pretrained(original_model_id).to(device)
lora_model = SegformerForSemanticSegmentation.from_pretrained(lora_model_id).to(device)

# Load the dataset and slice it
dataset = load_dataset("Chris1/cityscapes", split="validation")
sampled_dataset = [dataset[i] for i in range(10)]  # Select the first 10 examples

# Define your custom image processor
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

# Initialize mIoU metric
metric = evaluate.load("mean_iou")

# Define id2label and processor if not already defined
id2label = {
    0: 'road', 1: 'sidewalk', 2: 'building', 3: 'wall', 4: 'fence', 5: 'pole',
    6: 'traffic light', 7: 'traffic sign', 8: 'vegetation', 9: 'terrain',
    10: 'sky', 11: 'person', 12: 'rider', 13: 'car', 14: 'truck', 15: 'bus',
    16: 'train', 17: 'motorcycle', 18: 'bicycle', 19: 'ignore'
}
processor = SegformerImageProcessor()

# Cityscapes color palette
palette = np.array([
    [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
    [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
    [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
    [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32], [0, 0, 0]
])

def handle_grayscale_image(image):
    np_image = np.array(image)
    if np_image.ndim == 2:  # Grayscale image
        np_image = np.tile(np.expand_dims(np_image, -1), (1, 1, 3))
    return Image.fromarray(np_image)

def preprocess_image(image):
    image = handle_grayscale_image(image)
    image = jitter(image)  # Apply color jitter
    pixel_values = F.to_tensor(image).unsqueeze(0)  # Convert to tensor and add batch dimension
    return pixel_values.to(device)

def postprocess_predictions(logits):
    logits = logits.squeeze().detach().cpu().numpy()
    segmentation = np.argmax(logits, axis=0).astype(np.uint8)  # Convert to 8-bit integer
    return segmentation

def compute_miou(logits, labels):
    with torch.no_grad():
        logits_tensor = torch.from_numpy(logits)
        # Scale the logits to the size of the label
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        
        # Ensure the shapes of pred_labels and labels match
        if pred_labels.shape != labels.shape:
            labels = np.resize(labels, pred_labels.shape)
        
        pred_labels = [pred_labels]  # Wrap in a list
        labels = [labels]  # Wrap in a list
        
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=19,
            reduce_labels=processor.do_reduce_labels,
        )
        
        return metrics['mean_iou']


def apply_color_palette(segmentation):
    colored_segmentation = palette[segmentation]
    return Image.fromarray(colored_segmentation.astype(np.uint8))

def create_legend():
    # Define font and its size
    try:
        font = ImageFont.truetype("arial.ttf", 15)
    except IOError:
        font = ImageFont.load_default()

    # Calculate legend dimensions
    num_classes = len(id2label)
    legend_height = 20 * ((num_classes + 1) // 2)  # Two items per row
    legend_width = 250

    # Create a blank image for the legend
    legend = Image.new("RGB", (legend_width, legend_height), (255, 255, 255))
    draw = ImageDraw.Draw(legend)

    # Draw each color and its label
    for i, (class_id, class_name) in enumerate(id2label.items()):
        color = tuple(palette[class_id])
        x = (i % 2) * 120
        y = (i // 2) * 20
        draw.rectangle([x, y, x + 20, y + 20], fill=color)
        draw.text((x + 30, y + 5), class_name, fill=(0, 0, 0), font=font)

    return legend



def inference(index, a):
    """Run inference on the input image with both models."""
    image = sampled_dataset[index]['image']  # Fetch image from the sampled dataset
    pixel_values = preprocess_image(image)

    # Original model inference
    with torch.no_grad():
        original_outputs = original_model(pixel_values=pixel_values)
        original_segmentation = postprocess_predictions(original_outputs.logits)

    # LoRA model inference
    with torch.no_grad():
        lora_outputs = lora_model(pixel_values=pixel_values)
        lora_segmentation = postprocess_predictions(lora_outputs.logits)

    # Compute mIoU
    true_labels = np.array(sampled_dataset[index]['semantic_segmentation'])
    original_miou = compute_miou(original_outputs.logits.detach().cpu().numpy(), true_labels)
    lora_miou = compute_miou(lora_outputs.logits.detach().cpu().numpy(), true_labels)
    # original_miou = 0
    # lora_miou = 0

    # Apply color palette
    original_segmentation_image = apply_color_palette(original_segmentation)
    lora_segmentation_image = apply_color_palette(lora_segmentation)

   # Create legend
    legend = create_legend()

    # Return the original image, the segmentations, and mIoU
    return (
        image,
        original_segmentation_image,
        lora_segmentation_image,
        legend,
        f"Original Model mIoU: {original_miou:.2f}",
        f"LoRA Model mIoU: {lora_miou:.2f}"
    )

# Create a list of image options for the user to select from
image_options = [(f"Image {i}", i) for i in range(len(sampled_dataset))]

# Create the Gradio interface
iface = gr.Interface(
    fn=inference,
    inputs=[
        gr.Dropdown(label="Select Image", choices=image_options),
        gr.Image(type="pil", label="Legend", value=create_legend)
    ],
    outputs=[
        gr.Image(type="pil", label="Selected Image"),
        gr.Image(type="pil", label="Original Model Output"),
        gr.Image(type="pil", label="LoRA Model Output"),
        gr.Textbox(label="Original Model mIoU"),
        gr.Textbox(label="LoRA Model mIoU")
    ],
    title="Segformer Cityscapes Inference",
    description="Select an image from the Cityscapes dataset to see the segmentation results from both the original and fine-tuned Segformer models.",
)

# Launch the interface
iface.launch()