File size: 4,908 Bytes
74e4bcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
class CLIPLoss(nn.Module):
def __init__(self):
super().__init__()
self.labels = None
self.last_local_batch_size = None
def forward(self, outputs):
image_embed = outputs['image_embed']
text_embed = outputs['text_embed']
logit_scale = outputs['logit_scale']
local_batch_size = image_embed.size(0)
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * utils.get_rank() + torch.arange(
local_batch_size, device=image_embed.device
)
self.last_local_batch_size = local_batch_size
# normalized features
image_embed = F.normalize(image_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
# gather features from all GPUs
image_embed_all, text_embed_all = \
utils.all_gather_batch([image_embed, text_embed])
# cosine similarity as logits
logits_per_image = logit_scale * image_embed @ text_embed_all.t()
logits_per_text = logit_scale * text_embed @ image_embed_all.t()
loss = (F.cross_entropy(logits_per_image, self.labels) + \
F.cross_entropy(logits_per_text, self.labels)) / 2
# compute accuracy
with torch.no_grad():
pred = torch.argmax(logits_per_image, dim=-1)
correct = pred.eq(self.labels).sum()
acc = 100 * correct / local_batch_size
return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc}
class SIMCLRLoss(nn.Module):
"""
This is the SimCLR loss in https://arxiv.org/abs/2002.05709
The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and
the memory layout that can be reshaped into shape (2, batch_size, embedding_dim).
This memory layout is consistent with the SimCLR collator in
https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py
Config params:
temperature (float): the temperature to be applied on the logits
"""
def __init__(self, temperature=0.1):
super().__init__()
self.tau = temperature
self.labels = None
self.masks = None
self.last_local_batch_size = None
def forward(self, outputs):
q_a = outputs['aug1_embed']
q_b = outputs['aug2_embed']
q_a = F.normalize(q_a, dim=-1, p=2)
q_b = F.normalize(q_b, dim=-1, p=2)
local_batch_size = q_a.size(0)
k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b])
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * utils.get_rank() + torch.arange(
local_batch_size, device=q_a.device
)
total_batch_size = local_batch_size * utils.get_world_size()
self.masks = F.one_hot(self.labels, total_batch_size) * 1e9
self.last_local_batch_size = local_batch_size
logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau
logits_aa = logits_aa - self.masks
logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau
logits_bb = logits_bb - self.masks
logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau
logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau
loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels)
loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels)
loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples
# compute accuracy
with torch.no_grad():
pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1)
correct = pred.eq(self.labels).sum()
acc = 100 * correct / local_batch_size
return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc}
class SLIPLoss(nn.Module):
def __init__(self, ssl_loss, ssl_scale):
super().__init__()
self.clip_loss = CLIPLoss()
self.ssl_loss = ssl_loss
self.ssl_scale = ssl_scale
def forward(self, outputs):
clip_loss_dict = self.clip_loss(outputs)
clip_loss = clip_loss_dict['clip_loss']
clip_acc = clip_loss_dict['clip_acc']
ssl_loss_dict = self.ssl_loss(outputs)
ssl_loss = ssl_loss_dict['ssl_loss']
ssl_acc = ssl_loss_dict['ssl_acc']
return {'loss': clip_loss + self.ssl_scale * ssl_loss,
'clip_loss': clip_loss,
'clip_acc': clip_acc,
'ssl_loss': ssl_loss,
'ssl_acc': ssl_acc}
|