File size: 5,788 Bytes
f664757 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
"""
Functions in this file are courtesty of @ashen-sensored on GitHub - thankyou so much! <3
Used to merge DreamSim LoRA weights into the base ViT models manually, so we don't need
to use an ancient version of PeFT that is no longer supported (and kind of broken)
"""
import logging
from os import PathLike
from pathlib import Path
import torch
from safetensors.torch import load_file
from torch import Tensor, nn
from .model import DreamsimModel
logger = logging.getLogger(__name__)
@torch.no_grad()
def calculate_merged_weight(
lora_a: Tensor,
lora_b: Tensor,
base: Tensor,
scale: float,
qkv_switches: list[bool],
) -> Tensor:
n_switches = len(qkv_switches)
n_groups = sum(qkv_switches)
qkv_mask = torch.tensor(qkv_switches, dtype=torch.bool).reshape(len(qkv_switches), -1)
qkv_mask = qkv_mask.broadcast_to((-1, base.shape[0] // n_switches)).reshape(-1)
lora_b = lora_b.squeeze()
delta_w = base.new_zeros(lora_b.shape[0], base.shape[1])
grp_in_ch = lora_a.shape[0] // n_groups
grp_out_ch = lora_b.shape[0] // n_groups
for i in range(n_groups):
islice = slice(i * grp_in_ch, (i + 1) * grp_in_ch)
oslice = slice(i * grp_out_ch, (i + 1) * grp_out_ch)
delta_w[oslice, :] = lora_b[oslice, :] @ lora_a[islice, :]
delta_w_full = base.new_zeros(base.shape)
delta_w_full[qkv_mask, :] = delta_w
merged = base + scale * delta_w_full
return merged.to(base)
@torch.no_grad()
def merge_dreamsim_lora(
base_model: nn.Module,
lora_path: PathLike,
torch_device: torch.device | str = torch.device("cpu"),
):
lora_path = Path(lora_path)
# make sure model is on device
base_model = base_model.eval().requires_grad_(False).to(torch_device)
# load the lora
if lora_path.suffix.lower() in [".pt", ".pth", ".bin"]:
lora_sd = torch.load(lora_path, map_location=torch_device, weights_only=True)
elif lora_path.suffix.lower() == ".safetensors":
lora_sd = load_file(lora_path)
else:
raise ValueError(f"Unsupported file extension '{lora_path.suffix}'")
# these loras were created by a cursed PEFT version, okay? so we have to do some crimes.
group_prefix = "base_model.model.base_model.model.model."
# get all lora weights for qkv layers, stripping the insane prefix
group_weights = {k.replace(group_prefix, ""): v for k, v in lora_sd.items() if k.startswith(group_prefix)}
# strip ".lora_X.weight" from keys to match against base model keys
group_layers = set([k.rsplit(".", 2)[0] for k in group_weights.keys()])
base_weights = base_model.state_dict()
for key in [x for x in base_weights.keys() if "attn.qkv.weight" in x]:
param_name = key.rsplit(".", 1)[0]
if param_name not in group_layers:
logger.warning(f"QKV param '{param_name}' not found in lora weights")
continue
new_weight = calculate_merged_weight(
group_weights[f"{param_name}.lora_A.weight"],
group_weights[f"{param_name}.lora_B.weight"],
base_weights[key],
0.5 / 16,
[True, False, True],
)
base_weights[key] = new_weight
base_model.load_state_dict(base_weights)
return base_model.requires_grad_(False)
def remap_clip(state_dict: dict[str, Tensor], variant: str) -> dict[str, Tensor]:
"""Remap keys from the original DreamSim checkpoint to match new model structure."""
def prepend_extractor(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
if variant.endswith("single"):
return {f"extractor.{k}": v for k, v in state_dict.items()}
return state_dict
if "clip" not in variant:
return prepend_extractor(state_dict)
if "patch_embed.proj.bias" in state_dict:
_ = state_dict.pop("patch_embed.proj.bias", None)
if "pos_drop.weight" in state_dict:
state_dict["norm_pre.weight"] = state_dict.pop("pos_drop.weight")
state_dict["norm_pre.bias"] = state_dict.pop("pos_drop.bias")
if "head.weight" in state_dict and "head.bias" not in state_dict:
state_dict["head.bias"] = torch.zeros(state_dict["head.weight"].shape[0])
return prepend_extractor(state_dict)
def convert_dreamsim_single(
ckpt_path: PathLike,
variant: str,
ensemble: bool = False,
) -> DreamsimModel:
ckpt_path = Path(ckpt_path)
if ckpt_path.exists():
if ckpt_path.is_dir():
ckpt_path = ckpt_path.joinpath("ensemble" if ensemble else variant)
ckpt_path = ckpt_path.joinpath(f"{variant}_merged.safetensors")
# defaults are for dino, overridden as needed below
patch_size = 16
layer_norm_eps = 1e-6
pre_norm = False
act_layer = "gelu"
match variant:
case "open_clip_vitb16" | "open_clip_vitb32" | "clip_vitb16" | "clip_vitb32":
patch_size = 32 if "b32" in variant else 16
layer_norm_eps = 1e-5
pre_norm = True
img_mean = (0.48145466, 0.4578275, 0.40821073)
img_std = (0.26862954, 0.26130258, 0.27577711)
act_layer = "quick_gelu" if variant.startswith("clip_") else "gelu"
case "dino_vitb16":
img_mean = (0.485, 0.456, 0.406)
img_std = (0.229, 0.224, 0.225)
case _:
raise NotImplementedError(f"Unsupported model variant '{variant}'")
model: DreamsimModel = DreamsimModel(
image_size=224,
patch_size=patch_size,
layer_norm_eps=layer_norm_eps,
pre_norm=pre_norm,
act_layer=act_layer,
img_mean=img_mean,
img_std=img_std,
)
state_dict = load_file(ckpt_path, device="cpu")
state_dict = remap_clip(state_dict)
model.extractor.load_state_dict(state_dict)
return model
|