Spaces:
Runtime error
Runtime error
from utils import * | |
from modules import * | |
from data import * | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F | |
from datetime import datetime | |
import hydra | |
from omegaconf import DictConfig, OmegaConf | |
import pytorch_lightning as pl | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from pytorch_lightning.utilities.seed import seed_everything | |
import torch.multiprocessing | |
import seaborn as sns | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
import sys | |
import pdb | |
import matplotlib as mpl | |
from skimage import measure | |
from scipy.stats import mode as statsmode | |
from collections import OrderedDict | |
import unet | |
import pdb | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") | |
class_names = ( | |
"Buildings", | |
"Cultivation", | |
"Natural green", | |
"Wetland", | |
"Water", | |
"Infrastructure", | |
"Background", | |
) | |
bounds = list(np.arange(len(class_names) + 1) + 1) | |
cmap = mpl.colors.ListedColormap(colors) | |
norm = mpl.colors.BoundaryNorm(bounds, cmap.N) | |
def retouch_label(pred_label, true_label): | |
retouched_label = pred_label + 0 | |
blobs = measure.label(retouched_label) | |
for idx in np.unique(blobs): | |
# most frequent label class in this blob | |
retouched_label[blobs == idx] = statsmode(true_label[blobs == idx])[0][0] | |
return retouched_label | |
def get_class_labels(dataset_name): | |
if dataset_name.startswith("cityscapes"): | |
return [ | |
"road", | |
"sidewalk", | |
"parking", | |
"rail track", | |
"building", | |
"wall", | |
"fence", | |
"guard rail", | |
"bridge", | |
"tunnel", | |
"pole", | |
"polegroup", | |
"traffic light", | |
"traffic sign", | |
"vegetation", | |
"terrain", | |
"sky", | |
"person", | |
"rider", | |
"car", | |
"truck", | |
"bus", | |
"caravan", | |
"trailer", | |
"train", | |
"motorcycle", | |
"bicycle", | |
] | |
elif dataset_name == "cocostuff27": | |
return [ | |
"electronic", | |
"appliance", | |
"food", | |
"furniture", | |
"indoor", | |
"kitchen", | |
"accessory", | |
"animal", | |
"outdoor", | |
"person", | |
"sports", | |
"vehicle", | |
"ceiling", | |
"floor", | |
"food", | |
"furniture", | |
"rawmaterial", | |
"textile", | |
"wall", | |
"window", | |
"building", | |
"ground", | |
"plant", | |
"sky", | |
"solid", | |
"structural", | |
"water", | |
] | |
elif dataset_name == "voc": | |
return [ | |
"background", | |
"aeroplane", | |
"bicycle", | |
"bird", | |
"boat", | |
"bottle", | |
"bus", | |
"car", | |
"cat", | |
"chair", | |
"cow", | |
"diningtable", | |
"dog", | |
"horse", | |
"motorbike", | |
"person", | |
"pottedplant", | |
"sheep", | |
"sofa", | |
"train", | |
"tvmonitor", | |
] | |
elif dataset_name == "potsdam": | |
return ["roads and cars", "buildings and clutter", "trees and vegetation"] | |
else: | |
raise ValueError("Unknown Dataset {}".format(dataset_name)) | |
def my_app(cfg: DictConfig) -> None: | |
OmegaConf.set_struct(cfg, False) | |
print(OmegaConf.to_yaml(cfg)) | |
pytorch_data_dir = cfg.pytorch_data_dir | |
data_dir = join(cfg.output_root, "data") | |
log_dir = join(cfg.output_root, "logs") | |
checkpoint_dir = join(cfg.output_root, "checkpoints") | |
prefix = "{}/{}_{}".format(cfg.log_dir, cfg.dataset_name, cfg.experiment_name) | |
name = "{}_date_{}".format(prefix, datetime.now().strftime("%b%d_%H-%M-%S")) | |
cfg.full_name = prefix | |
os.makedirs(data_dir, exist_ok=True) | |
os.makedirs(log_dir, exist_ok=True) | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
seed_everything(seed=0) | |
print(data_dir) | |
print(cfg.output_root) | |
geometric_transforms = T.Compose( | |
[T.RandomHorizontalFlip(), T.RandomResizedCrop(size=cfg.res, scale=(0.8, 1.0))] | |
) | |
photometric_transforms = T.Compose( | |
[ | |
T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), | |
T.RandomGrayscale(0.2), | |
T.RandomApply([T.GaussianBlur((5, 5))]), | |
] | |
) | |
sys.stdout.flush() | |
train_dataset = ContrastiveSegDataset( | |
pytorch_data_dir=pytorch_data_dir, | |
dataset_name=cfg.dataset_name, | |
crop_type=cfg.crop_type, | |
image_set="train", | |
transform=get_transform(cfg.res, False, cfg.loader_crop_type), | |
target_transform=get_transform(cfg.res, True, cfg.loader_crop_type), | |
cfg=cfg, | |
aug_geometric_transform=geometric_transforms, | |
aug_photometric_transform=photometric_transforms, | |
num_neighbors=cfg.num_neighbors, | |
mask=True, | |
pos_images=True, | |
pos_labels=True, | |
) | |
if cfg.dataset_name == "voc": | |
val_loader_crop = None | |
else: | |
val_loader_crop = "center" | |
val_dataset = ContrastiveSegDataset( | |
pytorch_data_dir=pytorch_data_dir, | |
dataset_name=cfg.dataset_name, | |
crop_type=None, | |
image_set="val", | |
transform=get_transform(320, False, val_loader_crop), | |
target_transform=get_transform(320, True, val_loader_crop), | |
mask=True, | |
cfg=cfg, | |
) | |
# val_dataset = MaterializedDataset(val_dataset) | |
train_loader = DataLoader( | |
train_dataset, | |
cfg.batch_size, | |
shuffle=True, | |
num_workers=cfg.num_workers, | |
pin_memory=True, | |
) | |
if cfg.submitting_to_aml: | |
val_batch_size = 16 | |
else: | |
val_batch_size = cfg.batch_size | |
val_loader = DataLoader( | |
val_dataset, | |
val_batch_size, | |
shuffle=False, | |
num_workers=cfg.num_workers, | |
pin_memory=True, | |
) | |
model = LitUnsupervisedSegmenter(train_dataset.n_classes, cfg) | |
tb_logger = TensorBoardLogger(join(log_dir, name), default_hp_metric=False) | |
if cfg.submitting_to_aml: | |
gpu_args = dict(gpus=1, val_check_interval=250) | |
if gpu_args["val_check_interval"] > len(train_loader): | |
gpu_args.pop("val_check_interval") | |
else: | |
gpu_args = dict(gpus=-1, accelerator="ddp", val_check_interval=cfg.val_freq) | |
# gpu_args = dict(gpus=1, accelerator='ddp', val_check_interval=cfg.val_freq) | |
if gpu_args["val_check_interval"] > len(train_loader) // 4: | |
gpu_args.pop("val_check_interval") | |
trainer = Trainer( | |
log_every_n_steps=cfg.scalar_log_freq, | |
logger=tb_logger, | |
max_steps=cfg.max_steps, | |
callbacks=[ | |
ModelCheckpoint( | |
dirpath=join(checkpoint_dir, name), | |
every_n_train_steps=400, | |
save_top_k=2, | |
monitor="test/cluster/mIoU", | |
mode="max", | |
) | |
], | |
**gpu_args | |
) | |
trainer.fit(model, train_loader, val_loader) | |
if __name__ == "__main__": | |
prep_args() | |
my_app() | |