Transformers
Inference Endpoints
dreamsim / utils.py
neggles's picture
add code
f664757
raw
history blame contribute delete
No virus
5.79 kB
"""
Functions in this file are courtesty of @ashen-sensored on GitHub - thankyou so much! <3
Used to merge DreamSim LoRA weights into the base ViT models manually, so we don't need
to use an ancient version of PeFT that is no longer supported (and kind of broken)
"""
import logging
from os import PathLike
from pathlib import Path
import torch
from safetensors.torch import load_file
from torch import Tensor, nn
from .model import DreamsimModel
logger = logging.getLogger(__name__)
@torch.no_grad()
def calculate_merged_weight(
lora_a: Tensor,
lora_b: Tensor,
base: Tensor,
scale: float,
qkv_switches: list[bool],
) -> Tensor:
n_switches = len(qkv_switches)
n_groups = sum(qkv_switches)
qkv_mask = torch.tensor(qkv_switches, dtype=torch.bool).reshape(len(qkv_switches), -1)
qkv_mask = qkv_mask.broadcast_to((-1, base.shape[0] // n_switches)).reshape(-1)
lora_b = lora_b.squeeze()
delta_w = base.new_zeros(lora_b.shape[0], base.shape[1])
grp_in_ch = lora_a.shape[0] // n_groups
grp_out_ch = lora_b.shape[0] // n_groups
for i in range(n_groups):
islice = slice(i * grp_in_ch, (i + 1) * grp_in_ch)
oslice = slice(i * grp_out_ch, (i + 1) * grp_out_ch)
delta_w[oslice, :] = lora_b[oslice, :] @ lora_a[islice, :]
delta_w_full = base.new_zeros(base.shape)
delta_w_full[qkv_mask, :] = delta_w
merged = base + scale * delta_w_full
return merged.to(base)
@torch.no_grad()
def merge_dreamsim_lora(
base_model: nn.Module,
lora_path: PathLike,
torch_device: torch.device | str = torch.device("cpu"),
):
lora_path = Path(lora_path)
# make sure model is on device
base_model = base_model.eval().requires_grad_(False).to(torch_device)
# load the lora
if lora_path.suffix.lower() in [".pt", ".pth", ".bin"]:
lora_sd = torch.load(lora_path, map_location=torch_device, weights_only=True)
elif lora_path.suffix.lower() == ".safetensors":
lora_sd = load_file(lora_path)
else:
raise ValueError(f"Unsupported file extension '{lora_path.suffix}'")
# these loras were created by a cursed PEFT version, okay? so we have to do some crimes.
group_prefix = "base_model.model.base_model.model.model."
# get all lora weights for qkv layers, stripping the insane prefix
group_weights = {k.replace(group_prefix, ""): v for k, v in lora_sd.items() if k.startswith(group_prefix)}
# strip ".lora_X.weight" from keys to match against base model keys
group_layers = set([k.rsplit(".", 2)[0] for k in group_weights.keys()])
base_weights = base_model.state_dict()
for key in [x for x in base_weights.keys() if "attn.qkv.weight" in x]:
param_name = key.rsplit(".", 1)[0]
if param_name not in group_layers:
logger.warning(f"QKV param '{param_name}' not found in lora weights")
continue
new_weight = calculate_merged_weight(
group_weights[f"{param_name}.lora_A.weight"],
group_weights[f"{param_name}.lora_B.weight"],
base_weights[key],
0.5 / 16,
[True, False, True],
)
base_weights[key] = new_weight
base_model.load_state_dict(base_weights)
return base_model.requires_grad_(False)
def remap_clip(state_dict: dict[str, Tensor], variant: str) -> dict[str, Tensor]:
"""Remap keys from the original DreamSim checkpoint to match new model structure."""
def prepend_extractor(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
if variant.endswith("single"):
return {f"extractor.{k}": v for k, v in state_dict.items()}
return state_dict
if "clip" not in variant:
return prepend_extractor(state_dict)
if "patch_embed.proj.bias" in state_dict:
_ = state_dict.pop("patch_embed.proj.bias", None)
if "pos_drop.weight" in state_dict:
state_dict["norm_pre.weight"] = state_dict.pop("pos_drop.weight")
state_dict["norm_pre.bias"] = state_dict.pop("pos_drop.bias")
if "head.weight" in state_dict and "head.bias" not in state_dict:
state_dict["head.bias"] = torch.zeros(state_dict["head.weight"].shape[0])
return prepend_extractor(state_dict)
def convert_dreamsim_single(
ckpt_path: PathLike,
variant: str,
ensemble: bool = False,
) -> DreamsimModel:
ckpt_path = Path(ckpt_path)
if ckpt_path.exists():
if ckpt_path.is_dir():
ckpt_path = ckpt_path.joinpath("ensemble" if ensemble else variant)
ckpt_path = ckpt_path.joinpath(f"{variant}_merged.safetensors")
# defaults are for dino, overridden as needed below
patch_size = 16
layer_norm_eps = 1e-6
pre_norm = False
act_layer = "gelu"
match variant:
case "open_clip_vitb16" | "open_clip_vitb32" | "clip_vitb16" | "clip_vitb32":
patch_size = 32 if "b32" in variant else 16
layer_norm_eps = 1e-5
pre_norm = True
img_mean = (0.48145466, 0.4578275, 0.40821073)
img_std = (0.26862954, 0.26130258, 0.27577711)
act_layer = "quick_gelu" if variant.startswith("clip_") else "gelu"
case "dino_vitb16":
img_mean = (0.485, 0.456, 0.406)
img_std = (0.229, 0.224, 0.225)
case _:
raise NotImplementedError(f"Unsupported model variant '{variant}'")
model: DreamsimModel = DreamsimModel(
image_size=224,
patch_size=patch_size,
layer_norm_eps=layer_norm_eps,
pre_norm=pre_norm,
act_layer=act_layer,
img_mean=img_mean,
img_std=img_std,
)
state_dict = load_file(ckpt_path, device="cpu")
state_dict = remap_clip(state_dict)
model.extractor.load_state_dict(state_dict)
return model