|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import utils |
|
|
|
|
|
class CLIPLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.labels = None |
|
self.last_local_batch_size = None |
|
|
|
def forward(self, outputs): |
|
image_embed = outputs['image_embed'] |
|
text_embed = outputs['text_embed'] |
|
logit_scale = outputs['logit_scale'] |
|
local_batch_size = image_embed.size(0) |
|
|
|
if local_batch_size != self.last_local_batch_size: |
|
self.labels = local_batch_size * utils.get_rank() + torch.arange( |
|
local_batch_size, device=image_embed.device |
|
) |
|
self.last_local_batch_size = local_batch_size |
|
|
|
|
|
image_embed = F.normalize(image_embed, dim=-1, p=2) |
|
text_embed = F.normalize(text_embed, dim=-1, p=2) |
|
|
|
|
|
image_embed_all, text_embed_all = \ |
|
utils.all_gather_batch([image_embed, text_embed]) |
|
|
|
|
|
logits_per_image = logit_scale * image_embed @ text_embed_all.t() |
|
logits_per_text = logit_scale * text_embed @ image_embed_all.t() |
|
|
|
loss = (F.cross_entropy(logits_per_image, self.labels) + \ |
|
F.cross_entropy(logits_per_text, self.labels)) / 2 |
|
|
|
|
|
with torch.no_grad(): |
|
pred = torch.argmax(logits_per_image, dim=-1) |
|
correct = pred.eq(self.labels).sum() |
|
acc = 100 * correct / local_batch_size |
|
|
|
return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc} |
|
|
|
|
|
class SIMCLRLoss(nn.Module): |
|
""" |
|
This is the SimCLR loss in https://arxiv.org/abs/2002.05709 |
|
The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and |
|
the memory layout that can be reshaped into shape (2, batch_size, embedding_dim). |
|
This memory layout is consistent with the SimCLR collator in |
|
https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py |
|
Config params: |
|
temperature (float): the temperature to be applied on the logits |
|
""" |
|
|
|
def __init__(self, temperature=0.1): |
|
super().__init__() |
|
self.tau = temperature |
|
self.labels = None |
|
self.masks = None |
|
self.last_local_batch_size = None |
|
|
|
def forward(self, outputs): |
|
q_a = outputs['aug1_embed'] |
|
q_b = outputs['aug2_embed'] |
|
|
|
q_a = F.normalize(q_a, dim=-1, p=2) |
|
q_b = F.normalize(q_b, dim=-1, p=2) |
|
|
|
local_batch_size = q_a.size(0) |
|
|
|
k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b]) |
|
|
|
if local_batch_size != self.last_local_batch_size: |
|
self.labels = local_batch_size * utils.get_rank() + torch.arange( |
|
local_batch_size, device=q_a.device |
|
) |
|
total_batch_size = local_batch_size * utils.get_world_size() |
|
self.masks = F.one_hot(self.labels, total_batch_size) * 1e9 |
|
self.last_local_batch_size = local_batch_size |
|
|
|
logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau |
|
logits_aa = logits_aa - self.masks |
|
logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau |
|
logits_bb = logits_bb - self.masks |
|
logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau |
|
logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau |
|
|
|
loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels) |
|
loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels) |
|
loss = (loss_a + loss_b) / 2 |
|
|
|
|
|
with torch.no_grad(): |
|
pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1) |
|
correct = pred.eq(self.labels).sum() |
|
acc = 100 * correct / local_batch_size |
|
|
|
return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc} |
|
|
|
|
|
class SLIPLoss(nn.Module): |
|
def __init__(self, ssl_loss, ssl_scale): |
|
super().__init__() |
|
self.clip_loss = CLIPLoss() |
|
self.ssl_loss = ssl_loss |
|
self.ssl_scale = ssl_scale |
|
|
|
def forward(self, outputs): |
|
clip_loss_dict = self.clip_loss(outputs) |
|
clip_loss = clip_loss_dict['clip_loss'] |
|
clip_acc = clip_loss_dict['clip_acc'] |
|
|
|
ssl_loss_dict = self.ssl_loss(outputs) |
|
ssl_loss = ssl_loss_dict['ssl_loss'] |
|
ssl_acc = ssl_loss_dict['ssl_acc'] |
|
|
|
return {'loss': clip_loss + self.ssl_scale * ssl_loss, |
|
'clip_loss': clip_loss, |
|
'clip_acc': clip_acc, |
|
'ssl_loss': ssl_loss, |
|
'ssl_acc': ssl_acc} |
|
|