Spaces:
Running
on
Zero
Running
on
Zero
# Derived from https://github.com/microsoft/LoRA | |
# ------------------------------------------------------------------------------------------ | |
# Copyright (c) Microsoft Corporation. All rights reserved. | |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. | |
# ------------------------------------------------------------------------------------------ | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from typing import Dict, List | |
import lit_llama.model as llama | |
from contextlib import contextmanager | |
from dataclasses import dataclass | |
class LoRALayer(): | |
def __init__( | |
self, | |
r: int, | |
lora_alpha: int, | |
lora_dropout: float, | |
merge_weights: bool, | |
): | |
self.r = r | |
self.lora_alpha = lora_alpha | |
# Optional dropout | |
if lora_dropout > 0.: | |
self.lora_dropout = nn.Dropout(p=lora_dropout) | |
else: | |
self.lora_dropout = lambda x: x | |
# Mark the weight as unmerged | |
self.merged = False | |
self.merge_weights = merge_weights | |
class MergedLinear(nn.Linear, LoRALayer): | |
# LoRA implemented in a dense layer | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
r: int = 0, | |
lora_alpha: int = 1, | |
lora_dropout: float = 0., | |
enable_lora: List[bool] = [False], | |
fan_in_fan_out: bool = False, | |
merge_weights: bool = True, | |
**kwargs | |
): | |
nn.Linear.__init__(self, in_features, out_features, **kwargs) | |
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, | |
merge_weights=merge_weights) | |
assert out_features % len(enable_lora) == 0, \ | |
'The length of enable_lora must divide out_features' | |
self.enable_lora = enable_lora | |
self.fan_in_fan_out = fan_in_fan_out | |
# Actual trainable parameters | |
if r > 0 and any(enable_lora): | |
self.lora_A = nn.Parameter( | |
self.weight.new_zeros((r * sum(enable_lora), in_features))) | |
self.lora_B = nn.Parameter( | |
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) | |
) # weights for Conv1D with groups=sum(enable_lora) | |
self.scaling = self.lora_alpha / self.r | |
# Freezing the pre-trained weight matrix | |
self.weight.requires_grad = False | |
# Compute the indices | |
self.lora_ind = self.weight.new_zeros( | |
(out_features, ), dtype=torch.bool | |
).view(len(enable_lora), -1) | |
self.lora_ind[enable_lora, :] = True | |
self.lora_ind = self.lora_ind.view(-1) | |
self.reset_parameters() | |
if fan_in_fan_out: | |
self.weight.data = self.weight.data.T | |
def reset_parameters(self): | |
nn.Linear.reset_parameters(self) | |
if hasattr(self, 'lora_A'): | |
# initialize A the same way as the default for nn.Linear and B to zero | |
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) | |
nn.init.zeros_(self.lora_B) | |
def zero_pad(self, x): | |
result = x.new_zeros((*x.shape[:-1], self.out_features)) | |
result = result.view(-1, self.out_features) | |
result[:, self.lora_ind] = x.reshape( | |
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) | |
) | |
return result.view((*x.shape[:-1], self.out_features)) | |
def train(self, mode: bool = True): | |
def T(w): | |
return w.T if self.fan_in_fan_out else w | |
nn.Linear.train(self, mode) | |
if self.merge_weights and self.merged: | |
# Make sure that the weights are not merged | |
if self.r > 0 and any(self.enable_lora): | |
delta_w = F.conv1d( | |
self.lora_A.data.unsqueeze(0), | |
self.lora_B.data.unsqueeze(-1), | |
groups=sum(self.enable_lora) | |
).squeeze(0) | |
self.weight.data -= self.zero_pad(T(delta_w * self.scaling)) | |
self.merged = False | |
def eval(self): | |
def T(w): | |
return w.T if self.fan_in_fan_out else w | |
nn.Linear.eval(self) | |
if self.merge_weights and not self.merged: | |
# Merge the weights and mark it | |
if self.r > 0 and any(self.enable_lora): | |
delta_w = F.conv1d( | |
self.lora_A.data.unsqueeze(0), | |
self.lora_B.data.unsqueeze(-1), | |
groups=sum(self.enable_lora) | |
).squeeze(0) | |
self.weight.data += self.zero_pad(T(delta_w * self.scaling)) | |
self.merged = True | |
def forward(self, x: torch.Tensor): | |
def T(w): | |
return w.T if self.fan_in_fan_out else w | |
if self.merged: | |
return F.linear(x, T(self.weight), bias=self.bias) | |
else: | |
result = F.linear(x, T(self.weight), bias=self.bias) | |
if self.r > 0: | |
after_A = F.linear(self.lora_dropout(x), self.lora_A) | |
after_B = F.conv1d( | |
after_A.transpose(-2, -1), | |
self.lora_B.unsqueeze(-1), | |
groups=sum(self.enable_lora) | |
).transpose(-2, -1) | |
result += self.zero_pad(after_B) * self.scaling | |
return result | |
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: | |
# import pdb; pdb.set_trace() | |
for n, p in model.named_parameters(): | |
if 'lora_' not in n and 'motion_proj' not in n and 'llama_proj' not in n: | |
p.requires_grad = False | |
if bias == 'none': | |
return | |
elif bias == 'all': | |
for n, p in model.named_parameters(): | |
if 'bias' in n: | |
p.requires_grad = True | |
elif bias == 'lora_only': | |
for m in model.modules(): | |
if isinstance(m, LoRALayer) and \ | |
hasattr(m, 'bias') and \ | |
m.bias is not None: | |
m.bias.requires_grad = True | |
else: | |
raise NotImplementedError | |
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: | |
my_state_dict = model.state_dict() | |
if bias == 'none': | |
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'llama_proj' in k or 'motion_proj' in k} | |
elif bias == 'all': | |
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k or 'llama_proj' in k or 'motion_proj' in k} | |
elif bias == 'lora_only': | |
to_return = {} | |
for k in my_state_dict: | |
if 'lora_' in k: | |
to_return[k] = my_state_dict[k] | |
bias_name = k.split('lora_')[0]+'bias' | |
if bias_name in my_state_dict: | |
to_return[bias_name] = my_state_dict[bias_name] | |
return to_return | |
else: | |
raise NotImplementedError | |
class LoRAConfig: | |
r: float = 0.0 | |
alpha: float = 1.0 | |
dropout: float = 0.0 | |
class CausalSelfAttention(llama.CausalSelfAttention): | |
lora_config = None | |
def __init__(self, config: llama.LLaMAConfig) -> None: | |
# Skip the parent class __init__ altogether and replace it to avoid | |
# useless allocations | |
nn.Module.__init__(self) | |
assert config.n_embd % config.n_head == 0 | |
# key, query, value projections for all heads, but in a batch | |
self.c_attn = MergedLinear( | |
in_features=config.n_embd, | |
out_features=3 * config.n_embd, | |
r=self.lora_config.r, | |
lora_alpha=self.lora_config.alpha, | |
lora_dropout=self.lora_config.dropout, | |
enable_lora=[True, False, True], | |
fan_in_fan_out = False, | |
merge_weights=True, | |
bias=False) | |
# output projection | |
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
# regularization | |
self.n_head = config.n_head | |
self.n_embd = config.n_embd | |
self.block_size = config.block_size | |
self.rope_cache = None | |
def lora(r, alpha, dropout, enabled: bool = True): | |
"""A context manager under which you can instantiate the model with LoRA.""" | |
if not enabled: | |
yield | |
return | |
CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout) | |
causal_self_attention = llama.CausalSelfAttention | |
llama.CausalSelfAttention = CausalSelfAttention | |
yield | |
llama.CausalSelfAttention = causal_self_attention | |
CausalSelfAttention.lora_config = None | |