guimcc commited on
Commit
df8f91f
1 Parent(s): 58d9df7

updated app

Browse files
Files changed (1) hide show
  1. app.py +19 -31
app.py CHANGED
@@ -1,10 +1,10 @@
 
1
  import gradio as gr
2
  from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
3
  from torchvision.transforms import ColorJitter, functional as F
4
  from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
  import torch
7
- import torch.nn as nn
8
  from datasets import load_dataset
9
  import evaluate
10
 
@@ -18,9 +18,10 @@ lora_model_id = "guimCC/segformer-v0-gta-cityscapes"
18
  original_model = SegformerForSemanticSegmentation.from_pretrained(original_model_id).to(device)
19
  lora_model = SegformerForSemanticSegmentation.from_pretrained(lora_model_id).to(device)
20
 
21
- # Load the dataset and slice it
22
  dataset = load_dataset("Chris1/cityscapes", split="validation")
23
- sampled_dataset = [dataset[i] for i in range(10)] # Select the first 10 examples
 
24
 
25
  # Define your custom image processor
26
  jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
@@ -66,7 +67,7 @@ def compute_miou(logits, labels):
66
  with torch.no_grad():
67
  logits_tensor = torch.from_numpy(logits)
68
  # Scale the logits to the size of the label
69
- logits_tensor = nn.functional.interpolate(
70
  logits_tensor,
71
  size=labels.shape[-2:],
72
  mode="bilinear",
@@ -90,9 +91,13 @@ def compute_miou(logits, labels):
90
  reduce_labels=processor.do_reduce_labels,
91
  )
92
 
93
- return metrics['mean_iou']
94
-
95
-
 
 
 
 
96
  def apply_color_palette(segmentation):
97
  colored_segmentation = palette[segmentation]
98
  return Image.fromarray(colored_segmentation.astype(np.uint8))
@@ -123,9 +128,7 @@ def create_legend():
123
 
124
  return legend
125
 
126
-
127
-
128
- def inference(index, a):
129
  """Run inference on the input image with both models."""
130
  image = sampled_dataset[index]['image'] # Fetch image from the sampled dataset
131
  pixel_values = preprocess_image(image)
@@ -140,28 +143,15 @@ def inference(index, a):
140
  lora_outputs = lora_model(pixel_values=pixel_values)
141
  lora_segmentation = postprocess_predictions(lora_outputs.logits)
142
 
143
- # Compute mIoU
144
- true_labels = np.array(sampled_dataset[index]['semantic_segmentation'])
145
- original_miou = compute_miou(original_outputs.logits.detach().cpu().numpy(), true_labels)
146
- lora_miou = compute_miou(lora_outputs.logits.detach().cpu().numpy(), true_labels)
147
- # original_miou = 0
148
- # lora_miou = 0
149
-
150
  # Apply color palette
151
  original_segmentation_image = apply_color_palette(original_segmentation)
152
  lora_segmentation_image = apply_color_palette(lora_segmentation)
153
 
154
- # Create legend
155
- legend = create_legend()
156
-
157
  # Return the original image, the segmentations, and mIoU
158
  return (
159
  image,
160
  original_segmentation_image,
161
  lora_segmentation_image,
162
- legend,
163
- f"Original Model mIoU: {original_miou:.2f}",
164
- f"LoRA Model mIoU: {lora_miou:.2f}"
165
  )
166
 
167
  # Create a list of image options for the user to select from
@@ -175,15 +165,13 @@ iface = gr.Interface(
175
  gr.Image(type="pil", label="Legend", value=create_legend)
176
  ],
177
  outputs=[
178
- gr.Image(type="pil", label="Selected Image"),
179
- gr.Image(type="pil", label="Original Model Output"),
180
- gr.Image(type="pil", label="LoRA Model Output"),
181
- gr.Textbox(label="Original Model mIoU"),
182
- gr.Textbox(label="LoRA Model mIoU")
183
  ],
184
- title="Segformer Cityscapes Inference",
185
- description="Select an image from the Cityscapes dataset to see the segmentation results from both the original and fine-tuned Segformer models.",
186
  )
187
 
188
  # Launch the interface
189
- iface.launch()
 
1
+ import random
2
  import gradio as gr
3
  from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
4
  from torchvision.transforms import ColorJitter, functional as F
5
  from PIL import Image, ImageDraw, ImageFont
6
  import numpy as np
7
  import torch
 
8
  from datasets import load_dataset
9
  import evaluate
10
 
 
18
  original_model = SegformerForSemanticSegmentation.from_pretrained(original_model_id).to(device)
19
  lora_model = SegformerForSemanticSegmentation.from_pretrained(lora_model_id).to(device)
20
 
21
+ # Load the dataset and select 10 random images
22
  dataset = load_dataset("Chris1/cityscapes", split="validation")
23
+ #sampled_dataset = random.sample(list(dataset), 10) # Select 10 random examples
24
+ sampled_dataset = dataset[:10] # Select the first 10 examples
25
 
26
  # Define your custom image processor
27
  jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
 
67
  with torch.no_grad():
68
  logits_tensor = torch.from_numpy(logits)
69
  # Scale the logits to the size of the label
70
+ logits_tensor = F.interpolate(
71
  logits_tensor,
72
  size=labels.shape[-2:],
73
  mode="bilinear",
 
91
  reduce_labels=processor.do_reduce_labels,
92
  )
93
 
94
+ mean_iou = metrics.get('mean_iou', 0.0)
95
+
96
+ if np.isnan(mean_iou):
97
+ mean_iou = 0.0 # Handle NaN values gracefully
98
+
99
+ return mean_iou
100
+
101
  def apply_color_palette(segmentation):
102
  colored_segmentation = palette[segmentation]
103
  return Image.fromarray(colored_segmentation.astype(np.uint8))
 
128
 
129
  return legend
130
 
131
+ def inference(index, legend):
 
 
132
  """Run inference on the input image with both models."""
133
  image = sampled_dataset[index]['image'] # Fetch image from the sampled dataset
134
  pixel_values = preprocess_image(image)
 
143
  lora_outputs = lora_model(pixel_values=pixel_values)
144
  lora_segmentation = postprocess_predictions(lora_outputs.logits)
145
 
 
 
 
 
 
 
 
146
  # Apply color palette
147
  original_segmentation_image = apply_color_palette(original_segmentation)
148
  lora_segmentation_image = apply_color_palette(lora_segmentation)
149
 
 
 
 
150
  # Return the original image, the segmentations, and mIoU
151
  return (
152
  image,
153
  original_segmentation_image,
154
  lora_segmentation_image,
 
 
 
155
  )
156
 
157
  # Create a list of image options for the user to select from
 
165
  gr.Image(type="pil", label="Legend", value=create_legend)
166
  ],
167
  outputs=[
168
+ gr.Image(type="pil", label="Input Image"),
169
+ gr.Image(type="pil", label="Original Model Prediction"),
170
+ gr.Image(type="pil", label="LoRA Model Prediction"),
171
+
 
172
  ],
173
+ live=True
 
174
  )
175
 
176
  # Launch the interface
177
+ iface.launch()