Biomap / biomap /model.py
jeremyLE-Ekimetrics's picture
streamlit
9fcd62f
raw
history blame
18.7 kB
from utils import *
from modules import *
from data import *
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.multiprocessing
import seaborn as sns
import unet
class LitUnsupervisedSegmenter(pl.LightningModule):
def __init__(self, n_classes, cfg):
super().__init__()
self.name = "LitUnsupervisedSegmenter"
self.cfg = cfg
self.n_classes = n_classes
if not cfg.continuous:
dim = n_classes
else:
dim = cfg.dim
data_dir = join(cfg.output_root, "data")
if cfg.arch == "feature-pyramid":
cut_model = load_model(cfg.model_type, data_dir).cuda()
self.net = FeaturePyramidNet(
cfg.granularity, cut_model, dim, cfg.continuous
)
elif cfg.arch == "dino":
self.net = DinoFeaturizer(dim, cfg)
else:
raise ValueError("Unknown arch {}".format(cfg.arch))
self.train_cluster_probe = ClusterLookup(dim, n_classes)
self.cluster_probe = ClusterLookup(dim, n_classes + cfg.extra_clusters)
# self.linear_probe = nn.Conv2d(dim, n_classes, (1, 1))
# self.linear_probe = nn.Sequential(OrderedDict([
# ('conv1', nn.Conv2d(dim, 2*n_classes, (7, 7), padding='same')),
# ('relu1', nn.ReLU()),
# ('conv2', nn.Conv2d(2*n_classes, n_classes, (3, 3), padding='same'))
# ]))
self.linear_probe = unet.AuxUNet(
enc_chs=(3, 32, 64, 128, 256),
dec_chs=(256, 128, 64, 32),
aux_ch=70,
num_class=n_classes,
)
self.decoder = nn.Conv2d(dim, self.net.n_feats, (1, 1))
self.cluster_metrics = UnsupervisedMetrics(
"test/cluster/", n_classes, cfg.extra_clusters, True
)
self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False)
self.test_cluster_metrics = UnsupervisedMetrics(
"final/cluster/", n_classes, cfg.extra_clusters, True
)
self.test_linear_metrics = UnsupervisedMetrics(
"final/linear/", n_classes, 0, False
)
self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss()
self.crf_loss_fn = ContrastiveCRFLoss(
cfg.crf_samples, cfg.alpha, cfg.beta, cfg.gamma, cfg.w1, cfg.w2, cfg.shift
)
self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(cfg)
for p in self.contrastive_corr_loss_fn.parameters():
p.requires_grad = False
self.automatic_optimization = False
if self.cfg.dataset_name.startswith("cityscapes"):
self.label_cmap = create_cityscapes_colormap()
else:
self.label_cmap = create_pascal_label_colormap()
self.val_steps = 0
self.save_hyperparameters()
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
return self.net(x)[1]
def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
net_optim, linear_probe_optim, cluster_probe_optim = self.optimizers()
net_optim.zero_grad()
linear_probe_optim.zero_grad()
cluster_probe_optim.zero_grad()
with torch.no_grad():
ind = batch["ind"]
img = batch["img"]
img_aug = batch["img_aug"]
coord_aug = batch["coord_aug"]
img_pos = batch["img_pos"]
label = batch["label"]
label_pos = batch["label_pos"]
feats, code = self.net(img)
if self.cfg.correspondence_weight > 0:
feats_pos, code_pos = self.net(img_pos)
log_args = dict(sync_dist=False, rank_zero_only=True)
if self.cfg.use_true_labels:
signal = one_hot_feats(label + 1, self.n_classes + 1)
signal_pos = one_hot_feats(label_pos + 1, self.n_classes + 1)
else:
signal = feats
signal_pos = feats_pos
loss = 0
should_log_hist = (
(self.cfg.hist_freq is not None)
and (self.global_step % self.cfg.hist_freq == 0)
and (self.global_step > 0)
)
if self.cfg.use_salience:
salience = batch["mask"].to(torch.float32).squeeze(1)
salience_pos = batch["mask_pos"].to(torch.float32).squeeze(1)
else:
salience = None
salience_pos = None
if self.cfg.correspondence_weight > 0:
(
pos_intra_loss,
pos_intra_cd,
pos_inter_loss,
pos_inter_cd,
neg_inter_loss,
neg_inter_cd,
) = self.contrastive_corr_loss_fn(
signal,
signal_pos,
salience,
salience_pos,
code,
code_pos,
)
if should_log_hist:
self.logger.experiment.add_histogram(
"intra_cd", pos_intra_cd, self.global_step
)
self.logger.experiment.add_histogram(
"inter_cd", pos_inter_cd, self.global_step
)
self.logger.experiment.add_histogram(
"neg_cd", neg_inter_cd, self.global_step
)
neg_inter_loss = neg_inter_loss.mean()
pos_intra_loss = pos_intra_loss.mean()
pos_inter_loss = pos_inter_loss.mean()
self.log("loss/pos_intra", pos_intra_loss, **log_args)
self.log("loss/pos_inter", pos_inter_loss, **log_args)
self.log("loss/neg_inter", neg_inter_loss, **log_args)
self.log("cd/pos_intra", pos_intra_cd.mean(), **log_args)
self.log("cd/pos_inter", pos_inter_cd.mean(), **log_args)
self.log("cd/neg_inter", neg_inter_cd.mean(), **log_args)
loss += (
self.cfg.pos_inter_weight * pos_inter_loss
+ self.cfg.pos_intra_weight * pos_intra_loss
+ self.cfg.neg_inter_weight * neg_inter_loss
) * self.cfg.correspondence_weight
if self.cfg.rec_weight > 0:
rec_feats = self.decoder(code)
rec_loss = -(norm(rec_feats) * norm(feats)).sum(1).mean()
self.log("loss/rec", rec_loss, **log_args)
loss += self.cfg.rec_weight * rec_loss
if self.cfg.aug_alignment_weight > 0:
orig_feats_aug, orig_code_aug = self.net(img_aug)
downsampled_coord_aug = resize(
coord_aug.permute(0, 3, 1, 2), orig_code_aug.shape[2]
).permute(0, 2, 3, 1)
aug_alignment = -torch.einsum(
"bkhw,bkhw->bhw",
norm(sample(code, downsampled_coord_aug)),
norm(orig_code_aug),
).mean()
self.log("loss/aug_alignment", aug_alignment, **log_args)
loss += self.cfg.aug_alignment_weight * aug_alignment
if self.cfg.crf_weight > 0:
crf = self.crf_loss_fn(resize(img, 56), norm(resize(code, 56))).mean()
self.log("loss/crf", crf, **log_args)
loss += self.cfg.crf_weight * crf
flat_label = label.reshape(-1)
mask = (flat_label >= 0) & (flat_label < self.n_classes)
detached_code = torch.clone(code.detach())
# pdb.set_trace()
linear_logits = self.linear_probe(img, detached_code)
linear_logits = F.interpolate(
linear_logits, label.shape[-2:], mode="bilinear", align_corners=False
)
linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes)
linear_loss = self.linear_probe_loss_fn(
linear_logits[mask], flat_label[mask]
).mean()
loss += linear_loss
self.log("loss/linear", linear_loss, **log_args)
cluster_loss, cluster_probs = self.cluster_probe(detached_code, None)
loss += cluster_loss
self.log("loss/cluster", cluster_loss, **log_args)
self.log("loss/total", loss, **log_args)
self.manual_backward(loss)
net_optim.step()
cluster_probe_optim.step()
linear_probe_optim.step()
if (
self.cfg.reset_probe_steps is not None
and self.global_step == self.cfg.reset_probe_steps
):
print("RESETTING PROBES")
self.linear_probe.reset_parameters()
self.cluster_probe.reset_parameters()
self.trainer.optimizers[1] = torch.optim.Adam(
list(self.linear_probe.parameters()), lr=5e-3
)
self.trainer.optimizers[2] = torch.optim.Adam(
list(self.cluster_probe.parameters()), lr=5e-3
)
if self.global_step % 2000 == 0 and self.global_step > 0:
print("RESETTING TFEVENT FILE")
# Make a new tfevent file
self.logger.experiment.close()
self.logger.experiment._get_file_writer()
return loss
def on_train_start(self):
tb_metrics = {**self.linear_metrics.compute(), **self.cluster_metrics.compute()}
self.logger.log_hyperparams(self.cfg, tb_metrics)
def validation_step(self, batch, batch_idx):
img = batch["img"]
label = batch["label"]
self.net.eval()
with torch.no_grad():
feats, code = self.net(img)
# code = F.interpolate(code, label.shape[-2:], mode='bilinear', align_corners=False)
# linear_preds = self.linear_probe(code)
linear_preds = self.linear_probe(img, code)
linear_preds = linear_preds.argmax(1)
self.linear_metrics.update(linear_preds, label)
code = F.interpolate(
code, label.shape[-2:], mode="bilinear", align_corners=False
)
cluster_loss, cluster_preds = self.cluster_probe(code, None)
cluster_preds = cluster_preds.argmax(1)
self.cluster_metrics.update(cluster_preds, label)
return {
"img": img[: self.cfg.n_images].detach().cpu(),
"linear_preds": linear_preds[: self.cfg.n_images].detach().cpu(),
"cluster_preds": cluster_preds[: self.cfg.n_images].detach().cpu(),
"label": label[: self.cfg.n_images].detach().cpu(),
}
def validation_epoch_end(self, outputs) -> None:
super().validation_epoch_end(outputs)
with torch.no_grad():
tb_metrics = {
**self.linear_metrics.compute(),
**self.cluster_metrics.compute(),
}
if self.trainer.is_global_zero and not self.cfg.submitting_to_aml:
# output_num = 0
output_num = random.randint(0, len(outputs) - 1)
output = {k: v.detach().cpu() for k, v in outputs[output_num].items()}
# pdb.set_trace()
alpha = 0.4
n_rows = 6
fig, ax = plt.subplots(
n_rows,
self.cfg.n_images,
figsize=(self.cfg.n_images * 3, n_rows * 3),
)
for i in range(self.cfg.n_images):
try:
rbg_img = prep_for_plot(output["img"][i])
true_label = output["label"].squeeze()[i]
true_label[true_label == -1] = 7
except:
continue
# ax[0, i].imshow(prep_for_plot(output["img"][i]))
# ax[1, i].imshow(self.label_cmap[output["label"].squeeze()[i]])
# ax[2, i].imshow(self.label_cmap[output["linear_preds"][i]])
# ax[3, i].imshow(self.label_cmap[self.cluster_metrics.map_clusters(output["cluster_preds"][i])])
ax[0, i].imshow(rbg_img)
ax[1, i].imshow(rbg_img)
ax[1, i].imshow(true_label, alpha=alpha, cmap=cmap, norm=norm)
ax[2, i].imshow(rbg_img)
pred_label = output["linear_preds"][i]
ax[2, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm)
ax[3, i].imshow(rbg_img)
retouched_label = retouch_label(pred_label.numpy(), true_label)
ax[3, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm)
ax[4, i].imshow(rbg_img)
pred_label = self.cluster_metrics.map_clusters(
output["cluster_preds"][i]
)
ax[4, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm)
# ax[3, i].imshow(map_clusters_with_label(true_label, pred_label), alpha=0.5, cmap=cmap, norm=norm)
ax[5, i].imshow(rbg_img)
retouched_label = retouch_label(pred_label.numpy(), true_label)
ax[5, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm)
ax[0, 0].set_ylabel("Image", fontsize=16)
ax[1, 0].set_ylabel("Label", fontsize=16)
ax[2, 0].set_ylabel("UNet Probe", fontsize=16)
ax[3, 0].set_ylabel("Retouched UNet Probe", fontsize=16)
ax[4, 0].set_ylabel("Cluster Probe", fontsize=16)
ax[5, 0].set_ylabel("Retouched cluster Probe", fontsize=16)
remove_axes(ax)
plt.tight_layout()
add_plot(self.logger.experiment, "plot_labels", self.global_step)
if self.cfg.has_labels:
fig = plt.figure(figsize=(13, 10))
ax = fig.gca()
hist = (
self.cluster_metrics.histogram.detach().cpu().to(torch.float32)
)
hist /= torch.clamp_min(hist.sum(dim=0, keepdim=True), 1)
sns.heatmap(hist.t(), annot=False, fmt="g", ax=ax, cmap="Blues")
ax.set_xlabel("Predicted labels")
ax.set_ylabel("True labels")
names = get_class_labels(self.cfg.dataset_name)
if self.cfg.extra_clusters:
names = names + ["Extra"]
ax.set_xticks(np.arange(0, len(names)) + 0.5)
ax.set_yticks(np.arange(0, len(names)) + 0.5)
ax.xaxis.tick_top()
ax.xaxis.set_ticklabels(names, fontsize=14)
ax.yaxis.set_ticklabels(names, fontsize=14)
colors = [self.label_cmap[i] / 255.0 for i in range(len(names))]
[
t.set_color(colors[i])
for i, t in enumerate(ax.xaxis.get_ticklabels())
]
[
t.set_color(colors[i])
for i, t in enumerate(ax.yaxis.get_ticklabels())
]
# ax.yaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
# ax.xaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
plt.xticks(rotation=90)
plt.yticks(rotation=0)
ax.vlines(
np.arange(0, len(names) + 1),
color=[0.5, 0.5, 0.5],
*ax.get_xlim()
)
ax.hlines(
np.arange(0, len(names) + 1),
color=[0.5, 0.5, 0.5],
*ax.get_ylim()
)
plt.tight_layout()
add_plot(self.logger.experiment, "conf_matrix", self.global_step)
all_bars = torch.cat(
[
self.cluster_metrics.histogram.sum(0).cpu(),
self.cluster_metrics.histogram.sum(1).cpu(),
],
axis=0,
)
ymin = max(all_bars.min() * 0.8, 1)
ymax = all_bars.max() * 1.2
fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 1 * 4))
ax[0].bar(
range(self.n_classes + self.cfg.extra_clusters),
self.cluster_metrics.histogram.sum(0).cpu(),
tick_label=names,
color=colors,
)
ax[0].set_ylim(ymin, ymax)
ax[0].set_title("Label Frequency")
ax[0].set_yscale("log")
ax[0].tick_params(axis="x", labelrotation=90)
ax[1].bar(
range(self.n_classes + self.cfg.extra_clusters),
self.cluster_metrics.histogram.sum(1).cpu(),
tick_label=names,
color=colors,
)
ax[1].set_ylim(ymin, ymax)
ax[1].set_title("Cluster Frequency")
ax[1].set_yscale("log")
ax[1].tick_params(axis="x", labelrotation=90)
plt.tight_layout()
add_plot(
self.logger.experiment, "label frequency", self.global_step
)
if self.global_step > 2:
self.log_dict(tb_metrics)
if self.trainer.is_global_zero and self.cfg.azureml_logging:
from azureml.core.run import Run
run_logger = Run.get_context()
for metric, value in tb_metrics.items():
run_logger.log(metric, value)
self.linear_metrics.reset()
self.cluster_metrics.reset()
def configure_optimizers(self):
main_params = list(self.net.parameters())
if self.cfg.rec_weight > 0:
main_params.extend(self.decoder.parameters())
net_optim = torch.optim.Adam(main_params, lr=self.cfg.lr)
linear_probe_optim = torch.optim.Adam(
list(self.linear_probe.parameters()), lr=5e-3
)
cluster_probe_optim = torch.optim.Adam(
list(self.cluster_probe.parameters()), lr=5e-3
)
return net_optim, linear_probe_optim, cluster_probe_optim