import collections import os from os.path import join import io import datetime from dateutil.relativedelta import relativedelta import matplotlib.pyplot as plt import numpy as np import torch.multiprocessing import torch.nn as nn import torch.nn.functional as F import wget from PIL import Image from scipy.optimize import linear_sum_assignment from torch._six import string_classes from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format from torchmetrics import Metric from torchvision import models from torchvision import transforms as T from torch.utils.tensorboard.summary import hparams import matplotlib as mpl from PIL import Image import matplotlib as mpl import torch.multiprocessing import torchvision.transforms as T import plotly.graph_objects as go import plotly.express as px import numpy as np from plotly.subplots import make_subplots import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background') mapping_class = { "Buildings": 1, "Cultivation": 2, "Natural green": 3, "Wetland": 4, "Water": 5, "Infrastructure": 6, "Background": 0, } score_attribution = { "Buildings" : 0., "Cultivation": 0.3, "Natural green": 1., "Wetland": 0.9, "Water": 0.9, "Infrastructure": 0., "Background": 0. } bounds = list(np.arange(len(mapping_class.keys()) + 1) + 1) cmap = mpl.colors.ListedColormap(colors) norm = mpl.colors.BoundaryNorm(bounds, cmap.N) def compute_biodiv_score(class_image): """Compute the biodiversity score of an image Args: image (_type_): _description_ Returns: biodiversity_score: the biodiversity score associated to the landscape of the image """ score_matrice = class_image.copy().astype(int) for key in mapping_class.keys(): score_matrice = np.where(score_matrice==mapping_class[key], score_attribution[key], score_matrice) number_of_pixel = np.prod(list(score_matrice.shape)) score = np.sum(score_matrice)/number_of_pixel score_details = { key: np.sum(np.where(class_image == mapping_class[key], 1, 0)) for key in mapping_class.keys() if key not in ["background"] } return score, score_details def plot_image(months, imgs, imgs_label, nb_values, scores): fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True) fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True) # Scores fig = make_subplots( rows=1, cols=4, specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "indicator"}]], subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores") ) fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1) fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2) fig.add_trace(go.Pie(labels = class_names, values = [nb_values[0][key] for key in mapping_class.keys()], marker_colors = colors, name="Segment repartition", textposition='inside', texttemplate = "%{percent:.0%}", textfont_size=14 ), row=1, col=3) fig.add_trace(go.Indicator(value=scores[0]), row=1, col=4) fig.update_layout( legend=dict( xanchor = "center", yanchor="top", y=-0.1, x = 0.5, orientation="h") ) fig.update( layout={ "xaxis": { "range": [0,imgs[0].shape[1]+1/100000], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "yaxis": { "range": [imgs[0].shape[0]+1/100000,0], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at y=0 'visible': False,}, "xaxis1": { "range": [0,imgs[0].shape[1]+1/100000], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "yaxis1": { "range": [imgs[0].shape[0]+1/100000,0], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at y=0 'visible': False,} },) fig.update_xaxes(row=1, col=2, visible=False) fig.update_yaxes(row=1, col=2, visible=False) return fig def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) : fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True) fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True) # Scores scatters = [ go.Scatter( x=months[:i+1], y=scores[:i+1], mode="lines+markers+text", marker_color="black", text = [f"{score:.2f}" for score in scores[:i+1]], textposition="top center" ) for i in range(len(scores)) ] # Scores fig = make_subplots( rows=1, cols=4, specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]], subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores") ) fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1) fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2) fig.add_trace(go.Pie(labels = class_names, values = [nb_values[0][key] for key in mapping_class.keys()], marker_colors = colors, name="Segment repartition", textposition='inside', texttemplate = "%{percent:.0%}", textfont_size=14 ), row=1, col=3) fig.add_trace(scatters[0], row=1, col=4) fig.update_traces(selector=dict(type='scatter')) number_frames = len(imgs) frames = [dict( name = k, data = [ fig2["frames"][k]["data"][0], fig3["frames"][k]["data"][0], go.Pie(labels = class_names, values = [nb_values[k][key] for key in mapping_class.keys()], marker_colors = colors, name="Segment repartition", textposition='inside', texttemplate = "%{percent:.0%}", textfont_size=14 ), scatters[k] ], traces=[0, 1, 2, 3] # the elements of the list [0,1,2] give info on the traces in fig.data # that are updated by the above three go.Scatter instances ) for k in range(number_frames)] updatemenus = [dict(type='buttons', buttons=[dict(label='Play', method='animate', args=[[f'{k}' for k in range(number_frames)], dict(frame=dict(duration=500, redraw=False), transition=dict(duration=0), easing='linear', fromcurrent=True, mode='immediate' )])], direction= 'left', pad=dict(t=85), showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top') ] sliders = [{'yanchor': 'top', 'xanchor': 'left', 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'}, 'transition': {'duration': 500.0, 'easing': 'linear'}, 'pad': {'b': 10, 't': 50}, 'len': 0.9, 'x': 0.1, 'y': 0, 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False}, 'transition': {'duration': 0, 'easing': 'linear'}}], 'label': months[k], 'method': 'animate'} for k in range(number_frames) ]}] fig.update(frames=frames) for i,fr in enumerate(fig["frames"]): fr.update( layout={ "xaxis": { "range": [0,imgs[0].shape[1]+i/100000], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "yaxis": { "range": [imgs[0].shape[0]+i/100000,0], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "xaxis1": { "range": [0,imgs[0].shape[1]+i/100000], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "yaxis1": { "range": [imgs[0].shape[0]+i/100000,0], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, }) start_date = datetime.datetime.strptime(months[0], "%Y-%m-%d") - relativedelta(months=1) end_date = datetime.datetime.strptime(months[-1], "%Y-%m-%d") + relativedelta(months=1) interval = [start_date.strftime("%Y-%m-%d"),end_date.strftime("%Y-%m-%d")] fig.update( layout={ "xaxis": { "range": [0,imgs[0].shape[1]+i/100000], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "yaxis": { "range": [imgs[0].shape[0]+i/100000,0], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at y=0 'visible': False,}, "xaxis2": { "range": [0,imgs[0].shape[1]+i/100000], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "yaxis2": { "range": [imgs[0].shape[0]+i/100000,0], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at y=0 'visible': False,}, "xaxis3": { "dtick":"M3", "range":interval }, "yaxis3": { 'range': [min(scores)*0.9, max(scores)* 1.1], 'showgrid': False, 'zeroline': False, 'visible': True } } ) fig.update_layout(updatemenus=updatemenus, sliders=sliders, legend=dict( xanchor = "center", yanchor="top", y=-0.1, x = 0.5, orientation="h") ) fig.update_layout(margin=dict(b=0, r=0)) return fig def transform_to_pil(output, alpha=0.3): # Transform img with torch img = torch.moveaxis(prep_for_plot(output['img']),-1,0) img=T.ToPILImage()(img) cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)]) labels = np.array(output['linear_preds'])-1 label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8)) # Overlay labels with img wit alpha background = img.convert("RGBA") overlay = label.convert("RGBA") labeled_img = Image.blend(background, overlay, alpha) return img, label, labeled_img def prep_for_plot(img, rescale=True, resize=None): if resize is not None: img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear") else: img = img.unsqueeze(0) plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0) if rescale: plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min()) return plot_img def add_plot(writer, name, step): buf = io.BytesIO() plt.savefig(buf, format='jpeg', dpi=100) buf.seek(0) image = Image.open(buf) image = T.ToTensor()(image) writer.add_image(name, image, step) plt.clf() plt.close() @torch.jit.script def shuffle(x): return x[torch.randperm(x.shape[0])] def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step): exp, ssi, sei = hparams(hparam_dict, metric_dict) writer.file_writer.add_summary(exp) writer.file_writer.add_summary(ssi) writer.file_writer.add_summary(sei) for k, v in metric_dict.items(): writer.add_scalar(k, v, global_step) @torch.jit.script def resize(classes: torch.Tensor, size: int): return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False) def one_hot_feats(labels, n_classes): return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32) def load_model(model_type, data_dir): if model_type == "robust_resnet50": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'imagenet_l2_3_0.pt') if not os.path.exists(model_file): wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt", model_file) model_weights = torch.load(model_file) model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if 'model' in name} model.load_state_dict(model_weights_modified) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "densecl": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth') if not os.path.exists(model_file): wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download", model_file) model_weights = torch.load(model_file) # model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if # 'model' in name} model.load_state_dict(model_weights['state_dict'], strict=False) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "resnet50": model = models.resnet50(pretrained=True) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "mocov2": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar') if not os.path.exists(model_file): wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/" "moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file) checkpoint = torch.load(model_file) # rename moco pre-trained keys state_dict = checkpoint['state_dict'] for k in list(state_dict.keys()): # retain only encoder_q up to before the embedding layer if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): # remove prefix state_dict[k[len("module.encoder_q."):]] = state_dict[k] # delete renamed or unused k del state_dict[k] msg = model.load_state_dict(state_dict, strict=False) assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "densenet121": model = models.densenet121(pretrained=True) model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) elif model_type == "vgg11": model = models.vgg11(pretrained=True) model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) else: raise ValueError("No model: {} found".format(model_type)) model.eval() model.cuda() return model class UnNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image): image2 = torch.clone(image) for t, m, s in zip(image2, self.mean, self.std): t.mul_(s).add_(m) return image2 normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) class ToTargetTensor(object): def __call__(self, target): return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) def prep_args(): import sys old_args = sys.argv new_args = [old_args.pop(0)] while len(old_args) > 0: arg = old_args.pop(0) if len(arg.split("=")) == 2: new_args.append(arg) elif arg.startswith("--"): new_args.append(arg[2:] + "=" + old_args.pop(0)) else: raise ValueError("Unexpected arg style {}".format(arg)) sys.argv = new_args def get_transform(res, is_label, crop_type): if crop_type == "center": cropper = T.CenterCrop(res) elif crop_type == "random": cropper = T.RandomCrop(res) elif crop_type is None: cropper = T.Lambda(lambda x: x) res = (res, res) else: raise ValueError("Unknown Cropper {}".format(crop_type)) if is_label: return T.Compose([T.Resize(res, Image.NEAREST), cropper, ToTargetTensor()]) else: return T.Compose([T.Resize(res, Image.NEAREST), cropper, T.ToTensor(), normalize]) def _remove_axes(ax): ax.xaxis.set_major_formatter(plt.NullFormatter()) ax.yaxis.set_major_formatter(plt.NullFormatter()) ax.set_xticks([]) ax.set_yticks([]) def remove_axes(axes): if len(axes.shape) == 2: for ax1 in axes: for ax in ax1: _remove_axes(ax) else: for ax in axes: _remove_axes(ax) class UnsupervisedMetrics(Metric): def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool, dist_sync_on_step=True): # call `self.add_state`for every internal state that is needed for the metrics computations # dist_reduce_fx indicates the function that should be used to reduce # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.n_classes = n_classes self.extra_clusters = extra_clusters self.compute_hungarian = compute_hungarian self.prefix = prefix self.add_state("stats", default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): with torch.no_grad(): actual = target.reshape(-1) preds = preds.reshape(-1) mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes) actual = actual[mask] preds = preds[mask] self.stats += torch.bincount( (self.n_classes + self.extra_clusters) * actual + preds, minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \ .reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device) def map_clusters(self, clusters): if self.extra_clusters == 0: return torch.tensor(self.assignments[1])[clusters] else: missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))) cluster_to_class = self.assignments[1] for missing_entry in missing: if missing_entry == cluster_to_class.shape[0]: cluster_to_class = np.append(cluster_to_class, -1) else: cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1) cluster_to_class = torch.tensor(cluster_to_class) return cluster_to_class[clusters] def compute(self): if self.compute_hungarian: self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True) # print(self.assignments) if self.extra_clusters == 0: self.histogram = self.stats[np.argsort(self.assignments[1]), :] if self.extra_clusters > 0: self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True) histogram = self.stats[self.assignments_t[1], :] missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])) new_row = self.stats[missing, :].sum(0, keepdim=True) histogram = torch.cat([histogram, new_row], axis=0) new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device) self.histogram = torch.cat([histogram, new_col], axis=1) else: self.assignments = (torch.arange(self.n_classes).unsqueeze(1), torch.arange(self.n_classes).unsqueeze(1)) self.histogram = self.stats tp = torch.diag(self.histogram) fp = torch.sum(self.histogram, dim=0) - tp fn = torch.sum(self.histogram, dim=1) - tp iou = tp / (tp + fp + fn) prc = tp / (tp + fn) opc = torch.sum(tp) / torch.sum(self.histogram) metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(), self.prefix + "Accuracy": opc.item()} return {k: 100 * v for k, v in metric_dict.items()} def flexible_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) try: return torch.stack(batch, 0, out=out) except RuntimeError: return batch elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return flexible_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): return {key: flexible_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(flexible_collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) return [flexible_collate(samples) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type))