T-MoENet / model /evl.py
yixin1121's picture
Upload folder using huggingface_hub
513e1fb verified
from typing import Dict, Iterable, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class Attention(nn.Module):
'''
A generalized attention module with more flexibility.
'''
def __init__(
self, q_in_dim: int, k_in_dim: int, v_in_dim: int,
qk_proj_dim: int, v_proj_dim: int, num_heads: int, out_dim: int,
return_all_features: bool = False, add_mask: bool = False, dropout: float = 0.0
):
super().__init__()
self.q_proj = nn.Linear(q_in_dim, qk_proj_dim)
self.k_proj = nn.Linear(k_in_dim, qk_proj_dim)
self.v_proj = nn.Linear(v_in_dim, v_proj_dim)
self.out_proj = nn.Linear(v_proj_dim, out_dim)
self.num_heads = num_heads
self.return_all_features = return_all_features
assert qk_proj_dim % num_heads == 0 and v_proj_dim % num_heads == 0
self.add_mask = add_mask
self._initialize_weights()
def _initialize_weights(self):
for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.)
def forward(self, q, k, v, mask):
if not self.add_mask:
mask = torch.ones_like(mask)
assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3
N = q.size(0); assert k.size(0) == N and v.size(0) == N
Lq, Lkv = q.size(1), k.size(1); assert v.size(1) == Lkv
q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
H = self.num_heads
Cqk, Cv = q.size(-1) // H, v.size(-1) // H
q = q.view(N, Lq, H, Cqk)
k = k.view(N, Lkv, H, Cqk)
v = v.view(N, Lkv, H, Cv)
aff = torch.einsum('nqhc,nkhc->nqkh', q / (Cqk ** 0.5), k)
#aff = aff.softmax(dim=-2)
rmask = ~(mask.bool())
aff = aff.masked_fill(rmask.unsqueeze(1).unsqueeze(-1).to(aff.device), float("-inf"))
aff = aff.softmax(dim = -2)
mix = torch.einsum('nqlh,nlhc->nqhc', aff, v)
out = self.out_proj(mix.flatten(-2))
if self.return_all_features:
return dict(q=q, k=k, v=v, aff=aff, out=out)
else:
return out
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
in_feature_dim: int = 768,
qkv_dim: int = 768,
num_heads: int = 12,
mlp_factor: float = 4.0,
mlp_dropout: float = 0.0,
act: nn.Module = QuickGELU,
add_mask: bool = False
):
super().__init__()
self.attn = Attention(
q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,
qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim, add_mask=add_mask
)
mlp_dim = round(mlp_factor * in_feature_dim)
self.mlp = nn.Sequential(OrderedDict([
('fc1', nn.Linear(in_feature_dim, mlp_dim)),
('act', act()),
('dropout', nn.Dropout(mlp_dropout)),
('fc2', nn.Linear(mlp_dim, in_feature_dim)),
]))
self.norm1 = LayerNorm(in_feature_dim)
self.norm2 = LayerNorm(in_feature_dim)
self.norm3 = LayerNorm(in_feature_dim)
self._initialize_weights()
def _initialize_weights(self):
for m in (self.mlp[0], self.mlp[-1]):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x, y, mask):
y_norm = self.norm3(y)
x = x + self.attn(self.norm1(x), y_norm, y_norm, mask)
x = x + self.mlp(self.norm2(x))
return x
class EVLDecoder(nn.Module):
def __init__(
self,
num_frames: int = 8,
spatial_size: Tuple[int, int] = (14, 14),
num_layers: int = 4,
in_feature_dim: int = 768,
qkv_dim: int = 768,
num_heads: int = 12,
mlp_factor: float = 4.0,
enable_temporal_conv: bool = True,
enable_temporal_pos_embed: bool = True,
mlp_dropout: float = 0.5,
add_vid_feat: bool = False,
add_mask: bool = False,
):
super().__init__()
self.num_layers = num_layers
self.add_vid_feat = add_vid_feat
if add_vid_feat:
self.decoder_layers = nn.ModuleList(
[TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, mlp_factor, mlp_dropout, add_mask=add_mask) for _ in range(num_layers)]
)
self.cls_token = nn.Parameter(torch.zeros([in_feature_dim]))
self._initialize_weights()
if enable_temporal_conv:
self.temporal_conv = nn.ModuleList(
[nn.Conv1d(in_feature_dim, in_feature_dim, kernel_size=3, stride=1, padding=1, groups=in_feature_dim) for _ in range(num_layers)]
)
# self.temporal_conv = nn.ModuleList(
# [nn.Linear(in_feature_dim, in_feature_dim) for _ in range(num_layers)]
# )
if enable_temporal_pos_embed:
self.temporal_pos_embed = nn.ParameterList(
[nn.Parameter(torch.zeros([num_frames, in_feature_dim])) for _ in range(num_layers)]
)
def _initialize_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
def forward(self, in_features, video_mask):
N, T, C = in_features.size()
if self.add_vid_feat:
x = self.cls_token.view(1, 1, -1).repeat(N, 1, 1)
frame_features = in_features
for i in range(self.num_layers):
frame_features = in_features
feat = in_features
feat = feat.permute(0, 2, 1).contiguous() # N * L, C, T
feat = self.temporal_conv[i](feat)
feat = feat.view(N, C, T).permute(0, 2, 1,).contiguous() # N, T, C
frame_features = frame_features + feat
frame_features = frame_features + self.temporal_pos_embed[i].view(1, T, C)
if self.add_vid_feat:
x = self.decoder_layers[i](x, frame_features, video_mask)
if self.add_vid_feat:
return x
return frame_features
class EVLTransformer(nn.Module):
def __init__(
self,
num_frames: int = 8,
decoder_num_layers: int = 2,
decoder_qkv_dim: int = 768,
decoder_num_heads: int = 16,
decoder_mlp_factor: float = 4.0,
enable_temporal_conv: bool = True,
enable_temporal_pos_embed: bool = True,
enable_temporal_cross_attention: bool = False,
decoder_mlp_dropout: float = 0.5,
add_video_feat: bool = False,
output_dim: int = 1536,
add_mask: bool = False
):
super().__init__()
self.decoder_num_layers = decoder_num_layers
backbone_feature_dim = 768
backbone_spatial_size = (16, 16)
self.decoder = EVLDecoder(
num_frames=num_frames,
spatial_size=backbone_spatial_size,
num_layers=decoder_num_layers,
in_feature_dim=backbone_feature_dim,
qkv_dim=decoder_qkv_dim,
num_heads=decoder_num_heads,
mlp_factor=decoder_mlp_factor,
enable_temporal_conv=enable_temporal_conv,
enable_temporal_pos_embed=enable_temporal_pos_embed,
mlp_dropout=decoder_mlp_dropout,
add_vid_feat = add_video_feat,
add_mask=add_mask
)
self.add_vid_feat = add_video_feat
if self.add_vid_feat:
self.norm = nn.LayerNorm(backbone_feature_dim)
#self.dropout = nn.Dropout(0.5)
self.proj = nn.Linear(decoder_qkv_dim, output_dim)
def forward(self, x, video_mask):
features = x
x = self.decoder(features, video_mask)
if self.add_vid_feat:
x = self.norm(x)
#x = self.dropout(x)
x = self.proj(x)
return x
class TemporalAttention(nn.Module):
def __init__(
self,
in_feature_dim: int = 768,
qkv_dim: int = 768,
num_heads: int = 8,
max_frames: int = 40,
stride: int = 4,
kernel_size: int = 4,
add_mask: bool = True,
):
super().__init__()
self.num_layers = 2
self.kernel_size = kernel_size
self.stride = stride
max_frames = (max_frames - self.kernel_size) // self.stride + 1
self.decoder_layers = nn.ModuleList(
[TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, 2.0, 0.5, add_mask=add_mask) for _ in range(self.num_layers)]
)
'''
self.attn = Attention(
q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,
qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim, add_mask=add_mask
)'''
self.temporal_pos_embed = nn.Parameter(torch.zeros([max_frames, in_feature_dim]))
self.norm = nn.LayerNorm(in_feature_dim)
def forward(self, x, video_mask):
x, video_mask = avg_1d_pool(x, self.kernel_size, self.stride, video_mask, return_mask=True)
x = x + self.temporal_pos_embed.unsqueeze(0)
for i in range(self.num_layers):
x = self.decoder_layers[i](x, x, video_mask)
#x_norm = self.norm(x)
#x = x + self.attn(x_norm, x_norm, x_norm, video_mask)
return x
def recursive_gumbel_softmax(sim, x, video_mask, topk):
# sim: bs, T
# x: bs, T, dim
feats = []
bs = x.shape[0]
idxs = torch.zeros(bs, 10)
v_masks = []
rmask = ~(video_mask.bool())
sim = sim.masked_fill(rmask.unsqueeze(1).to(sim.device), float("-inf"))
for i in range(topk):
choice = F.gumbel_softmax(sim/0.01, hard=True, dim = -1, tau=0.1).squeeze(1) # bs, T
idxs[:, i] = torch.argsort(choice, descending=True)[:, 0]
tmp = torch.sum(choice.unsqueeze(-1) * x, dim = 1, keepdim=True) # bs, dim
feats.append(tmp)
mask_tmp = video_mask[torch.arange(bs), idxs[:, i].to(torch.long)]
v_masks.append(mask_tmp)
sim = sim - choice.unsqueeze(1)
rank = torch.argsort(idxs, dim = 1)
feats = torch.cat(feats, dim= 1) # bs, 10, dim
res = [feats[torch.arange(bs), rank[:, i]] for i in range(10)]
res = torch.stack(res, dim=1)
video_mask = torch.stack(v_masks, dim=1)
video_mask = [video_mask[torch.arange(bs), rank[:, i]] for i in range(10)]
video_mask = torch.stack(video_mask, dim = 1)
return res, video_mask