|
from abc import abstractmethod |
|
|
|
import torch |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from torch import Tensor |
|
from torch.nn import functional as F |
|
from torchvision.transforms import v2 as T |
|
|
|
from .common import ensure_tuple |
|
from .vit import VisionTransformer, vit_base_dreamsim |
|
|
|
|
|
class DreamsimBackbone(ModelMixin, ConfigMixin): |
|
@abstractmethod |
|
def forward_features(self, x: Tensor) -> Tensor: |
|
raise NotImplementedError("abstract base class was called ;_;") |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
"""Dreamsim forward pass for similarity computation. |
|
Args: |
|
x (Tensor): Input tensor of shape [2, B, 3, H, W]. |
|
|
|
Returns: |
|
sim (torch.Tensor): dreamsim similarity score of shape [B]. |
|
""" |
|
inputs = x.view(-1, 3, *x.shape[-2:]) |
|
|
|
x = self.forward_features(inputs).view(*x.shape[:2], -1) |
|
|
|
return 1 - F.cosine_similarity(x[0], x[1], dim=1) |
|
|
|
def compile(self, *args, **kwargs): |
|
"""Compile the model with Inductor. This is a no-op unless overridden by a subclass.""" |
|
return self |
|
|
|
|
|
class DreamsimModel(DreamsimBackbone): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
image_size: int = 224, |
|
patch_size: int = 16, |
|
layer_norm_eps: float = 1e-6, |
|
pre_norm: bool = False, |
|
act_layer: str = "gelu", |
|
img_mean: tuple[float, float, float] = (0.485, 0.456, 0.406), |
|
img_std: tuple[float, float, float] = (0.229, 0.224, 0.225), |
|
do_resize: bool = False, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.image_size = ensure_tuple(image_size, 2) |
|
self.patch_size = ensure_tuple(patch_size, 2) |
|
self.layer_norm_eps = layer_norm_eps |
|
self.pre_norm = pre_norm |
|
self.do_resize = do_resize |
|
self.img_mean = img_mean |
|
self.img_std = img_std |
|
|
|
num_classes = 512 if self.pre_norm else 0 |
|
self.extractor: VisionTransformer = vit_base_dreamsim( |
|
image_size=image_size, |
|
patch_size=patch_size, |
|
layer_norm_eps=layer_norm_eps, |
|
num_classes=num_classes, |
|
pre_norm=pre_norm, |
|
act_layer=act_layer, |
|
) |
|
|
|
self.resize = T.Resize( |
|
self.image_size, |
|
interpolation=T.InterpolationMode.BICUBIC, |
|
antialias=True, |
|
) |
|
self.img_norm = T.Normalize(mean=self.img_mean, std=self.img_std) |
|
|
|
def compile(self, *, mode: str = "reduce-overhead", force: bool = False, **kwargs): |
|
if (not self._compiled) or force: |
|
self.extractor = torch.compile(self.extractor, mode=mode, **kwargs) |
|
self._compiled = True |
|
return self |
|
|
|
def transforms(self, x: Tensor) -> Tensor: |
|
if self.do_resize: |
|
x = self.resize(x) |
|
return self.img_norm(x) |
|
|
|
def forward_features(self, x: Tensor) -> Tensor: |
|
if x.ndim == 3: |
|
x = x.unsqueeze(0) |
|
x = self.transforms(x) |
|
x = self.extractor.forward(x, norm=self.pre_norm) |
|
|
|
x = x.div(x.norm(dim=1, keepdim=True)) |
|
x = x.sub(x.mean(dim=1, keepdim=True)) |
|
return x |
|
|
|
|
|
class DreamsimEnsemble(DreamsimBackbone): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
image_size: int = 224, |
|
patch_size: int = 16, |
|
layer_norm_eps: float | tuple[float, ...] = (1e-6, 1e-5, 1e-5), |
|
num_classes: int | tuple[int, ...] = (0, 512, 512), |
|
do_resize: bool = False, |
|
) -> None: |
|
super().__init__() |
|
if isinstance(layer_norm_eps, float): |
|
layer_norm_eps = (layer_norm_eps,) * 3 |
|
if isinstance(num_classes, int): |
|
num_classes = (num_classes,) * 3 |
|
|
|
self.image_size = ensure_tuple(image_size, 2) |
|
self.patch_size = ensure_tuple(patch_size, 2) |
|
self.do_resize = do_resize |
|
|
|
self.dino: VisionTransformer = vit_base_dreamsim( |
|
image_size=self.image_size, |
|
patch_size=self.patch_size, |
|
layer_norm_eps=layer_norm_eps[0], |
|
num_classes=num_classes[0], |
|
pre_norm=False, |
|
act_layer="gelu", |
|
) |
|
self.clip1: VisionTransformer = vit_base_dreamsim( |
|
image_size=self.image_size, |
|
patch_size=self.patch_size, |
|
layer_norm_eps=layer_norm_eps[1], |
|
num_classes=num_classes[1], |
|
pre_norm=True, |
|
act_layer="quick_gelu", |
|
) |
|
self.clip2: VisionTransformer = vit_base_dreamsim( |
|
image_size=self.image_size, |
|
patch_size=self.patch_size, |
|
layer_norm_eps=layer_norm_eps[2], |
|
num_classes=num_classes[2], |
|
pre_norm=True, |
|
act_layer="gelu", |
|
) |
|
|
|
self.resize = T.Resize( |
|
self.image_size, |
|
interpolation=T.InterpolationMode.BICUBIC, |
|
antialias=True, |
|
) |
|
self.dino_norm = T.Normalize( |
|
mean=(0.485, 0.456, 0.406), |
|
std=(0.229, 0.224, 0.225), |
|
) |
|
self.clip_norm = T.Normalize( |
|
mean=(0.48145466, 0.4578275, 0.40821073), |
|
std=(0.26862954, 0.26130258, 0.27577711), |
|
) |
|
|
|
self._compiled = False |
|
|
|
def compile(self, *, mode: str = "reduce-overhead", force: bool = False, **kwargs): |
|
if (not self._compiled) or force: |
|
self.dino = torch.compile(self.dino, mode=mode, **kwargs) |
|
self.clip1 = torch.compile(self.clip1, mode=mode, **kwargs) |
|
self.clip2 = torch.compile(self.clip2, mode=mode, **kwargs) |
|
self._compiled = True |
|
return self |
|
|
|
def transforms(self, x: Tensor, resize: bool = False) -> tuple[Tensor, Tensor, Tensor]: |
|
if resize: |
|
x = self.resize(x) |
|
x = self.dino_norm(x), self.clip_norm(x), self.clip_norm(x) |
|
return x |
|
|
|
def forward_features(self, x: Tensor) -> Tensor: |
|
if x.ndim == 3: |
|
x = x.unsqueeze(0) |
|
x_dino, x_clip1, x_clip2 = self.transforms(x, self.do_resize) |
|
|
|
|
|
x_dino = self.dino.forward(x_dino, norm=False) |
|
x_clip1 = self.clip1.forward(x_clip1, norm=True) |
|
x_clip2 = self.clip2.forward(x_clip2, norm=True) |
|
|
|
z: Tensor = torch.cat([x_dino, x_clip1, x_clip2], dim=1) |
|
z = z.div(z.norm(dim=1, keepdim=True)) |
|
z = z.sub(z.mean(dim=1, keepdim=True)) |
|
return z |
|
|