Spaces:
Runtime error
Runtime error
import collections | |
import os | |
from os.path import join | |
import io | |
import datetime | |
from dateutil.relativedelta import relativedelta | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch.multiprocessing | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wget | |
from PIL import Image | |
from scipy.optimize import linear_sum_assignment | |
from torch._six import string_classes | |
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format | |
from torchmetrics import Metric | |
from torchvision import models | |
from torchvision import transforms as T | |
from torch.utils.tensorboard.summary import hparams | |
import matplotlib as mpl | |
from PIL import Image | |
import matplotlib as mpl | |
import torch.multiprocessing | |
import torchvision.transforms as T | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import numpy as np | |
from plotly.subplots import make_subplots | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") | |
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background') | |
mapping_class = { | |
"Buildings": 1, | |
"Cultivation": 2, | |
"Natural green": 3, | |
"Wetland": 4, | |
"Water": 5, | |
"Infrastructure": 6, | |
"Background": 0, | |
} | |
score_attribution = { | |
"Buildings" : 0., | |
"Cultivation": 0.3, | |
"Natural green": 1., | |
"Wetland": 0.9, | |
"Water": 0.9, | |
"Infrastructure": 0., | |
"Background": 0. | |
} | |
bounds = list(np.arange(len(mapping_class.keys()) + 1) + 1) | |
cmap = mpl.colors.ListedColormap(colors) | |
norm = mpl.colors.BoundaryNorm(bounds, cmap.N) | |
def compute_biodiv_score(class_image): | |
"""Compute the biodiversity score of an image | |
Args: | |
image (_type_): _description_ | |
Returns: | |
biodiversity_score: the biodiversity score associated to the landscape of the image | |
""" | |
score_matrice = class_image.copy().astype(int) | |
for key in mapping_class.keys(): | |
score_matrice = np.where(score_matrice==mapping_class[key], score_attribution[key], score_matrice) | |
number_of_pixel = np.prod(list(score_matrice.shape)) | |
score = np.sum(score_matrice)/number_of_pixel | |
score_details = { | |
key: np.sum(np.where(class_image == mapping_class[key], 1, 0)) | |
for key in mapping_class.keys() | |
if key not in ["background"] | |
} | |
return score, score_details | |
def plot_image(months, imgs, imgs_label, nb_values, scores): | |
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True) | |
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True) | |
# Scores | |
fig = make_subplots( | |
rows=1, cols=4, | |
specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "indicator"}]], | |
subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores") | |
) | |
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1) | |
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2) | |
fig.add_trace(go.Pie(labels = class_names, | |
values = [nb_values[0][key] for key in mapping_class.keys()], | |
marker_colors = colors, | |
name="Segment repartition", | |
textposition='inside', | |
texttemplate = "%{percent:.0%}", | |
textfont_size=14 | |
), | |
row=1, col=3) | |
fig.add_trace(go.Indicator(value=scores[0]), row=1, col=4) | |
fig.update_layout( | |
legend=dict( | |
xanchor = "center", | |
yanchor="top", | |
y=-0.1, | |
x = 0.5, | |
orientation="h") | |
) | |
fig.update( | |
layout={ | |
"xaxis": { | |
"range": [0,imgs[0].shape[1]+1/100000], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
"yaxis": { | |
"range": [imgs[0].shape[0]+1/100000,0], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at y=0 | |
'visible': False,}, | |
"xaxis1": { | |
"range": [0,imgs[0].shape[1]+1/100000], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
"yaxis1": { | |
"range": [imgs[0].shape[0]+1/100000,0], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at y=0 | |
'visible': False,} | |
},) | |
fig.update_xaxes(row=1, col=2, visible=False) | |
fig.update_yaxes(row=1, col=2, visible=False) | |
return fig | |
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) : | |
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True) | |
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True) | |
# Scores | |
scatters = [ | |
go.Scatter( | |
x=months[:i+1], | |
y=scores[:i+1], | |
mode="lines+markers+text", | |
marker_color="black", | |
text = [f"{score:.2f}" for score in scores[:i+1]], | |
textposition="top center" | |
) for i in range(len(scores)) | |
] | |
# Scores | |
fig = make_subplots( | |
rows=1, cols=4, | |
specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]], | |
subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores") | |
) | |
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1) | |
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2) | |
fig.add_trace(go.Pie(labels = class_names, | |
values = [nb_values[0][key] for key in mapping_class.keys()], | |
marker_colors = colors, | |
name="Segment repartition", | |
textposition='inside', | |
texttemplate = "%{percent:.0%}", | |
textfont_size=14 | |
), | |
row=1, col=3) | |
fig.add_trace(scatters[0], row=1, col=4) | |
fig.update_traces(selector=dict(type='scatter')) | |
number_frames = len(imgs) | |
frames = [dict( | |
name = k, | |
data = [ fig2["frames"][k]["data"][0], | |
fig3["frames"][k]["data"][0], | |
go.Pie(labels = class_names, | |
values = [nb_values[k][key] for key in mapping_class.keys()], | |
marker_colors = colors, | |
name="Segment repartition", | |
textposition='inside', | |
texttemplate = "%{percent:.0%}", | |
textfont_size=14 | |
), | |
scatters[k] | |
], | |
traces=[0, 1, 2, 3] # the elements of the list [0,1,2] give info on the traces in fig.data | |
# that are updated by the above three go.Scatter instances | |
) for k in range(number_frames)] | |
updatemenus = [dict(type='buttons', | |
buttons=[dict(label='Play', | |
method='animate', | |
args=[[f'{k}' for k in range(number_frames)], | |
dict(frame=dict(duration=500, redraw=False), | |
transition=dict(duration=0), | |
easing='linear', | |
fromcurrent=True, | |
mode='immediate' | |
)])], | |
direction= 'left', | |
pad=dict(t=85), | |
showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top') | |
] | |
sliders = [{'yanchor': 'top', | |
'xanchor': 'left', | |
'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'}, | |
'transition': {'duration': 500.0, 'easing': 'linear'}, | |
'pad': {'b': 10, 't': 50}, | |
'len': 0.9, 'x': 0.1, 'y': 0, | |
'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False}, | |
'transition': {'duration': 0, 'easing': 'linear'}}], | |
'label': months[k], 'method': 'animate'} for k in range(number_frames) | |
]}] | |
fig.update(frames=frames) | |
for i,fr in enumerate(fig["frames"]): | |
fr.update( | |
layout={ | |
"xaxis": { | |
"range": [0,imgs[0].shape[1]+i/100000], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
"yaxis": { | |
"range": [imgs[0].shape[0]+i/100000,0], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
"xaxis1": { | |
"range": [0,imgs[0].shape[1]+i/100000], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
"yaxis1": { | |
"range": [imgs[0].shape[0]+i/100000,0], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
}) | |
start_date = datetime.datetime.strptime(months[0], "%Y-%m-%d") - relativedelta(months=1) | |
end_date = datetime.datetime.strptime(months[-1], "%Y-%m-%d") + relativedelta(months=1) | |
interval = [start_date.strftime("%Y-%m-%d"),end_date.strftime("%Y-%m-%d")] | |
fig.update( | |
layout={ | |
"xaxis": { | |
"range": [0,imgs[0].shape[1]+i/100000], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
"yaxis": { | |
"range": [imgs[0].shape[0]+i/100000,0], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at y=0 | |
'visible': False,}, | |
"xaxis2": { | |
"range": [0,imgs[0].shape[1]+i/100000], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
"yaxis2": { | |
"range": [imgs[0].shape[0]+i/100000,0], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at y=0 | |
'visible': False,}, | |
"xaxis3": { | |
"dtick":"M3", | |
"range":interval | |
}, | |
"yaxis3": { | |
'range': [min(scores)*0.9, max(scores)* 1.1], | |
'showgrid': False, | |
'zeroline': False, | |
'visible': True | |
} | |
} | |
) | |
fig.update_layout(updatemenus=updatemenus, | |
sliders=sliders, | |
legend=dict( | |
xanchor = "center", | |
yanchor="top", | |
y=-0.1, | |
x = 0.5, | |
orientation="h") | |
) | |
fig.update_layout(margin=dict(b=0, r=0)) | |
return fig | |
def transform_to_pil(output, alpha=0.3): | |
# Transform img with torch | |
img = torch.moveaxis(prep_for_plot(output['img']),-1,0) | |
img=T.ToPILImage()(img) | |
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)]) | |
labels = np.array(output['linear_preds'])-1 | |
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8)) | |
# Overlay labels with img wit alpha | |
background = img.convert("RGBA") | |
overlay = label.convert("RGBA") | |
labeled_img = Image.blend(background, overlay, alpha) | |
return img, label, labeled_img | |
def prep_for_plot(img, rescale=True, resize=None): | |
if resize is not None: | |
img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear") | |
else: | |
img = img.unsqueeze(0) | |
plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0) | |
if rescale: | |
plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min()) | |
return plot_img | |
def add_plot(writer, name, step): | |
buf = io.BytesIO() | |
plt.savefig(buf, format='jpeg', dpi=100) | |
buf.seek(0) | |
image = Image.open(buf) | |
image = T.ToTensor()(image) | |
writer.add_image(name, image, step) | |
plt.clf() | |
plt.close() | |
def shuffle(x): | |
return x[torch.randperm(x.shape[0])] | |
def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step): | |
exp, ssi, sei = hparams(hparam_dict, metric_dict) | |
writer.file_writer.add_summary(exp) | |
writer.file_writer.add_summary(ssi) | |
writer.file_writer.add_summary(sei) | |
for k, v in metric_dict.items(): | |
writer.add_scalar(k, v, global_step) | |
def resize(classes: torch.Tensor, size: int): | |
return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False) | |
def one_hot_feats(labels, n_classes): | |
return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32) | |
def load_model(model_type, data_dir): | |
if model_type == "robust_resnet50": | |
model = models.resnet50(pretrained=False) | |
model_file = join(data_dir, 'imagenet_l2_3_0.pt') | |
if not os.path.exists(model_file): | |
wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt", | |
model_file) | |
model_weights = torch.load(model_file) | |
model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if | |
'model' in name} | |
model.load_state_dict(model_weights_modified) | |
model = nn.Sequential(*list(model.children())[:-1]) | |
elif model_type == "densecl": | |
model = models.resnet50(pretrained=False) | |
model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth') | |
if not os.path.exists(model_file): | |
wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download", | |
model_file) | |
model_weights = torch.load(model_file) | |
# model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if | |
# 'model' in name} | |
model.load_state_dict(model_weights['state_dict'], strict=False) | |
model = nn.Sequential(*list(model.children())[:-1]) | |
elif model_type == "resnet50": | |
model = models.resnet50(pretrained=True) | |
model = nn.Sequential(*list(model.children())[:-1]) | |
elif model_type == "mocov2": | |
model = models.resnet50(pretrained=False) | |
model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar') | |
if not os.path.exists(model_file): | |
wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/" | |
"moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file) | |
checkpoint = torch.load(model_file) | |
# rename moco pre-trained keys | |
state_dict = checkpoint['state_dict'] | |
for k in list(state_dict.keys()): | |
# retain only encoder_q up to before the embedding layer | |
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): | |
# remove prefix | |
state_dict[k[len("module.encoder_q."):]] = state_dict[k] | |
# delete renamed or unused k | |
del state_dict[k] | |
msg = model.load_state_dict(state_dict, strict=False) | |
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} | |
model = nn.Sequential(*list(model.children())[:-1]) | |
elif model_type == "densenet121": | |
model = models.densenet121(pretrained=True) | |
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) | |
elif model_type == "vgg11": | |
model = models.vgg11(pretrained=True) | |
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) | |
else: | |
raise ValueError("No model: {} found".format(model_type)) | |
model.eval() | |
model.cuda() | |
return model | |
class UnNormalize(object): | |
def __init__(self, mean, std): | |
self.mean = mean | |
self.std = std | |
def __call__(self, image): | |
image2 = torch.clone(image) | |
for t, m, s in zip(image2, self.mean, self.std): | |
t.mul_(s).add_(m) | |
return image2 | |
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
class ToTargetTensor(object): | |
def __call__(self, target): | |
return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) | |
def prep_args(): | |
import sys | |
old_args = sys.argv | |
new_args = [old_args.pop(0)] | |
while len(old_args) > 0: | |
arg = old_args.pop(0) | |
if len(arg.split("=")) == 2: | |
new_args.append(arg) | |
elif arg.startswith("--"): | |
new_args.append(arg[2:] + "=" + old_args.pop(0)) | |
else: | |
raise ValueError("Unexpected arg style {}".format(arg)) | |
sys.argv = new_args | |
def get_transform(res, is_label, crop_type): | |
if crop_type == "center": | |
cropper = T.CenterCrop(res) | |
elif crop_type == "random": | |
cropper = T.RandomCrop(res) | |
elif crop_type is None: | |
cropper = T.Lambda(lambda x: x) | |
res = (res, res) | |
else: | |
raise ValueError("Unknown Cropper {}".format(crop_type)) | |
if is_label: | |
return T.Compose([T.Resize(res, Image.NEAREST), | |
cropper, | |
ToTargetTensor()]) | |
else: | |
return T.Compose([T.Resize(res, Image.NEAREST), | |
cropper, | |
T.ToTensor(), | |
normalize]) | |
def _remove_axes(ax): | |
ax.xaxis.set_major_formatter(plt.NullFormatter()) | |
ax.yaxis.set_major_formatter(plt.NullFormatter()) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
def remove_axes(axes): | |
if len(axes.shape) == 2: | |
for ax1 in axes: | |
for ax in ax1: | |
_remove_axes(ax) | |
else: | |
for ax in axes: | |
_remove_axes(ax) | |
class UnsupervisedMetrics(Metric): | |
def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool, | |
dist_sync_on_step=True): | |
# call `self.add_state`for every internal state that is needed for the metrics computations | |
# dist_reduce_fx indicates the function that should be used to reduce | |
# state from multiple processes | |
super().__init__(dist_sync_on_step=dist_sync_on_step) | |
self.n_classes = n_classes | |
self.extra_clusters = extra_clusters | |
self.compute_hungarian = compute_hungarian | |
self.prefix = prefix | |
self.add_state("stats", | |
default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64), | |
dist_reduce_fx="sum") | |
def update(self, preds: torch.Tensor, target: torch.Tensor): | |
with torch.no_grad(): | |
actual = target.reshape(-1) | |
preds = preds.reshape(-1) | |
mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes) | |
actual = actual[mask] | |
preds = preds[mask] | |
self.stats += torch.bincount( | |
(self.n_classes + self.extra_clusters) * actual + preds, | |
minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \ | |
.reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device) | |
def map_clusters(self, clusters): | |
if self.extra_clusters == 0: | |
return torch.tensor(self.assignments[1])[clusters] | |
else: | |
missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))) | |
cluster_to_class = self.assignments[1] | |
for missing_entry in missing: | |
if missing_entry == cluster_to_class.shape[0]: | |
cluster_to_class = np.append(cluster_to_class, -1) | |
else: | |
cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1) | |
cluster_to_class = torch.tensor(cluster_to_class) | |
return cluster_to_class[clusters] | |
def compute(self): | |
if self.compute_hungarian: | |
self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True) | |
# print(self.assignments) | |
if self.extra_clusters == 0: | |
self.histogram = self.stats[np.argsort(self.assignments[1]), :] | |
if self.extra_clusters > 0: | |
self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True) | |
histogram = self.stats[self.assignments_t[1], :] | |
missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])) | |
new_row = self.stats[missing, :].sum(0, keepdim=True) | |
histogram = torch.cat([histogram, new_row], axis=0) | |
new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device) | |
self.histogram = torch.cat([histogram, new_col], axis=1) | |
else: | |
self.assignments = (torch.arange(self.n_classes).unsqueeze(1), | |
torch.arange(self.n_classes).unsqueeze(1)) | |
self.histogram = self.stats | |
tp = torch.diag(self.histogram) | |
fp = torch.sum(self.histogram, dim=0) - tp | |
fn = torch.sum(self.histogram, dim=1) - tp | |
iou = tp / (tp + fp + fn) | |
prc = tp / (tp + fn) | |
opc = torch.sum(tp) / torch.sum(self.histogram) | |
metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(), | |
self.prefix + "Accuracy": opc.item()} | |
return {k: 100 * v for k, v in metric_dict.items()} | |
def flexible_collate(batch): | |
r"""Puts each data field into a tensor with outer dimension batch size""" | |
elem = batch[0] | |
elem_type = type(elem) | |
if isinstance(elem, torch.Tensor): | |
out = None | |
if torch.utils.data.get_worker_info() is not None: | |
# If we're in a background process, concatenate directly into a | |
# shared memory tensor to avoid an extra copy | |
numel = sum([x.numel() for x in batch]) | |
storage = elem.storage()._new_shared(numel) | |
out = elem.new(storage) | |
try: | |
return torch.stack(batch, 0, out=out) | |
except RuntimeError: | |
return batch | |
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
and elem_type.__name__ != 'string_': | |
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': | |
# array of string classes and object | |
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |
return flexible_collate([torch.as_tensor(b) for b in batch]) | |
elif elem.shape == (): # scalars | |
return torch.as_tensor(batch) | |
elif isinstance(elem, float): | |
return torch.tensor(batch, dtype=torch.float64) | |
elif isinstance(elem, int): | |
return torch.tensor(batch) | |
elif isinstance(elem, string_classes): | |
return batch | |
elif isinstance(elem, collections.abc.Mapping): | |
return {key: flexible_collate([d[key] for d in batch]) for key in elem} | |
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | |
return elem_type(*(flexible_collate(samples) for samples in zip(*batch))) | |
elif isinstance(elem, collections.abc.Sequence): | |
# check to make sure that the elements in batch have consistent size | |
it = iter(batch) | |
elem_size = len(next(it)) | |
if not all(len(elem) == elem_size for elem in it): | |
raise RuntimeError('each element in list of batch should be of equal size') | |
transposed = zip(*batch) | |
return [flexible_collate(samples) for samples in transposed] | |
raise TypeError(default_collate_err_msg_format.format(elem_type)) | |