from utils import * from modules import * from data import * from torch.utils.data import DataLoader import torch.nn.functional as F from datetime import datetime import hydra from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.seed import seed_everything import torch.multiprocessing import seaborn as sns from pytorch_lightning.callbacks import ModelCheckpoint import sys import pdb import matplotlib as mpl from skimage import measure from scipy.stats import mode as statsmode from collections import OrderedDict import unet import pdb torch.multiprocessing.set_sharing_strategy("file_system") colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") class_names = ( "Buildings", "Cultivation", "Natural green", "Wetland", "Water", "Infrastructure", "Background", ) bounds = list(np.arange(len(class_names) + 1) + 1) cmap = mpl.colors.ListedColormap(colors) norm = mpl.colors.BoundaryNorm(bounds, cmap.N) def retouch_label(pred_label, true_label): retouched_label = pred_label + 0 blobs = measure.label(retouched_label) for idx in np.unique(blobs): # most frequent label class in this blob retouched_label[blobs == idx] = statsmode(true_label[blobs == idx])[0][0] return retouched_label def get_class_labels(dataset_name): if dataset_name.startswith("cityscapes"): return [ "road", "sidewalk", "parking", "rail track", "building", "wall", "fence", "guard rail", "bridge", "tunnel", "pole", "polegroup", "traffic light", "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", "truck", "bus", "caravan", "trailer", "train", "motorcycle", "bicycle", ] elif dataset_name == "cocostuff27": return [ "electronic", "appliance", "food", "furniture", "indoor", "kitchen", "accessory", "animal", "outdoor", "person", "sports", "vehicle", "ceiling", "floor", "food", "furniture", "rawmaterial", "textile", "wall", "window", "building", "ground", "plant", "sky", "solid", "structural", "water", ] elif dataset_name == "voc": return [ "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ] elif dataset_name == "potsdam": return ["roads and cars", "buildings and clutter", "trees and vegetation"] else: raise ValueError("Unknown Dataset {}".format(dataset_name)) @hydra.main(config_path="configs", config_name="train_config.yml") def my_app(cfg: DictConfig) -> None: OmegaConf.set_struct(cfg, False) print(OmegaConf.to_yaml(cfg)) pytorch_data_dir = cfg.pytorch_data_dir data_dir = join(cfg.output_root, "data") log_dir = join(cfg.output_root, "logs") checkpoint_dir = join(cfg.output_root, "checkpoints") prefix = "{}/{}_{}".format(cfg.log_dir, cfg.dataset_name, cfg.experiment_name) name = "{}_date_{}".format(prefix, datetime.now().strftime("%b%d_%H-%M-%S")) cfg.full_name = prefix os.makedirs(data_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) seed_everything(seed=0) print(data_dir) print(cfg.output_root) geometric_transforms = T.Compose( [T.RandomHorizontalFlip(), T.RandomResizedCrop(size=cfg.res, scale=(0.8, 1.0))] ) photometric_transforms = T.Compose( [ T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), T.RandomGrayscale(0.2), T.RandomApply([T.GaussianBlur((5, 5))]), ] ) sys.stdout.flush() train_dataset = ContrastiveSegDataset( pytorch_data_dir=pytorch_data_dir, dataset_name=cfg.dataset_name, crop_type=cfg.crop_type, image_set="train", transform=get_transform(cfg.res, False, cfg.loader_crop_type), target_transform=get_transform(cfg.res, True, cfg.loader_crop_type), cfg=cfg, aug_geometric_transform=geometric_transforms, aug_photometric_transform=photometric_transforms, num_neighbors=cfg.num_neighbors, mask=True, pos_images=True, pos_labels=True, ) if cfg.dataset_name == "voc": val_loader_crop = None else: val_loader_crop = "center" val_dataset = ContrastiveSegDataset( pytorch_data_dir=pytorch_data_dir, dataset_name=cfg.dataset_name, crop_type=None, image_set="val", transform=get_transform(320, False, val_loader_crop), target_transform=get_transform(320, True, val_loader_crop), mask=True, cfg=cfg, ) # val_dataset = MaterializedDataset(val_dataset) train_loader = DataLoader( train_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, ) if cfg.submitting_to_aml: val_batch_size = 16 else: val_batch_size = cfg.batch_size val_loader = DataLoader( val_dataset, val_batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, ) model = LitUnsupervisedSegmenter(train_dataset.n_classes, cfg) tb_logger = TensorBoardLogger(join(log_dir, name), default_hp_metric=False) if cfg.submitting_to_aml: gpu_args = dict(gpus=1, val_check_interval=250) if gpu_args["val_check_interval"] > len(train_loader): gpu_args.pop("val_check_interval") else: gpu_args = dict(gpus=-1, accelerator="ddp", val_check_interval=cfg.val_freq) # gpu_args = dict(gpus=1, accelerator='ddp', val_check_interval=cfg.val_freq) if gpu_args["val_check_interval"] > len(train_loader) // 4: gpu_args.pop("val_check_interval") trainer = Trainer( log_every_n_steps=cfg.scalar_log_freq, logger=tb_logger, max_steps=cfg.max_steps, callbacks=[ ModelCheckpoint( dirpath=join(checkpoint_dir, name), every_n_train_steps=400, save_top_k=2, monitor="test/cluster/mIoU", mode="max", ) ], **gpu_args ) trainer.fit(model, train_loader, val_loader) if __name__ == "__main__": prep_args() my_app()