Spaces:
Runtime error
Runtime error
File size: 6,799 Bytes
fe7ff55 0ff9fca fe7ff55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import gradio as gr
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from torchvision.transforms import ColorJitter, functional as F
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
import evaluate
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the models
original_model_id = "guimCC/segformer-v0-gta"
lora_model_id = "guimCC/segformer-v0-gta-cityscapes"
original_model = SegformerForSemanticSegmentation.from_pretrained(original_model_id).to(device)
lora_model = SegformerForSemanticSegmentation.from_pretrained(lora_model_id).to(device)
# Load the dataset and slice it
dataset = load_dataset("Chris1/cityscapes", split="validation")
sampled_dataset = [dataset[i] for i in range(10)] # Select the first 10 examples
# Define your custom image processor
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
# Initialize mIoU metric
metric = evaluate.load("mean_iou")
# Define id2label and processor if not already defined
id2label = {
0: 'road', 1: 'sidewalk', 2: 'building', 3: 'wall', 4: 'fence', 5: 'pole',
6: 'traffic light', 7: 'traffic sign', 8: 'vegetation', 9: 'terrain',
10: 'sky', 11: 'person', 12: 'rider', 13: 'car', 14: 'truck', 15: 'bus',
16: 'train', 17: 'motorcycle', 18: 'bicycle', 19: 'ignore'
}
processor = SegformerImageProcessor()
# Cityscapes color palette
palette = np.array([
[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
[153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
[70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32], [0, 0, 0]
])
def handle_grayscale_image(image):
np_image = np.array(image)
if np_image.ndim == 2: # Grayscale image
np_image = np.tile(np.expand_dims(np_image, -1), (1, 1, 3))
return Image.fromarray(np_image)
def preprocess_image(image):
image = handle_grayscale_image(image)
image = jitter(image) # Apply color jitter
pixel_values = F.to_tensor(image).unsqueeze(0) # Convert to tensor and add batch dimension
return pixel_values.to(device)
def postprocess_predictions(logits):
logits = logits.squeeze().detach().cpu().numpy()
segmentation = np.argmax(logits, axis=0).astype(np.uint8) # Convert to 8-bit integer
return segmentation
def compute_miou(logits, labels):
with torch.no_grad():
logits_tensor = torch.from_numpy(logits)
# Scale the logits to the size of the label
logits_tensor = nn.functional.interpolate(
logits_tensor,
size=labels.shape[-2:],
mode="bilinear",
align_corners=False,
).argmax(dim=1)
pred_labels = logits_tensor.detach().cpu().numpy()
# Ensure the shapes of pred_labels and labels match
if pred_labels.shape != labels.shape:
labels = np.resize(labels, pred_labels.shape)
pred_labels = [pred_labels] # Wrap in a list
labels = [labels] # Wrap in a list
metrics = metric.compute(
predictions=pred_labels,
references=labels,
num_labels=len(id2label),
ignore_index=19,
reduce_labels=processor.do_reduce_labels,
)
return metrics['mean_iou']
def apply_color_palette(segmentation):
colored_segmentation = palette[segmentation]
return Image.fromarray(colored_segmentation.astype(np.uint8))
def create_legend():
# Define font and its size
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
# Calculate legend dimensions
num_classes = len(id2label)
legend_height = 20 * ((num_classes + 1) // 2) # Two items per row
legend_width = 250
# Create a blank image for the legend
legend = Image.new("RGB", (legend_width, legend_height), (255, 255, 255))
draw = ImageDraw.Draw(legend)
# Draw each color and its label
for i, (class_id, class_name) in enumerate(id2label.items()):
color = tuple(palette[class_id])
x = (i % 2) * 120
y = (i // 2) * 20
draw.rectangle([x, y, x + 20, y + 20], fill=color)
draw.text((x + 30, y + 5), class_name, fill=(0, 0, 0), font=font)
return legend
def inference(index, a):
"""Run inference on the input image with both models."""
image = sampled_dataset[index]['image'] # Fetch image from the sampled dataset
pixel_values = preprocess_image(image)
# Original model inference
with torch.no_grad():
original_outputs = original_model(pixel_values=pixel_values)
original_segmentation = postprocess_predictions(original_outputs.logits)
# LoRA model inference
with torch.no_grad():
lora_outputs = lora_model(pixel_values=pixel_values)
lora_segmentation = postprocess_predictions(lora_outputs.logits)
# Compute mIoU
true_labels = np.array(sampled_dataset[index]['semantic_segmentation'])
original_miou = compute_miou(original_outputs.logits.detach().cpu().numpy(), true_labels)
lora_miou = compute_miou(lora_outputs.logits.detach().cpu().numpy(), true_labels)
# original_miou = 0
# lora_miou = 0
# Apply color palette
original_segmentation_image = apply_color_palette(original_segmentation)
lora_segmentation_image = apply_color_palette(lora_segmentation)
# Create legend
legend = create_legend()
# Return the original image, the segmentations, and mIoU
return (
image,
original_segmentation_image,
lora_segmentation_image,
legend,
f"Original Model mIoU: {original_miou:.2f}",
f"LoRA Model mIoU: {lora_miou:.2f}"
)
# Create a list of image options for the user to select from
image_options = [(f"Image {i}", i) for i in range(len(sampled_dataset))]
# Create the Gradio interface
iface = gr.Interface(
fn=inference,
inputs=[
gr.Dropdown(label="Select Image", choices=image_options),
gr.Image(type="pil", label="Legend", value=create_legend)
],
outputs=[
gr.Image(type="pil", label="Selected Image"),
gr.Image(type="pil", label="Original Model Output"),
gr.Image(type="pil", label="LoRA Model Output"),
gr.Textbox(label="Original Model mIoU"),
gr.Textbox(label="LoRA Model mIoU")
],
title="Segformer Cityscapes Inference",
description="Select an image from the Cityscapes dataset to see the segmentation results from both the original and fine-tuned Segformer models.",
)
# Launch the interface
iface.launch() |