Transformers
Inference Endpoints
neggles commited on
Commit
f664757
1 Parent(s): 2a50b55
Files changed (5) hide show
  1. __init__.py +9 -0
  2. common.py +38 -0
  3. model.py +173 -0
  4. utils.py +160 -0
  5. vit.py +375 -0
__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .model import DreamsimEnsemble, DreamsimModel
2
+ from .vit import VisionTransformer, vit_base_dreamsim
3
+
4
+ __all__ = [
5
+ "DreamsimModel",
6
+ "DreamsimEnsemble",
7
+ "VisionTransformer",
8
+ "vit_base_dreamsim",
9
+ ]
common.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def ensure_tuple(val: int | tuple[int, ...], n: int = 2) -> tuple[int, ...]:
9
+ if isinstance(val, int):
10
+ return (val,) * n
11
+ elif len(val) != n:
12
+ raise ValueError(f"Expected a tuple of {n} values, but got {len(val)}: {val}")
13
+ return val
14
+
15
+
16
+ def use_fused_attn():
17
+ if hasattr(F, "scaled_dot_product_attention"):
18
+ return True
19
+ return False
20
+
21
+
22
+ class QuickGELU(nn.Module):
23
+ """
24
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
25
+ """
26
+
27
+ def forward(self, input: Tensor) -> Tensor:
28
+ return input * torch.sigmoid(1.702 * input)
29
+
30
+
31
+ def get_act_layer(name: str) -> Callable[[], nn.Module]:
32
+ match name:
33
+ case "gelu":
34
+ return nn.GELU
35
+ case "quick_gelu":
36
+ return QuickGELU
37
+ case _:
38
+ raise ValueError(f"Activation layer {name} not supported.")
model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from torch import Tensor
5
+ from torch.nn import functional as F
6
+ from torchvision.transforms import v2 as T
7
+
8
+ from .common import ensure_tuple
9
+ from .vit import VisionTransformer, vit_base_dreamsim
10
+
11
+
12
+ class DreamsimModel(ModelMixin, ConfigMixin):
13
+ @register_to_config
14
+ def __init__(
15
+ self,
16
+ image_size: int = 224,
17
+ patch_size: int = 16,
18
+ layer_norm_eps: float = 1e-6,
19
+ pre_norm: bool = False,
20
+ act_layer: str = "gelu",
21
+ img_mean: tuple[float, float, float] = (0.485, 0.456, 0.406),
22
+ img_std: tuple[float, float, float] = (0.229, 0.224, 0.225),
23
+ do_resize: bool = False,
24
+ ) -> None:
25
+ super().__init__()
26
+
27
+ self.image_size = ensure_tuple(image_size, 2)
28
+ self.patch_size = patch_size
29
+ self.layer_norm_eps = layer_norm_eps
30
+ self.pre_norm = pre_norm
31
+ self.do_resize = do_resize
32
+ self.img_mean = img_mean
33
+ self.img_std = img_std
34
+
35
+ num_classes = 512 if self.pre_norm else 0
36
+ self.extractor: VisionTransformer = vit_base_dreamsim(
37
+ image_size=image_size,
38
+ patch_size=patch_size,
39
+ layer_norm_eps=layer_norm_eps,
40
+ num_classes=num_classes,
41
+ pre_norm=pre_norm,
42
+ act_layer=act_layer,
43
+ )
44
+
45
+ self.resize = T.Resize(
46
+ self.image_size,
47
+ interpolation=T.InterpolationMode.BICUBIC,
48
+ antialias=True,
49
+ )
50
+ self.img_norm = T.Normalize(mean=self.img_mean, std=self.img_std)
51
+
52
+ def transforms(self, x: Tensor) -> Tensor:
53
+ if self.do_resize:
54
+ x = self.resize(x)
55
+ return self.img_norm(x)
56
+
57
+ def forward_features(self, x: Tensor) -> Tensor:
58
+ if x.ndim == 3:
59
+ x = x.unsqueeze(0)
60
+ x = self.transforms(x)
61
+ x = self.extractor.forward(x, norm=self.pre_norm)
62
+
63
+ x.div_(x.norm(dim=1, keepdim=True))
64
+ x.sub_(x.mean(dim=1, keepdim=True))
65
+ return x
66
+
67
+ def forward(self, x: Tensor) -> Tensor:
68
+ """Dreamsim forward pass for similarity computation.
69
+ Args:
70
+ x (Tensor): Input tensor of shape [2, B, 3, H, W].
71
+
72
+ Returns:
73
+ sim (torch.Tensor): dreamsim similarity score of shape [B].
74
+ """
75
+ all_images = x.view(-1, 3, *x.shape[-2:])
76
+
77
+ x = self.forward_features(all_images)
78
+ x = x.view(*x.shape[:2], -1)
79
+
80
+ return 1 - F.cosine_similarity(x[0], x[1], dim=1)
81
+
82
+
83
+ class DreamsimEnsemble(ModelMixin, ConfigMixin):
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ image_size: int = 224,
88
+ patch_size: int = 16,
89
+ layer_norm_eps: float | tuple[float, ...] = (1e-6, 1e-5, 1e-5),
90
+ num_classes: tuple[int, int, int] = (0, 512, 512),
91
+ do_resize: bool = False,
92
+ ) -> None:
93
+ super().__init__()
94
+ if isinstance(layer_norm_eps, float):
95
+ layer_norm_eps = (layer_norm_eps,) * 3
96
+
97
+ self.image_size = ensure_tuple(image_size, 2)
98
+ self.patch_size = patch_size
99
+ self.do_resize = do_resize
100
+
101
+ self.dino: VisionTransformer = vit_base_dreamsim(
102
+ image_size=self.image_size,
103
+ patch_size=self.patch_size,
104
+ layer_norm_eps=layer_norm_eps[0],
105
+ num_classes=num_classes[0],
106
+ pre_norm=False,
107
+ act_layer="gelu",
108
+ )
109
+ self.clip1: VisionTransformer = vit_base_dreamsim(
110
+ image_size=self.image_size,
111
+ patch_size=self.patch_size,
112
+ layer_norm_eps=layer_norm_eps[1],
113
+ num_classes=num_classes[1],
114
+ pre_norm=True,
115
+ act_layer="quick_gelu",
116
+ )
117
+ self.clip2: VisionTransformer = vit_base_dreamsim(
118
+ image_size=self.image_size,
119
+ patch_size=self.patch_size,
120
+ layer_norm_eps=layer_norm_eps[2],
121
+ num_classes=num_classes[2],
122
+ pre_norm=True,
123
+ act_layer="gelu",
124
+ )
125
+
126
+ self.resize = T.Resize(
127
+ self.image_size,
128
+ interpolation=T.InterpolationMode.BICUBIC,
129
+ antialias=True,
130
+ )
131
+ self.dino_norm = T.Normalize(
132
+ mean=(0.485, 0.456, 0.406),
133
+ std=(0.229, 0.224, 0.225),
134
+ )
135
+ self.clip_norm = T.Normalize(
136
+ mean=(0.48145466, 0.4578275, 0.40821073),
137
+ std=(0.26862954, 0.26130258, 0.27577711),
138
+ )
139
+
140
+ def transforms(self, x: Tensor, resize: bool = False) -> tuple[Tensor, Tensor, Tensor]:
141
+ if resize:
142
+ x = self.resize(x)
143
+ return self.dino_norm(x), self.clip_norm(x), self.clip_norm(x)
144
+
145
+ def forward_features(self, x: Tensor) -> Tensor:
146
+ if x.ndim == 3:
147
+ x = x.unsqueeze(0)
148
+ x_dino, x_clip1, x_clip2 = self.transforms(x, self.do_resize)
149
+
150
+ # these expect to always receive a batch, and will return a batch
151
+ x_dino = self.dino.forward(x_dino, norm=False)
152
+ x_clip1 = self.clip1.forward(x_clip1, norm=True)
153
+ x_clip2 = self.clip2.forward(x_clip2, norm=True)
154
+
155
+ z: Tensor = torch.cat([x_dino, x_clip1, x_clip2], dim=1)
156
+ z.div_(z.norm(dim=1, keepdim=True))
157
+ z.sub_(z.mean(dim=1, keepdim=True))
158
+ return z
159
+
160
+ def forward(self, x: Tensor) -> Tensor:
161
+ """Dreamsim forward pass for similarity computation.
162
+ Args:
163
+ x (Tensor): Input tensor of shape [2, B, 3, H, W].
164
+
165
+ Returns:
166
+ sim (torch.Tensor): dreamsim similarity score of shape [B].
167
+ """
168
+ all_images = x.view(-1, 3, *x.shape[-2:])
169
+
170
+ x = self.forward_features(all_images)
171
+ x = x.view(*x.shape[:2], -1)
172
+
173
+ return 1 - F.cosine_similarity(x[0], x[1], dim=1)
utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions in this file are courtesty of @ashen-sensored on GitHub - thankyou so much! <3
3
+
4
+ Used to merge DreamSim LoRA weights into the base ViT models manually, so we don't need
5
+ to use an ancient version of PeFT that is no longer supported (and kind of broken)
6
+ """
7
+ import logging
8
+ from os import PathLike
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from safetensors.torch import load_file
13
+ from torch import Tensor, nn
14
+
15
+ from .model import DreamsimModel
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @torch.no_grad()
21
+ def calculate_merged_weight(
22
+ lora_a: Tensor,
23
+ lora_b: Tensor,
24
+ base: Tensor,
25
+ scale: float,
26
+ qkv_switches: list[bool],
27
+ ) -> Tensor:
28
+ n_switches = len(qkv_switches)
29
+ n_groups = sum(qkv_switches)
30
+
31
+ qkv_mask = torch.tensor(qkv_switches, dtype=torch.bool).reshape(len(qkv_switches), -1)
32
+ qkv_mask = qkv_mask.broadcast_to((-1, base.shape[0] // n_switches)).reshape(-1)
33
+
34
+ lora_b = lora_b.squeeze()
35
+ delta_w = base.new_zeros(lora_b.shape[0], base.shape[1])
36
+
37
+ grp_in_ch = lora_a.shape[0] // n_groups
38
+ grp_out_ch = lora_b.shape[0] // n_groups
39
+ for i in range(n_groups):
40
+ islice = slice(i * grp_in_ch, (i + 1) * grp_in_ch)
41
+ oslice = slice(i * grp_out_ch, (i + 1) * grp_out_ch)
42
+ delta_w[oslice, :] = lora_b[oslice, :] @ lora_a[islice, :]
43
+
44
+ delta_w_full = base.new_zeros(base.shape)
45
+ delta_w_full[qkv_mask, :] = delta_w
46
+
47
+ merged = base + scale * delta_w_full
48
+ return merged.to(base)
49
+
50
+
51
+ @torch.no_grad()
52
+ def merge_dreamsim_lora(
53
+ base_model: nn.Module,
54
+ lora_path: PathLike,
55
+ torch_device: torch.device | str = torch.device("cpu"),
56
+ ):
57
+ lora_path = Path(lora_path)
58
+ # make sure model is on device
59
+ base_model = base_model.eval().requires_grad_(False).to(torch_device)
60
+
61
+ # load the lora
62
+ if lora_path.suffix.lower() in [".pt", ".pth", ".bin"]:
63
+ lora_sd = torch.load(lora_path, map_location=torch_device, weights_only=True)
64
+ elif lora_path.suffix.lower() == ".safetensors":
65
+ lora_sd = load_file(lora_path)
66
+ else:
67
+ raise ValueError(f"Unsupported file extension '{lora_path.suffix}'")
68
+
69
+ # these loras were created by a cursed PEFT version, okay? so we have to do some crimes.
70
+ group_prefix = "base_model.model.base_model.model.model."
71
+ # get all lora weights for qkv layers, stripping the insane prefix
72
+ group_weights = {k.replace(group_prefix, ""): v for k, v in lora_sd.items() if k.startswith(group_prefix)}
73
+ # strip ".lora_X.weight" from keys to match against base model keys
74
+ group_layers = set([k.rsplit(".", 2)[0] for k in group_weights.keys()])
75
+
76
+ base_weights = base_model.state_dict()
77
+ for key in [x for x in base_weights.keys() if "attn.qkv.weight" in x]:
78
+ param_name = key.rsplit(".", 1)[0]
79
+ if param_name not in group_layers:
80
+ logger.warning(f"QKV param '{param_name}' not found in lora weights")
81
+ continue
82
+ new_weight = calculate_merged_weight(
83
+ group_weights[f"{param_name}.lora_A.weight"],
84
+ group_weights[f"{param_name}.lora_B.weight"],
85
+ base_weights[key],
86
+ 0.5 / 16,
87
+ [True, False, True],
88
+ )
89
+ base_weights[key] = new_weight
90
+
91
+ base_model.load_state_dict(base_weights)
92
+ return base_model.requires_grad_(False)
93
+
94
+
95
+ def remap_clip(state_dict: dict[str, Tensor], variant: str) -> dict[str, Tensor]:
96
+ """Remap keys from the original DreamSim checkpoint to match new model structure."""
97
+
98
+ def prepend_extractor(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
99
+ if variant.endswith("single"):
100
+ return {f"extractor.{k}": v for k, v in state_dict.items()}
101
+ return state_dict
102
+
103
+ if "clip" not in variant:
104
+ return prepend_extractor(state_dict)
105
+
106
+ if "patch_embed.proj.bias" in state_dict:
107
+ _ = state_dict.pop("patch_embed.proj.bias", None)
108
+ if "pos_drop.weight" in state_dict:
109
+ state_dict["norm_pre.weight"] = state_dict.pop("pos_drop.weight")
110
+ state_dict["norm_pre.bias"] = state_dict.pop("pos_drop.bias")
111
+ if "head.weight" in state_dict and "head.bias" not in state_dict:
112
+ state_dict["head.bias"] = torch.zeros(state_dict["head.weight"].shape[0])
113
+
114
+ return prepend_extractor(state_dict)
115
+
116
+
117
+ def convert_dreamsim_single(
118
+ ckpt_path: PathLike,
119
+ variant: str,
120
+ ensemble: bool = False,
121
+ ) -> DreamsimModel:
122
+ ckpt_path = Path(ckpt_path)
123
+ if ckpt_path.exists():
124
+ if ckpt_path.is_dir():
125
+ ckpt_path = ckpt_path.joinpath("ensemble" if ensemble else variant)
126
+ ckpt_path = ckpt_path.joinpath(f"{variant}_merged.safetensors")
127
+
128
+ # defaults are for dino, overridden as needed below
129
+ patch_size = 16
130
+ layer_norm_eps = 1e-6
131
+ pre_norm = False
132
+ act_layer = "gelu"
133
+
134
+ match variant:
135
+ case "open_clip_vitb16" | "open_clip_vitb32" | "clip_vitb16" | "clip_vitb32":
136
+ patch_size = 32 if "b32" in variant else 16
137
+ layer_norm_eps = 1e-5
138
+ pre_norm = True
139
+ img_mean = (0.48145466, 0.4578275, 0.40821073)
140
+ img_std = (0.26862954, 0.26130258, 0.27577711)
141
+ act_layer = "quick_gelu" if variant.startswith("clip_") else "gelu"
142
+ case "dino_vitb16":
143
+ img_mean = (0.485, 0.456, 0.406)
144
+ img_std = (0.229, 0.224, 0.225)
145
+ case _:
146
+ raise NotImplementedError(f"Unsupported model variant '{variant}'")
147
+
148
+ model: DreamsimModel = DreamsimModel(
149
+ image_size=224,
150
+ patch_size=patch_size,
151
+ layer_norm_eps=layer_norm_eps,
152
+ pre_norm=pre_norm,
153
+ act_layer=act_layer,
154
+ img_mean=img_mean,
155
+ img_std=img_std,
156
+ )
157
+ state_dict = load_file(ckpt_path, device="cpu")
158
+ state_dict = remap_clip(state_dict)
159
+ model.extractor.load_state_dict(state_dict)
160
+ return model
vit.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mostly copy-paste from timm library.
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ """
18
+ import math
19
+ from functools import partial
20
+ from typing import Callable, Final, Optional, Sequence
21
+
22
+ import torch
23
+ from torch import Tensor, nn
24
+ from torch.nn import functional as F
25
+
26
+ from .common import ensure_tuple, get_act_layer, use_fused_attn
27
+
28
+
29
+ def vit_weights_init(module: nn.Module) -> None:
30
+ if isinstance(module, nn.Linear):
31
+ nn.init.trunc_normal_(module.weight, std=0.02)
32
+ if module.bias is not None:
33
+ nn.init.zeros_(module.bias)
34
+ elif isinstance(module, nn.LayerNorm):
35
+ nn.init.ones_(module.weight)
36
+ nn.init.zeros_(module.bias)
37
+
38
+
39
+ class DropPath(nn.Module):
40
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
41
+
42
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
43
+ super(DropPath, self).__init__()
44
+ self.drop_prob = drop_prob
45
+ self.scale_by_keep = scale_by_keep
46
+
47
+ def forward(self, x: Tensor) -> Tensor:
48
+ if self.drop_prob == 0 or not self.training:
49
+ return x
50
+ keep_prob = 1 - self.drop_prob
51
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
52
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
53
+ if keep_prob > 0.0 and self.scale_by_keep:
54
+ random_tensor.div_(keep_prob)
55
+ return x * random_tensor
56
+
57
+ def extra_repr(self):
58
+ return f"drop_prob={self.drop_prob:0.3f}"
59
+
60
+
61
+ class Mlp(nn.Module):
62
+ def __init__(
63
+ self,
64
+ in_features: int,
65
+ hidden_features: Optional[int] = None,
66
+ out_features: Optional[int] = None,
67
+ act_layer: Callable[[], nn.Module] = nn.GELU,
68
+ drop: float = 0.0,
69
+ ):
70
+ super().__init__()
71
+ out_features = out_features or in_features
72
+ hidden_features = hidden_features or in_features
73
+ self.fc1 = nn.Linear(in_features, hidden_features)
74
+ self.act = act_layer()
75
+ self.fc2 = nn.Linear(hidden_features, out_features)
76
+ self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity()
77
+
78
+ def forward(self, x: Tensor) -> Tensor:
79
+ x = self.fc1(x)
80
+ x = self.act(x)
81
+ x = self.drop(x)
82
+ x = self.fc2(x)
83
+ x = self.drop(x)
84
+ return x
85
+
86
+
87
+ class Attention(nn.Module):
88
+ fused_attn: Final[bool]
89
+
90
+ def __init__(
91
+ self,
92
+ dim: int,
93
+ num_heads: int = 8,
94
+ qkv_bias: bool = False,
95
+ qk_scale: Optional[float] = None,
96
+ attn_drop: float = 0.0,
97
+ proj_drop: float = 0.0,
98
+ ):
99
+ super().__init__()
100
+ self.num_heads = num_heads
101
+ self.head_dim = dim // num_heads
102
+ self.scale = qk_scale or self.head_dim**-0.5
103
+ self.fused_attn = use_fused_attn()
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0.0 else nn.Identity()
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
109
+
110
+ def forward(self, x: Tensor) -> Tensor:
111
+ B, N, C = x.shape
112
+ qkv: Tensor = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
113
+ q, k, v = qkv.unbind(0)
114
+
115
+ if self.fused_attn:
116
+ dropout_p = getattr(self.attn_drop, "p", 0.0) if self.training else 0.0
117
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
118
+ else:
119
+ q = q * self.scale
120
+ attn = q @ k.transpose(-2, -1)
121
+ attn = attn.softmax(dim=-1)
122
+ attn = self.attn_drop(attn)
123
+ x = attn @ v
124
+
125
+ x = x.transpose(1, 2).reshape(B, N, C)
126
+ x = self.proj(x)
127
+ x = self.proj_drop(x)
128
+ return x
129
+
130
+
131
+ class Block(nn.Module):
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ num_heads: int,
136
+ mlp_ratio: float = 4.0,
137
+ qkv_bias: bool = False,
138
+ drop: float = 0.0,
139
+ attn_drop: float = 0.0,
140
+ drop_path: float = 0.0,
141
+ act_layer: Callable[[], nn.Module] = nn.GELU,
142
+ norm_layer: Callable[[], nn.Module] = nn.LayerNorm,
143
+ ):
144
+ super().__init__()
145
+ self.norm1 = norm_layer(dim)
146
+ self.attn = Attention(
147
+ dim,
148
+ num_heads=num_heads,
149
+ qkv_bias=qkv_bias,
150
+ attn_drop=attn_drop,
151
+ proj_drop=drop,
152
+ )
153
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
154
+ self.norm2 = norm_layer(dim)
155
+ mlp_hidden_dim = int(dim * mlp_ratio)
156
+ self.mlp = Mlp(
157
+ in_features=dim,
158
+ hidden_features=mlp_hidden_dim,
159
+ act_layer=act_layer,
160
+ drop=drop,
161
+ )
162
+
163
+ def forward(self, x: Tensor) -> Tensor:
164
+ x = x + self.drop_path(self.attn(self.norm1(x)))
165
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
166
+ return x
167
+
168
+
169
+ class PatchEmbed(nn.Module):
170
+ """Image to Patch Embedding"""
171
+
172
+ def __init__(
173
+ self,
174
+ img_size: int | tuple[int, int] = 224,
175
+ patch_size: int | tuple[int, int] = 16,
176
+ in_chans: int = 3,
177
+ embed_dim: int = 768,
178
+ bias: bool = True,
179
+ dynamic_pad: bool = False,
180
+ ):
181
+ super().__init__()
182
+ self.img_size = ensure_tuple(img_size)
183
+ self.patch_size = ensure_tuple(patch_size)
184
+ self.num_patches = (img_size // patch_size) ** 2
185
+
186
+ self.dynamic_pad = dynamic_pad
187
+
188
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
189
+
190
+ def forward(self, x: Tensor) -> Tensor:
191
+ _, _, H, W = x.shape
192
+ if self.dynamic_pad:
193
+ pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
194
+ pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
195
+ x = F.pad(x, (0, pad_w, 0, pad_h))
196
+ x = self.proj(x)
197
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
198
+ return x
199
+
200
+
201
+ class VisionTransformer(nn.Module):
202
+ """Vision Transformer"""
203
+
204
+ def __init__(
205
+ self,
206
+ img_size: int | tuple[int, int] = 224,
207
+ patch_size: int | tuple[int, int] = 16,
208
+ in_chans: int = 3,
209
+ num_classes: int = 0,
210
+ embed_dim: int = 768,
211
+ depth: int = 12,
212
+ num_heads: int = 12,
213
+ mlp_ratio: float = 4.0,
214
+ qkv_bias: bool = False,
215
+ pre_norm: bool = False,
216
+ drop_rate: float = 0.0,
217
+ attn_drop_rate: float = 0.0,
218
+ drop_path_rate: float = 0.0,
219
+ norm_layer: Callable[[], nn.Module] = nn.LayerNorm,
220
+ act_layer: Callable[[], nn.Module] = nn.GELU,
221
+ skip_init: bool = False,
222
+ dynamic_pad: bool = False,
223
+ **kwargs,
224
+ ):
225
+ super().__init__()
226
+ self.img_size = img_size
227
+ self.patch_size = patch_size
228
+ self.num_classes = num_classes
229
+ self.num_features = self.embed_dim = embed_dim
230
+ self.depth = depth
231
+
232
+ self.patch_embed = PatchEmbed(
233
+ img_size=img_size,
234
+ patch_size=patch_size,
235
+ in_chans=in_chans,
236
+ embed_dim=embed_dim,
237
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
238
+ dynamic_pad=dynamic_pad,
239
+ )
240
+ num_patches = self.patch_embed.num_patches
241
+ embed_len = num_patches + 1 # num_patches + 1 for the [CLS] token
242
+
243
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
244
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, embed_dim))
245
+ self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
246
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
247
+
248
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)] # stochastic depth decay rule
249
+ self.blocks: list[Block] = nn.ModuleList(
250
+ [
251
+ Block(
252
+ dim=embed_dim,
253
+ num_heads=num_heads,
254
+ mlp_ratio=mlp_ratio,
255
+ qkv_bias=qkv_bias,
256
+ drop=drop_rate,
257
+ attn_drop=attn_drop_rate,
258
+ drop_path=dpr[i],
259
+ act_layer=act_layer,
260
+ norm_layer=norm_layer,
261
+ )
262
+ for i in range(self.depth)
263
+ ]
264
+ )
265
+ self.norm = norm_layer(embed_dim)
266
+
267
+ # Classifier head
268
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
269
+
270
+ if not skip_init:
271
+ self.reset_parameters()
272
+
273
+ def reset_parameters(self):
274
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
275
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
276
+ self.apply(vit_weights_init)
277
+
278
+ def interpolate_pos_encoding(self, x: Tensor, w: Tensor, h: Tensor) -> Tensor:
279
+ npatch = x.shape[1] - 1
280
+ N = self.pos_embed.shape[1] - 1
281
+ if npatch == N and w == h:
282
+ return self.pos_embed
283
+ class_pos_embed = self.pos_embed[:, 0]
284
+ patch_pos_embed = self.pos_embed[:, 1:]
285
+ dim = x.shape[-1]
286
+ w0 = w // self.patch_embed.patch_size[0]
287
+ h0 = h // self.patch_embed.patch_size[0]
288
+ # we add a small number to avoid floating point error in the interpolation
289
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
290
+ w0, h0 = w0 + 0.1, h0 + 0.1
291
+ patch_pos_embed = nn.functional.interpolate(
292
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
293
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
294
+ mode="bicubic",
295
+ )
296
+ if int(w0) != patch_pos_embed.shape[-2] or int(h0) != patch_pos_embed.shape[-1]:
297
+ raise ValueError("Error in positional encoding interpolation.")
298
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
299
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
300
+
301
+ def prepare_tokens(self, x: Tensor) -> Tensor:
302
+ B, _, W, H = x.shape
303
+ x = self.patch_embed(x) # patch linear embedding
304
+
305
+ # add the [CLS] token to the embed patch tokens
306
+ cls_tokens = self.cls_token.expand(B, -1, -1)
307
+ x = torch.cat((cls_tokens, x), dim=1)
308
+
309
+ # add positional encoding to each token
310
+ x = x + self.interpolate_pos_encoding(x, W, H)
311
+
312
+ return self.pos_drop(x)
313
+
314
+ def forward(self, x: Tensor, norm: bool = True) -> Tensor:
315
+ x = self.forward_features(x, norm=norm)
316
+ x = self.forward_head(x)
317
+ return x
318
+
319
+ def forward_features(self, x: Tensor, norm: bool = True) -> Tensor:
320
+ x = self.prepare_tokens(x)
321
+ x = self.norm_pre(x)
322
+ for blk in self.blocks:
323
+ x = blk(x)
324
+ if norm:
325
+ x = self.norm(x)
326
+ return x[:, 0]
327
+
328
+ def forward_head(self, x: Tensor) -> Tensor:
329
+ x = self.head(x)
330
+ return x
331
+
332
+ def get_intermediate_layers(
333
+ self,
334
+ x: Tensor,
335
+ n: int | Sequence[int] = 1,
336
+ norm: bool = True,
337
+ ) -> list[Tensor]:
338
+ # we return the output tokens from the `n` last blocks
339
+ outputs = []
340
+ layer_indices = set(range(self.depth - n, self.depth) if isinstance(n, int) else n)
341
+ x = self.prepare_tokens(x)
342
+ x = self.norm_pre(x)
343
+
344
+ for idx, blk in enumerate(self.blocks):
345
+ x = blk(x)
346
+ if idx in layer_indices:
347
+ outputs.append(x)
348
+ if norm:
349
+ outputs = [self.norm(x) for x in outputs]
350
+ return outputs
351
+
352
+
353
+ def vit_base_dreamsim(
354
+ patch_size: int = 16,
355
+ layer_norm_eps: float = 1e-6,
356
+ num_classes: int = 512,
357
+ act_layer: str | Callable[[], nn.Module] = "gelu",
358
+ **kwargs,
359
+ ):
360
+ if isinstance(act_layer, str):
361
+ act_layer = get_act_layer(act_layer)
362
+
363
+ model = VisionTransformer(
364
+ patch_size=patch_size,
365
+ num_classes=num_classes,
366
+ embed_dim=768,
367
+ depth=12,
368
+ num_heads=12,
369
+ mlp_ratio=4,
370
+ qkv_bias=True,
371
+ norm_layer=partial(nn.LayerNorm, eps=layer_norm_eps),
372
+ act_layer=act_layer,
373
+ **kwargs,
374
+ )
375
+ return model