Spaces:
Runtime error
Runtime error
from utils import * | |
from modules import * | |
from data import * | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
import torch.multiprocessing | |
import seaborn as sns | |
import unet | |
class LitUnsupervisedSegmenter(pl.LightningModule): | |
def __init__(self, n_classes, cfg): | |
super().__init__() | |
self.name = "LitUnsupervisedSegmenter" | |
self.cfg = cfg | |
self.n_classes = n_classes | |
if not cfg.continuous: | |
dim = n_classes | |
else: | |
dim = cfg.dim | |
data_dir = join(cfg.output_root, "data") | |
if cfg.arch == "feature-pyramid": | |
cut_model = load_model(cfg.model_type, data_dir).cuda() | |
self.net = FeaturePyramidNet( | |
cfg.granularity, cut_model, dim, cfg.continuous | |
) | |
elif cfg.arch == "dino": | |
self.net = DinoFeaturizer(dim, cfg) | |
else: | |
raise ValueError("Unknown arch {}".format(cfg.arch)) | |
self.train_cluster_probe = ClusterLookup(dim, n_classes) | |
self.cluster_probe = ClusterLookup(dim, n_classes + cfg.extra_clusters) | |
# self.linear_probe = nn.Conv2d(dim, n_classes, (1, 1)) | |
# self.linear_probe = nn.Sequential(OrderedDict([ | |
# ('conv1', nn.Conv2d(dim, 2*n_classes, (7, 7), padding='same')), | |
# ('relu1', nn.ReLU()), | |
# ('conv2', nn.Conv2d(2*n_classes, n_classes, (3, 3), padding='same')) | |
# ])) | |
self.linear_probe = unet.AuxUNet( | |
enc_chs=(3, 32, 64, 128, 256), | |
dec_chs=(256, 128, 64, 32), | |
aux_ch=70, | |
num_class=n_classes, | |
) | |
self.decoder = nn.Conv2d(dim, self.net.n_feats, (1, 1)) | |
self.cluster_metrics = UnsupervisedMetrics( | |
"test/cluster/", n_classes, cfg.extra_clusters, True | |
) | |
self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False) | |
self.test_cluster_metrics = UnsupervisedMetrics( | |
"final/cluster/", n_classes, cfg.extra_clusters, True | |
) | |
self.test_linear_metrics = UnsupervisedMetrics( | |
"final/linear/", n_classes, 0, False | |
) | |
self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss() | |
self.crf_loss_fn = ContrastiveCRFLoss( | |
cfg.crf_samples, cfg.alpha, cfg.beta, cfg.gamma, cfg.w1, cfg.w2, cfg.shift | |
) | |
self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(cfg) | |
for p in self.contrastive_corr_loss_fn.parameters(): | |
p.requires_grad = False | |
self.automatic_optimization = False | |
if self.cfg.dataset_name.startswith("cityscapes"): | |
self.label_cmap = create_cityscapes_colormap() | |
else: | |
self.label_cmap = create_pascal_label_colormap() | |
self.val_steps = 0 | |
self.save_hyperparameters() | |
def forward(self, x): | |
# in lightning, forward defines the prediction/inference actions | |
return self.net(x)[1] | |
def training_step(self, batch, batch_idx): | |
# training_step defined the train loop. | |
# It is independent of forward | |
net_optim, linear_probe_optim, cluster_probe_optim = self.optimizers() | |
net_optim.zero_grad() | |
linear_probe_optim.zero_grad() | |
cluster_probe_optim.zero_grad() | |
with torch.no_grad(): | |
ind = batch["ind"] | |
img = batch["img"] | |
img_aug = batch["img_aug"] | |
coord_aug = batch["coord_aug"] | |
img_pos = batch["img_pos"] | |
label = batch["label"] | |
label_pos = batch["label_pos"] | |
feats, code = self.net(img) | |
if self.cfg.correspondence_weight > 0: | |
feats_pos, code_pos = self.net(img_pos) | |
log_args = dict(sync_dist=False, rank_zero_only=True) | |
if self.cfg.use_true_labels: | |
signal = one_hot_feats(label + 1, self.n_classes + 1) | |
signal_pos = one_hot_feats(label_pos + 1, self.n_classes + 1) | |
else: | |
signal = feats | |
signal_pos = feats_pos | |
loss = 0 | |
should_log_hist = ( | |
(self.cfg.hist_freq is not None) | |
and (self.global_step % self.cfg.hist_freq == 0) | |
and (self.global_step > 0) | |
) | |
if self.cfg.use_salience: | |
salience = batch["mask"].to(torch.float32).squeeze(1) | |
salience_pos = batch["mask_pos"].to(torch.float32).squeeze(1) | |
else: | |
salience = None | |
salience_pos = None | |
if self.cfg.correspondence_weight > 0: | |
( | |
pos_intra_loss, | |
pos_intra_cd, | |
pos_inter_loss, | |
pos_inter_cd, | |
neg_inter_loss, | |
neg_inter_cd, | |
) = self.contrastive_corr_loss_fn( | |
signal, | |
signal_pos, | |
salience, | |
salience_pos, | |
code, | |
code_pos, | |
) | |
if should_log_hist: | |
self.logger.experiment.add_histogram( | |
"intra_cd", pos_intra_cd, self.global_step | |
) | |
self.logger.experiment.add_histogram( | |
"inter_cd", pos_inter_cd, self.global_step | |
) | |
self.logger.experiment.add_histogram( | |
"neg_cd", neg_inter_cd, self.global_step | |
) | |
neg_inter_loss = neg_inter_loss.mean() | |
pos_intra_loss = pos_intra_loss.mean() | |
pos_inter_loss = pos_inter_loss.mean() | |
self.log("loss/pos_intra", pos_intra_loss, **log_args) | |
self.log("loss/pos_inter", pos_inter_loss, **log_args) | |
self.log("loss/neg_inter", neg_inter_loss, **log_args) | |
self.log("cd/pos_intra", pos_intra_cd.mean(), **log_args) | |
self.log("cd/pos_inter", pos_inter_cd.mean(), **log_args) | |
self.log("cd/neg_inter", neg_inter_cd.mean(), **log_args) | |
loss += ( | |
self.cfg.pos_inter_weight * pos_inter_loss | |
+ self.cfg.pos_intra_weight * pos_intra_loss | |
+ self.cfg.neg_inter_weight * neg_inter_loss | |
) * self.cfg.correspondence_weight | |
if self.cfg.rec_weight > 0: | |
rec_feats = self.decoder(code) | |
rec_loss = -(norm(rec_feats) * norm(feats)).sum(1).mean() | |
self.log("loss/rec", rec_loss, **log_args) | |
loss += self.cfg.rec_weight * rec_loss | |
if self.cfg.aug_alignment_weight > 0: | |
orig_feats_aug, orig_code_aug = self.net(img_aug) | |
downsampled_coord_aug = resize( | |
coord_aug.permute(0, 3, 1, 2), orig_code_aug.shape[2] | |
).permute(0, 2, 3, 1) | |
aug_alignment = -torch.einsum( | |
"bkhw,bkhw->bhw", | |
norm(sample(code, downsampled_coord_aug)), | |
norm(orig_code_aug), | |
).mean() | |
self.log("loss/aug_alignment", aug_alignment, **log_args) | |
loss += self.cfg.aug_alignment_weight * aug_alignment | |
if self.cfg.crf_weight > 0: | |
crf = self.crf_loss_fn(resize(img, 56), norm(resize(code, 56))).mean() | |
self.log("loss/crf", crf, **log_args) | |
loss += self.cfg.crf_weight * crf | |
flat_label = label.reshape(-1) | |
mask = (flat_label >= 0) & (flat_label < self.n_classes) | |
detached_code = torch.clone(code.detach()) | |
# pdb.set_trace() | |
linear_logits = self.linear_probe(img, detached_code) | |
linear_logits = F.interpolate( | |
linear_logits, label.shape[-2:], mode="bilinear", align_corners=False | |
) | |
linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes) | |
linear_loss = self.linear_probe_loss_fn( | |
linear_logits[mask], flat_label[mask] | |
).mean() | |
loss += linear_loss | |
self.log("loss/linear", linear_loss, **log_args) | |
cluster_loss, cluster_probs = self.cluster_probe(detached_code, None) | |
loss += cluster_loss | |
self.log("loss/cluster", cluster_loss, **log_args) | |
self.log("loss/total", loss, **log_args) | |
self.manual_backward(loss) | |
net_optim.step() | |
cluster_probe_optim.step() | |
linear_probe_optim.step() | |
if ( | |
self.cfg.reset_probe_steps is not None | |
and self.global_step == self.cfg.reset_probe_steps | |
): | |
print("RESETTING PROBES") | |
self.linear_probe.reset_parameters() | |
self.cluster_probe.reset_parameters() | |
self.trainer.optimizers[1] = torch.optim.Adam( | |
list(self.linear_probe.parameters()), lr=5e-3 | |
) | |
self.trainer.optimizers[2] = torch.optim.Adam( | |
list(self.cluster_probe.parameters()), lr=5e-3 | |
) | |
if self.global_step % 2000 == 0 and self.global_step > 0: | |
print("RESETTING TFEVENT FILE") | |
# Make a new tfevent file | |
self.logger.experiment.close() | |
self.logger.experiment._get_file_writer() | |
return loss | |
def on_train_start(self): | |
tb_metrics = {**self.linear_metrics.compute(), **self.cluster_metrics.compute()} | |
self.logger.log_hyperparams(self.cfg, tb_metrics) | |
def validation_step(self, batch, batch_idx): | |
img = batch["img"] | |
label = batch["label"] | |
self.net.eval() | |
with torch.no_grad(): | |
feats, code = self.net(img) | |
# code = F.interpolate(code, label.shape[-2:], mode='bilinear', align_corners=False) | |
# linear_preds = self.linear_probe(code) | |
linear_preds = self.linear_probe(img, code) | |
linear_preds = linear_preds.argmax(1) | |
self.linear_metrics.update(linear_preds, label) | |
code = F.interpolate( | |
code, label.shape[-2:], mode="bilinear", align_corners=False | |
) | |
cluster_loss, cluster_preds = self.cluster_probe(code, None) | |
cluster_preds = cluster_preds.argmax(1) | |
self.cluster_metrics.update(cluster_preds, label) | |
return { | |
"img": img[: self.cfg.n_images].detach().cpu(), | |
"linear_preds": linear_preds[: self.cfg.n_images].detach().cpu(), | |
"cluster_preds": cluster_preds[: self.cfg.n_images].detach().cpu(), | |
"label": label[: self.cfg.n_images].detach().cpu(), | |
} | |
def validation_epoch_end(self, outputs) -> None: | |
super().validation_epoch_end(outputs) | |
with torch.no_grad(): | |
tb_metrics = { | |
**self.linear_metrics.compute(), | |
**self.cluster_metrics.compute(), | |
} | |
if self.trainer.is_global_zero and not self.cfg.submitting_to_aml: | |
# output_num = 0 | |
output_num = random.randint(0, len(outputs) - 1) | |
output = {k: v.detach().cpu() for k, v in outputs[output_num].items()} | |
# pdb.set_trace() | |
alpha = 0.4 | |
n_rows = 6 | |
fig, ax = plt.subplots( | |
n_rows, | |
self.cfg.n_images, | |
figsize=(self.cfg.n_images * 3, n_rows * 3), | |
) | |
for i in range(self.cfg.n_images): | |
try: | |
rbg_img = prep_for_plot(output["img"][i]) | |
true_label = output["label"].squeeze()[i] | |
true_label[true_label == -1] = 7 | |
except: | |
continue | |
# ax[0, i].imshow(prep_for_plot(output["img"][i])) | |
# ax[1, i].imshow(self.label_cmap[output["label"].squeeze()[i]]) | |
# ax[2, i].imshow(self.label_cmap[output["linear_preds"][i]]) | |
# ax[3, i].imshow(self.label_cmap[self.cluster_metrics.map_clusters(output["cluster_preds"][i])]) | |
ax[0, i].imshow(rbg_img) | |
ax[1, i].imshow(rbg_img) | |
ax[1, i].imshow(true_label, alpha=alpha, cmap=cmap, norm=norm) | |
ax[2, i].imshow(rbg_img) | |
pred_label = output["linear_preds"][i] | |
ax[2, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm) | |
ax[3, i].imshow(rbg_img) | |
retouched_label = retouch_label(pred_label.numpy(), true_label) | |
ax[3, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm) | |
ax[4, i].imshow(rbg_img) | |
pred_label = self.cluster_metrics.map_clusters( | |
output["cluster_preds"][i] | |
) | |
ax[4, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm) | |
# ax[3, i].imshow(map_clusters_with_label(true_label, pred_label), alpha=0.5, cmap=cmap, norm=norm) | |
ax[5, i].imshow(rbg_img) | |
retouched_label = retouch_label(pred_label.numpy(), true_label) | |
ax[5, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm) | |
ax[0, 0].set_ylabel("Image", fontsize=16) | |
ax[1, 0].set_ylabel("Label", fontsize=16) | |
ax[2, 0].set_ylabel("UNet Probe", fontsize=16) | |
ax[3, 0].set_ylabel("Retouched UNet Probe", fontsize=16) | |
ax[4, 0].set_ylabel("Cluster Probe", fontsize=16) | |
ax[5, 0].set_ylabel("Retouched cluster Probe", fontsize=16) | |
remove_axes(ax) | |
plt.tight_layout() | |
add_plot(self.logger.experiment, "plot_labels", self.global_step) | |
if self.cfg.has_labels: | |
fig = plt.figure(figsize=(13, 10)) | |
ax = fig.gca() | |
hist = ( | |
self.cluster_metrics.histogram.detach().cpu().to(torch.float32) | |
) | |
hist /= torch.clamp_min(hist.sum(dim=0, keepdim=True), 1) | |
sns.heatmap(hist.t(), annot=False, fmt="g", ax=ax, cmap="Blues") | |
ax.set_xlabel("Predicted labels") | |
ax.set_ylabel("True labels") | |
names = get_class_labels(self.cfg.dataset_name) | |
if self.cfg.extra_clusters: | |
names = names + ["Extra"] | |
ax.set_xticks(np.arange(0, len(names)) + 0.5) | |
ax.set_yticks(np.arange(0, len(names)) + 0.5) | |
ax.xaxis.tick_top() | |
ax.xaxis.set_ticklabels(names, fontsize=14) | |
ax.yaxis.set_ticklabels(names, fontsize=14) | |
colors = [self.label_cmap[i] / 255.0 for i in range(len(names))] | |
[ | |
t.set_color(colors[i]) | |
for i, t in enumerate(ax.xaxis.get_ticklabels()) | |
] | |
[ | |
t.set_color(colors[i]) | |
for i, t in enumerate(ax.yaxis.get_ticklabels()) | |
] | |
# ax.yaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0) | |
# ax.xaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0) | |
plt.xticks(rotation=90) | |
plt.yticks(rotation=0) | |
ax.vlines( | |
np.arange(0, len(names) + 1), | |
color=[0.5, 0.5, 0.5], | |
*ax.get_xlim() | |
) | |
ax.hlines( | |
np.arange(0, len(names) + 1), | |
color=[0.5, 0.5, 0.5], | |
*ax.get_ylim() | |
) | |
plt.tight_layout() | |
add_plot(self.logger.experiment, "conf_matrix", self.global_step) | |
all_bars = torch.cat( | |
[ | |
self.cluster_metrics.histogram.sum(0).cpu(), | |
self.cluster_metrics.histogram.sum(1).cpu(), | |
], | |
axis=0, | |
) | |
ymin = max(all_bars.min() * 0.8, 1) | |
ymax = all_bars.max() * 1.2 | |
fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 1 * 4)) | |
ax[0].bar( | |
range(self.n_classes + self.cfg.extra_clusters), | |
self.cluster_metrics.histogram.sum(0).cpu(), | |
tick_label=names, | |
color=colors, | |
) | |
ax[0].set_ylim(ymin, ymax) | |
ax[0].set_title("Label Frequency") | |
ax[0].set_yscale("log") | |
ax[0].tick_params(axis="x", labelrotation=90) | |
ax[1].bar( | |
range(self.n_classes + self.cfg.extra_clusters), | |
self.cluster_metrics.histogram.sum(1).cpu(), | |
tick_label=names, | |
color=colors, | |
) | |
ax[1].set_ylim(ymin, ymax) | |
ax[1].set_title("Cluster Frequency") | |
ax[1].set_yscale("log") | |
ax[1].tick_params(axis="x", labelrotation=90) | |
plt.tight_layout() | |
add_plot( | |
self.logger.experiment, "label frequency", self.global_step | |
) | |
if self.global_step > 2: | |
self.log_dict(tb_metrics) | |
if self.trainer.is_global_zero and self.cfg.azureml_logging: | |
from azureml.core.run import Run | |
run_logger = Run.get_context() | |
for metric, value in tb_metrics.items(): | |
run_logger.log(metric, value) | |
self.linear_metrics.reset() | |
self.cluster_metrics.reset() | |
def configure_optimizers(self): | |
main_params = list(self.net.parameters()) | |
if self.cfg.rec_weight > 0: | |
main_params.extend(self.decoder.parameters()) | |
net_optim = torch.optim.Adam(main_params, lr=self.cfg.lr) | |
linear_probe_optim = torch.optim.Adam( | |
list(self.linear_probe.parameters()), lr=5e-3 | |
) | |
cluster_probe_optim = torch.optim.Adam( | |
list(self.cluster_probe.parameters()), lr=5e-3 | |
) | |
return net_optim, linear_probe_optim, cluster_probe_optim | |