guimcc commited on
Commit
fe7ff55
1 Parent(s): 0ff9fca
Files changed (1) hide show
  1. app.py +188 -3
app.py CHANGED
@@ -1,4 +1,189 @@
1
- import streamlit as st
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('select a value')
4
- st.write(x, 'sq is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
3
+ from torchvision.transforms import ColorJitter, functional as F
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from datasets import load_dataset
9
+ import evaluate
10
 
11
+ # Define the device
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Load the models
15
+ original_model_id = "guimCC/segformer-v0-gta"
16
+ lora_model_id = "guimCC/segformer-v0-gta-cityscapes"
17
+
18
+ original_model = SegformerForSemanticSegmentation.from_pretrained(original_model_id).to(device)
19
+ lora_model = SegformerForSemanticSegmentation.from_pretrained(lora_model_id).to(device)
20
+
21
+ # Load the dataset and slice it
22
+ dataset = load_dataset("Chris1/cityscapes", split="validation")
23
+ sampled_dataset = [dataset[i] for i in range(10)] # Select the first 10 examples
24
+
25
+ # Define your custom image processor
26
+ jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
27
+
28
+ # Initialize mIoU metric
29
+ metric = evaluate.load("mean_iou")
30
+
31
+ # Define id2label and processor if not already defined
32
+ id2label = {
33
+ 0: 'road', 1: 'sidewalk', 2: 'building', 3: 'wall', 4: 'fence', 5: 'pole',
34
+ 6: 'traffic light', 7: 'traffic sign', 8: 'vegetation', 9: 'terrain',
35
+ 10: 'sky', 11: 'person', 12: 'rider', 13: 'car', 14: 'truck', 15: 'bus',
36
+ 16: 'train', 17: 'motorcycle', 18: 'bicycle', 19: 'ignore'
37
+ }
38
+ processor = SegformerImageProcessor()
39
+
40
+ # Cityscapes color palette
41
+ palette = np.array([
42
+ [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
43
+ [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
44
+ [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
45
+ [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32], [0, 0, 0]
46
+ ])
47
+
48
+ def handle_grayscale_image(image):
49
+ np_image = np.array(image)
50
+ if np_image.ndim == 2: # Grayscale image
51
+ np_image = np.tile(np.expand_dims(np_image, -1), (1, 1, 3))
52
+ return Image.fromarray(np_image)
53
+
54
+ def preprocess_image(image):
55
+ image = handle_grayscale_image(image)
56
+ image = jitter(image) # Apply color jitter
57
+ pixel_values = F.to_tensor(image).unsqueeze(0) # Convert to tensor and add batch dimension
58
+ return pixel_values.to(device)
59
+
60
+ def postprocess_predictions(logits):
61
+ logits = logits.squeeze().detach().cpu().numpy()
62
+ segmentation = np.argmax(logits, axis=0).astype(np.uint8) # Convert to 8-bit integer
63
+ return segmentation
64
+
65
+ def compute_miou(logits, labels):
66
+ with torch.no_grad():
67
+ logits_tensor = torch.from_numpy(logits)
68
+ # Scale the logits to the size of the label
69
+ logits_tensor = nn.functional.interpolate(
70
+ logits_tensor,
71
+ size=labels.shape[-2:],
72
+ mode="bilinear",
73
+ align_corners=False,
74
+ ).argmax(dim=1)
75
+
76
+ pred_labels = logits_tensor.detach().cpu().numpy()
77
+
78
+ # Ensure the shapes of pred_labels and labels match
79
+ if pred_labels.shape != labels.shape:
80
+ labels = np.resize(labels, pred_labels.shape)
81
+
82
+ pred_labels = [pred_labels] # Wrap in a list
83
+ labels = [labels] # Wrap in a list
84
+
85
+ metrics = metric.compute(
86
+ predictions=pred_labels,
87
+ references=labels,
88
+ num_labels=len(id2label),
89
+ ignore_index=19,
90
+ reduce_labels=processor.do_reduce_labels,
91
+ )
92
+
93
+ return metrics['mean_iou']
94
+
95
+
96
+ def apply_color_palette(segmentation):
97
+ colored_segmentation = palette[segmentation]
98
+ return Image.fromarray(colored_segmentation.astype(np.uint8))
99
+
100
+ def create_legend():
101
+ # Define font and its size
102
+ try:
103
+ font = ImageFont.truetype("arial.ttf", 15)
104
+ except IOError:
105
+ font = ImageFont.load_default()
106
+
107
+ # Calculate legend dimensions
108
+ num_classes = len(id2label)
109
+ legend_height = 20 * ((num_classes + 1) // 2) # Two items per row
110
+ legend_width = 250
111
+
112
+ # Create a blank image for the legend
113
+ legend = Image.new("RGB", (legend_width, legend_height), (255, 255, 255))
114
+ draw = ImageDraw.Draw(legend)
115
+
116
+ # Draw each color and its label
117
+ for i, (class_id, class_name) in enumerate(id2label.items()):
118
+ color = tuple(palette[class_id])
119
+ x = (i % 2) * 120
120
+ y = (i // 2) * 20
121
+ draw.rectangle([x, y, x + 20, y + 20], fill=color)
122
+ draw.text((x + 30, y + 5), class_name, fill=(0, 0, 0), font=font)
123
+
124
+ return legend
125
+
126
+
127
+
128
+ def inference(index, a):
129
+ """Run inference on the input image with both models."""
130
+ image = sampled_dataset[index]['image'] # Fetch image from the sampled dataset
131
+ pixel_values = preprocess_image(image)
132
+
133
+ # Original model inference
134
+ with torch.no_grad():
135
+ original_outputs = original_model(pixel_values=pixel_values)
136
+ original_segmentation = postprocess_predictions(original_outputs.logits)
137
+
138
+ # LoRA model inference
139
+ with torch.no_grad():
140
+ lora_outputs = lora_model(pixel_values=pixel_values)
141
+ lora_segmentation = postprocess_predictions(lora_outputs.logits)
142
+
143
+ # Compute mIoU
144
+ true_labels = np.array(sampled_dataset[index]['semantic_segmentation'])
145
+ original_miou = compute_miou(original_outputs.logits.detach().cpu().numpy(), true_labels)
146
+ lora_miou = compute_miou(lora_outputs.logits.detach().cpu().numpy(), true_labels)
147
+ # original_miou = 0
148
+ # lora_miou = 0
149
+
150
+ # Apply color palette
151
+ original_segmentation_image = apply_color_palette(original_segmentation)
152
+ lora_segmentation_image = apply_color_palette(lora_segmentation)
153
+
154
+ # Create legend
155
+ legend = create_legend()
156
+
157
+ # Return the original image, the segmentations, and mIoU
158
+ return (
159
+ image,
160
+ original_segmentation_image,
161
+ lora_segmentation_image,
162
+ legend,
163
+ f"Original Model mIoU: {original_miou:.2f}",
164
+ f"LoRA Model mIoU: {lora_miou:.2f}"
165
+ )
166
+
167
+ # Create a list of image options for the user to select from
168
+ image_options = [(f"Image {i}", i) for i in range(len(sampled_dataset))]
169
+
170
+ # Create the Gradio interface
171
+ iface = gr.Interface(
172
+ fn=inference,
173
+ inputs=[
174
+ gr.Dropdown(label="Select Image", choices=image_options),
175
+ gr.Image(type="pil", label="Legend", value=create_legend)
176
+ ],
177
+ outputs=[
178
+ gr.Image(type="pil", label="Selected Image"),
179
+ gr.Image(type="pil", label="Original Model Output"),
180
+ gr.Image(type="pil", label="LoRA Model Output"),
181
+ gr.Textbox(label="Original Model mIoU"),
182
+ gr.Textbox(label="LoRA Model mIoU")
183
+ ],
184
+ title="Segformer Cityscapes Inference",
185
+ description="Select an image from the Cityscapes dataset to see the segmentation results from both the original and fine-tuned Segformer models.",
186
+ )
187
+
188
+ # Launch the interface
189
+ iface.launch()