Spaces:
Runtime error
Runtime error
from PIL import Image | |
import matplotlib as mpl | |
from utils import prep_for_plot | |
import torch.multiprocessing | |
import torchvision.transforms as T | |
from utils_gee import extract_img, transform_ee_img | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import numpy as np | |
from plotly.subplots import make_subplots | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey') | |
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background') | |
cmap = mpl.colors.ListedColormap(colors) | |
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey') | |
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background') | |
scores_init = [1,2,4,3,4,1,0] | |
# Function that look for img on EE and segment it | |
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img | |
def segment_loc(model, location, month, year, how = "month", month_end = '12', year_end = None) : | |
if how == 'month': | |
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28') | |
elif how == 'year' : | |
if year_end == None : | |
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04) | |
else : | |
img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04) | |
img_test= transform_ee_img(img, max = 0.25) | |
# Preprocess opened img | |
x = preprocess(img_test) | |
x = torch.unsqueeze(x, dim=0).cpu() | |
# model=model.cpu() | |
with torch.no_grad(): | |
feats, code = model.net(x) | |
linear_preds = model.linear_probe(x, code) | |
linear_preds = linear_preds.argmax(1) | |
outputs = { | |
'img': x[:model.cfg.n_images].detach().cpu(), | |
'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu() | |
} | |
return outputs | |
# Function that look for all img on EE and extract all segments with the date as first output arg | |
def segment_group(location, start_date, end_date, how = 'month') : | |
outputs = [] | |
st_month = int(start_date[5:7]) | |
end_month = int(end_date[5:7]) | |
st_year = int(start_date[0:4]) | |
end_year = int(end_date[0:4]) | |
for year in range(st_year, end_year+1) : | |
if year != end_year : | |
last = 12 | |
else : | |
last = end_month | |
if year != st_year: | |
start = 1 | |
else : | |
start = st_month | |
if how == 'month' : | |
for month in range(start, last + 1): | |
month_str = f"{month:0>2d}" | |
year_str = str(year) | |
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str))) | |
elif how == 'year' : | |
outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}"))) | |
elif how == '2months' : | |
for month in range(start, last + 1): | |
month_str = f"{month:0>2d}" | |
year_str = str(year) | |
month_end = (month) % 12 +1 | |
if month_end < month : | |
year_end = year +1 | |
else : | |
year_end = year | |
month_end= f"{month_end:0>2d}" | |
year_end = str(year_end) | |
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end))) | |
return outputs | |
def values_from_output(output): | |
imgs = transform_to_pil(output, alpha = 0.3) | |
img = imgs[0] | |
img = np.array(img.convert('RGB')) | |
labeled_img = imgs[2] | |
labeled_img = np.array(labeled_img.convert('RGB')) | |
nb_values = [] | |
for i in range(7): | |
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1)) | |
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init) | |
return img, labeled_img, nb_values, score | |
# Function that extract from outputs (from segment_group function) all dates/ all images | |
def values_from_outputs(outputs) : | |
months = [] | |
imgs = [] | |
imgs_label = [] | |
nb_values = [] | |
scores = [] | |
for output in outputs: | |
img, labeled_img, nb_value, score = values_from_output(output[1]) | |
months.append(output[0]) | |
imgs.append(img) | |
imgs_label.append(labeled_img) | |
nb_values.append(nb_value) | |
scores.append(score) | |
return months, imgs, imgs_label, nb_values, scores | |
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) : | |
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True) | |
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True) | |
# Scores | |
scatters = [] | |
temp = [] | |
for score in scores : | |
temp_score = [] | |
temp_date = [] | |
score = scores[i] | |
temp.append(score) | |
text_temp = ["" for i in temp] | |
text_temp[-1] = str(round(score,2)) | |
scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center")) | |
# Scores | |
fig = make_subplots( | |
rows=1, cols=4, | |
# specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]] | |
# row_heights=[0.8, 0.2], | |
column_widths = [0.6, 0.6,0.3, 0.3], | |
subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores") | |
) | |
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1) | |
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2) | |
fig.add_trace(go.Pie(labels = class_names, | |
values = nb_values[0], | |
marker_colors = colors, | |
name="Segment repartition", | |
textposition='inside', | |
texttemplate = "%{percent:.0%}", | |
textfont_size=14 | |
), | |
row=1, col=3) | |
fig.add_trace(scatters[0], row=1, col=4) | |
# fig.add_annotation(text='score:' + str(scores[0]), | |
# showarrow=False, | |
# row=2, col=2) | |
number_frames = len(imgs) | |
frames = [dict( | |
name = k, | |
data = [ fig2["frames"][k]["data"][0], | |
fig3["frames"][k]["data"][0], | |
go.Pie(labels = class_names, | |
values = nb_values[k], | |
marker_colors = colors, | |
name="Segment repartition", | |
textposition='inside', | |
texttemplate = "%{percent:.0%}", | |
textfont_size=14 | |
), | |
scatters[k] | |
], | |
traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data | |
# that are updated by the above three go.Scatter instances | |
) for k in range(number_frames)] | |
updatemenus = [dict(type='buttons', | |
buttons=[dict(label='Play', | |
method='animate', | |
args=[[f'{k}' for k in range(number_frames)], | |
dict(frame=dict(duration=500, redraw=False), | |
transition=dict(duration=0), | |
easing='linear', | |
fromcurrent=True, | |
mode='immediate' | |
)])], | |
direction= 'left', | |
pad=dict(r= 10, t=85), | |
showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top') | |
] | |
sliders = [{'yanchor': 'top', | |
'xanchor': 'left', | |
'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'}, | |
'transition': {'duration': 500.0, 'easing': 'linear'}, | |
'pad': {'b': 10, 't': 50}, | |
'len': 0.9, 'x': 0.1, 'y': 0, | |
'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False}, | |
'transition': {'duration': 0, 'easing': 'linear'}}], | |
'label': months[k], 'method': 'animate'} for k in range(number_frames) | |
]}] | |
fig.update(frames=frames) | |
for i,fr in enumerate(fig["frames"]): | |
fr.update( | |
layout={ | |
"xaxis": { | |
"range": [0,imgs[0].shape[1]+i/100000] | |
}, | |
"yaxis": { | |
"range": [imgs[0].shape[0]+i/100000,0] | |
}, | |
}) | |
fr.update(layout_title_text= months[i]) | |
fig.update(layout_title_text= 'tot') | |
fig.update( | |
layout={ | |
"xaxis": { | |
"range": [0,imgs[0].shape[1]+i/100000], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at x=0 | |
'visible': False, # numbers below | |
}, | |
"yaxis": { | |
"range": [imgs[0].shape[0]+i/100000,0], | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at y=0 | |
'visible': False,}, | |
"xaxis3": { | |
"range": [0,len(scores)+1], | |
'autorange': False, # thin lines in the background | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at y=0 | |
'visible': False | |
}, | |
"yaxis3": { | |
"range": [0,1.5], | |
'autorange': False, | |
'showgrid': False, # thin lines in the background | |
'zeroline': False, # thick line at y=0 | |
'visible': False # thin lines in the background | |
} | |
}, | |
legend=dict( | |
yanchor="bottom", | |
y=0.99, | |
xanchor="center", | |
x=0.01 | |
) | |
) | |
fig.update_layout(updatemenus=updatemenus, | |
sliders=sliders) | |
fig.update_layout(margin=dict(b=0, r=0)) | |
# fig.show() #in jupyter notebook | |
return fig | |
# Last function (global one) | |
# how = 'month' or '2months' or 'year' | |
def segment_region(location, start_date, end_date, how = 'month'): | |
#extract the outputs for each image | |
outputs = segment_group(location, start_date, end_date, how = how) | |
#extract the intersting values from image | |
months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs) | |
#Create the figure | |
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) | |
return fig | |
#normalize img | |
preprocess = T.Compose([ | |
T.ToPILImage(), | |
T.Resize((320,320)), | |
# T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
# Function that look for img on EE and segment it | |
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img | |
def segment_loc(model,location, month, year, how = "month", month_end = '12', year_end = None) : | |
if how == 'month': | |
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28') | |
elif how == 'year' : | |
if year_end == None : | |
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04) | |
else : | |
img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04) | |
img_test= transform_ee_img(img, max = 0.25) | |
# Preprocess opened img | |
x = preprocess(img_test) | |
x = torch.unsqueeze(x, dim=0).cpu() | |
# model=model.cpu() | |
with torch.no_grad(): | |
feats, code = model.net(x) | |
linear_preds = model.linear_probe(x, code) | |
linear_preds = linear_preds.argmax(1) | |
outputs = { | |
'img': x[:model.cfg.n_images].detach().cpu(), | |
'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu() | |
} | |
return outputs | |
# Function that look for all img on EE and extract all segments with the date as first output arg | |
def segment_group(location, start_date, end_date, how = 'month') : | |
outputs = [] | |
st_month = int(start_date[5:7]) | |
end_month = int(end_date[5:7]) | |
st_year = int(start_date[0:4]) | |
end_year = int(end_date[0:4]) | |
for year in range(st_year, end_year+1) : | |
if year != end_year : | |
last = 12 | |
else : | |
last = end_month | |
if year != st_year: | |
start = 1 | |
else : | |
start = st_month | |
if how == 'month' : | |
for month in range(start, last + 1): | |
month_str = f"{month:0>2d}" | |
year_str = str(year) | |
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str))) | |
elif how == 'year' : | |
outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}"))) | |
elif how == '2months' : | |
for month in range(start, last + 1): | |
month_str = f"{month:0>2d}" | |
year_str = str(year) | |
month_end = (month) % 12 +1 | |
if month_end < month : | |
year_end = year +1 | |
else : | |
year_end = year | |
month_end= f"{month_end:0>2d}" | |
year_end = str(year_end) | |
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end))) | |
return outputs | |
# Function that transforms an output to PIL images | |
def transform_to_pil(outputs,alpha=0.3): | |
# Transform img with torch | |
img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0) | |
img=T.ToPILImage()(img) | |
# Transform label by saving it then open it | |
# label = outputs['linear_preds'][0] | |
# plt.imsave('label.png',label,cmap=cmap) | |
# label = Image.open('label.png') | |
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)]) | |
labels = np.array(outputs['linear_preds'][0])-1 | |
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8)) | |
# Overlay labels with img wit alpha | |
background = img.convert("RGBA") | |
overlay = label.convert("RGBA") | |
labeled_img = Image.blend(background, overlay, alpha) | |
return img, label, labeled_img | |
def values_from_output(output): | |
imgs = transform_to_pil(output,alpha = 0.3) | |
img = imgs[0] | |
img = np.array(img.convert('RGB')) | |
labeled_img = imgs[2] | |
labeled_img = np.array(labeled_img.convert('RGB')) | |
nb_values = [] | |
for i in range(7): | |
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1)) | |
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init) | |
return img, labeled_img, nb_values, score | |
# Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation | |
# Function that extract from outputs (from segment_group function) all dates/ all images | |
def values_from_outputs(outputs) : | |
months = [] | |
imgs = [] | |
imgs_label = [] | |
nb_values = [] | |
scores = [] | |
for output in outputs: | |
img, labeled_img, nb_value, score = values_from_output(output[1]) | |
months.append(output[0]) | |
imgs.append(img) | |
imgs_label.append(labeled_img) | |
nb_values.append(nb_value) | |
scores.append(score) | |
return months, imgs, imgs_label, nb_values, scores | |
# Last function (global one) | |
# how = 'month' or '2months' or 'year' | |
def segment_region(latitude, longitude, start_date, end_date, how = 'month'): | |
location = [float(latitude),float(longitude)] | |
how = how[0] | |
#extract the outputs for each image | |
outputs = segment_group(location, start_date, end_date, how = how) | |
#extract the intersting values from image | |
months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs) | |
print(months, imgs, imgs_label, nb_values, scores) | |
#Create the figure | |
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) | |
return fig |