import torch.multiprocessing import torchvision.transforms as T import numpy as np from utils import transform_to_pil, compute_biodiv_score, plot_imgs_labels, plot_image from utils_gee import get_image from dateutil.relativedelta import relativedelta from model import LitUnsupervisedSegmenter import datetime import matplotlib as mpl from joblib import Parallel, cpu_count, delayed import logging from inference import inference import streamlit as st import cv2 @st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name}) def inference_on_location(model, longitude=2.98, latitude=48.81, start_date=2020, end_date=2022, how="year"): """Performe an inference on the latitude and longitude between the start date and the end date Args: latitude (float): the latitude of the landscape longitude (float): the longitude of the landscape start_date (str): the start date for our inference end_date (str): the end date for our inference model (_type_, optional): _description_. Defaults to model. Returns: img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape """ logging.info("Running Inference on location") logging.info(f"latitude : {latitude} & longitude : {longitude}") logging.info(f"start date : {start_date} & end_date : {end_date}") logging.info(f"Prediction on intervale : {how}") if how == "month": delta_month = 1 elif how == "2months": delta_month = 2 elif how == "year": delta_month = 11 else: raise ValueError("Wrong interval") assert int(end_date) > int(start_date), "end date must be stricly higher than start date" location = [float(latitude), float(longitude)] # Extract img numpy from earth engine and transform it to PIL img dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)] while dates[-1] < datetime.datetime(int(end_date), 1, 1, 0, 0, 0): dates.append(dates[-1] + relativedelta(months=delta_month)) dates = [d.strftime("%Y-%m-%d") for d in dates] all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:])) # all_image = [cv2.imread("output/img.png") for i in range(len(dates))] outputs = inference(np.array(all_image), model) logging.info("Calculating Biodiversity Scores...") scores, scores_details = map(list, zip(*[compute_biodiv_score(output["linear_preds"].detach().numpy()) for output in outputs])) logging.info(f"Calculated Biodiversity Score : {scores}") imgs, labels, labeled_imgs = map(list, zip(*[transform_to_pil(output) for output in outputs])) images = [np.asarray(img) for img in imgs] labeled_imgs = [np.asarray(img) for img in labeled_imgs] fig = plot_imgs_labels(dates, images, labeled_imgs, scores_details, scores) # fig.save("test.png") return fig @st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name}) def inference_on_location_and_month(model, longitude = 2.98, latitude = 48.81, start_date = '2020-03-20'): """Performe an inference on the latitude and longitude between the start date and the end date Args: latitude (float): the latitude of the landscape longitude (float): the longitude of the landscape start_date (str): the start date for our inference end_date (str): the end date for our inference model (_type_, optional): _description_. Defaults to model. Returns: img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape """ logging.info("Running Inference on location and month") logging.info(f"latitude : {latitude} & longitude : {longitude}") location = [float(latitude), float(longitude)] # Extract img numpy from earth engine and transform it to PIL img end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1) end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d") img_test = get_image(location, start_date, end_date) outputs = inference(np.array([img_test]), model) logging.info("Calculating Biodiversity Score...") score, score_details = compute_biodiv_score(outputs[0]["linear_preds"].detach().numpy()) logging.info(f"Calculated Biodiversity Score : {score}") img, label, labeled_img = transform_to_pil(outputs[0]) fig = plot_image([start_date], [np.asarray(img)], [np.asarray(labeled_img)], [score_details], [score]) return fig if __name__ == "__main__": import logging import hydra import sys from model import LitUnsupervisedSegmenter file_handler = logging.FileHandler(filename='biomap.log') stdout_handler = logging.StreamHandler(stream=sys.stdout) handlers = [file_handler, stdout_handler] logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") # Initialize hydra with configs hydra.initialize(config_path="configs", job_name="corine") cfg = hydra.compose(config_name="my_train_config.yml") logging.info(f"config : {cfg}") # Load the model nbclasses = cfg.dir_dataset_n_classes model = LitUnsupervisedSegmenter(nbclasses, cfg) logging.info(f"Model Initialiazed") model_path = "biomap/checkpoint/model/model.pt" saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) logging.info(f"Model weights Loaded") model.load_state_dict(saved_state_dict) logging.info(f"Model Loaded") # inference_on_location_and_month(model) inference_on_location(model)