diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0de20a0f10935c72b944606b48991ce167c2e003 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +**/*.pyc +**/__pycache__ +__pycache__/ +cache_dir/ +checkpoints/ +feedback/ +temp/ +models--LanguageBind--Video-LLaVA-7B/ +*.jsonl +*.json +linghao +run.sh +examples/ +assets/task.gif \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bc7bcf743ff62ca85597e95ad7d4a36667eb42db --- /dev/null +++ b/LICENSE @@ -0,0 +1,9 @@ +License for Non-commercial Scientific Research Purposes + +IDEA grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under IDEA’s copyright interests to reproduce, distribute, and create derivative works of the text, videos, codes solely for your non-commercial research purposes. + +Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. + +Text and visualization results are owned by International Digital Economy Academy (IDEA). + +You also need to obey the original license of the dependency models/data used in this service. \ No newline at end of file diff --git a/README copy.md b/README copy.md new file mode 100644 index 0000000000000000000000000000000000000000..5534781233c8393d75839dd22eb24327bb3898a5 --- /dev/null +++ b/README copy.md @@ -0,0 +1,133 @@ +# MotionLLM: Understanding Human Behaviors from Human Motions and Videos + +![task](./assets/task.gif) + +[Ling-Hao Chen](https://lhchen.top)😎 1, 3, +[Shunlin Lu](https://shunlinlu.github.io)😎 2, 3, +[Ailing Zeng](https://ailingzeng.sit)3, +[Hao Zhang](https://haozhang534.github.io/)3, 4, +[Benyou Wang](https://wabyking.github.io/old.html)2, +[Ruimao Zhang](http://zhangruimao.site)2, +[Lei Zhang](https://leizhang.org)🤗 3 + +😎Co-first author. Listing order is random. +🤗Corresponding author. + +1Tsinghua University, +2School of Data Science, The Chinese University of Hong Kong, Shenzhen (CUHK-SZ), +3International Digital Economy Academy (IDEA), +4The Hong Kong University of Science and Technology + +

+ + + + + + + + + + + + + + + + + + +

+ +# 🤩 Abstract + +This study delves into the realm of multi-modality (i.e., video and motion modalities) human behavior understanding by leveraging the powerful capabilities of Large Language Models (LLMs). Diverging from recent LLMs designed for video-only or motion-only understanding, we argue that understanding human behavior necessitates joint modeling from both videos and motion sequences (e.g., SMPL sequences) to capture nuanced body part dynamics and semantics effectively. In light of this, we present MotionLLM, a straightforward yet effective framework for human motion understanding, captioning, and reasoning. Specifically, MotionLLM adopts a unified video-motion training strategy that leverages the complementary advantages of existing coarse video-text data and fine-grained motion-text data to glean rich spatial-temporal insights. Furthermore, we collect a substantial dataset, MoVid, comprising diverse videos, motions, captions, and instructions. Additionally, we propose the MoVid-Bench, with carefully manual annotations, for better evaluation of human behavior understanding on video and motion. Extensive experiments show the superiority of MotionLLM in the caption, spatial-temporal comprehension, and reasoning ability. + +## 🤩 Highlight Applications + +![application](./assets/application.png) + +## 🔧 Technical Solution + +![system](./assets/system.png) + +## 💻 Try it + +We provide a simple online [demo](https://demo.humotionx.com/) for you to try MotionLLM. Below is the guidance to deploy the demo on your local machine. + +### Step 1: Set up the environment + +```bash +pip install -r requirements.txt +``` + +### Step 2: Download the pre-trained model + + +
+ 2.1 Download the LLM + +Please follow the instruction of [Lit-GPT](https://github.com/Lightning-AI/litgpt) to prepare the LLM model (vicuna 1.5-7B). These files will be: +```bah +./checkpoints/vicuna-7b-v1.5 +├── generation_config.json +├── lit_config.json +├── lit_model.pth +├── pytorch_model-00001-of-00002.bin +├── pytorch_model-00002-of-00002.bin +├── pytorch_model.bin.index.json +├── tokenizer_config.json +└── tokenizer.model +``` + +If you have any confusion, we will update a more detailed instruction in couple of days. + +
+ +
+ 2.2 Dowload the LoRA and the projection layer of the MotionLLM + +We now release one versions of the MotionLLM checkpoints, namely `v1.0` (download [here](https://drive.google.com/drive/folders/1d_5vaL34Hs2z9ACcMXyPEfZNyMs36xKx?usp=sharing)). Opening for the suggestions to Ling-Hao Chen and Shunlin Lu. + +```bash +wget xxx +``` +Keep them in a folder named and remember the path (`LINEAR_V` and `LORA`). + +
+ +### 2.3 Run the demo + +```bash +GRADIO_TEMP_DIR=temp python app.py --lora_path $LORA --mlp_path $LINEAR_V +``` +If you have some error in downloading the huggingface model, you can try the following command with the mirror of huggingface. +```bash +HF_ENDPOINT=https://hf-mirror.com GRADIO_TEMP_DIR=temp python app.py --lora_path $LORA --mlp_path $LINEAR_V +``` +The `GRADIO_TEMP_DIR=temp` defines a temporary directory as `./temp` for the Gradio to store the data. You can change it to your own path. + +After thiess, you can open the browser and visit the local host via the command line output reminder. If it is not loaded, please change the IP address as your local IP address (via command `ifconfig`). + + +## 💼 To-Do + +- [x] Release the video demo of MotionLLM. +- [ ] Release the motion demo of MotionLLM. +- [ ] Release the MoVid dataset and MoVid-Bench. +- [ ] Release the tuning instruction of MotionLLM. + + +## 💋 Acknowledgement + + +The author team would like to deliver many thanks to many people. Qing Jiang helps a lot with some parts of manual annotation on MoVid Bench and resolves some ethics issues of MotionLLM. Jingcheng Hu provided some technical suggestions for efficient training. Shilong Liu and Bojia Zi provided some significant technical suggestions on LLM tuning. Jiale Liu, Wenhao Yang, and Chenlai Qian provided some significant suggestions for us to polish the paper. Hongyang Li helped us a lot with the figure design. Yiren Pang provided GPT API keys when our keys were temporarily out of quota. The code is on the basis of [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA), [HumanTOMATO](https://lhchen.top/HumanTOMATO/), [MotionGPT](https://github.com/qiqiApink/MotionGPT). [lit-gpt](https://github.com/Lightning-AI/litgpt), and [HumanML3D](https://github.com/EricGuo5513/HumanML3D). Thanks to all contributors! + + +## 📚 License + +This code is distributed under an [IDEA LICENSE](LICENSE). Note that our code depends on other libraries and datasets which each have their own respective licenses that must also be followed. + + +If you have any question, please contact at: thu [DOT] lhchen [AT] gmail [DOT] com AND shunlinlu0803 [AT] gmail [DOT] com. + diff --git a/app copy.py b/app copy.py new file mode 100644 index 0000000000000000000000000000000000000000..664d53fd960e0a17b875c8e06bb6abe9e792216a --- /dev/null +++ b/app copy.py @@ -0,0 +1,661 @@ +import shutil +import subprocess + +import torch +import gradio as gr +from fastapi import FastAPI +import os +from PIL import Image +import tempfile +from decord import VideoReader, cpu +import uvicorn +from transformers import TextStreamer + +import hashlib +import os +import sys +import time +import warnings +from pathlib import Path +from typing import Optional +from typing import Dict, List, Literal, Optional, Tuple +from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable + +import lightning as L +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from generate import generate as generate_ +from lit_llama import Tokenizer, LLaMA, LLaMAConfig +from lit_llama.lora import lora +from lit_llama.utils import EmptyInitOnDevice +from lit_gpt.utils import lazy_load +from scripts.video_dataset.prepare_video_dataset_video_llava import generate_prompt_mlp +from options import option +import imageio +from tqdm import tqdm + +from models.multimodal_encoder.builder import build_image_tower, build_video_tower +from models.multimodal_projector.builder import build_vision_projector + + +title_markdown = ("""
+

MotionLLM: Understanding Human Behaviors from Human Motions and Videos

+

+ Ling-Hao Chen😎 1, 3, + Shunlin Lu😎 2, 3, +
+ Ailing Zeng3, + Hao Zhang3, 4, + Benyou Wang2, + Ruimao Zhang2, + Lei Zhang🤗 3 +

+

😎Co-first author. Listing order is random.🤗Corresponding author.

+

+ 1THU   + 2CUHK (SZ)   + 3IDEA Research   + 4HKUST +

+
+
+ MotionLLM +
+ +""") + +block_css = """ +#buttons button { + min-width: min(120px,100%); +} +""" + + +tos_markdown = (""" +*We are now working to support the motion branch of the MotionLLM model. + +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. +It is forbidden to use the service to generate content that is illegal, harmful, violent, racist, or sexual +The usage of this service is subject to the IDEA License. +""") + + +learn_more_markdown = (""" +### License +License for Non-commercial Scientific Research Purposes + +IDEA grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under IDEA’s copyright interests to reproduce, distribute, and create derivative works of the text, videos, codes solely for your non-commercial research purposes. + +Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. + +Text and visualization results are owned by International Digital Economy Academy (IDEA). + +You also need to obey the original license of the dependency models/data used in this service. +""") + + + +class LlavaMetaModel: + + def __init__(self, config, pretrained_checkpoint): + super(LlavaMetaModel, self).__init__() + # import pdb; pdb.set_trace() + if hasattr(config, "mm_image_tower") or hasattr(config, "image_tower"): + self.image_tower = build_image_tower(config, delay_load=True) + self.mm_projector = build_vision_projector(config) + if hasattr(config, "mm_video_tower") or hasattr(config, "video_tower"): + self.video_tower = build_video_tower(config, delay_load=True) + self.mm_projector = build_vision_projector(config) + self.load_video_tower_pretrained(pretrained_checkpoint) + + def get_image_tower(self): + image_tower = getattr(self, 'image_tower', None) + if type(image_tower) is list: + image_tower = image_tower[0] + return image_tower + + def get_video_tower(self): + video_tower = getattr(self, 'video_tower', None) + + if type(video_tower) is list: + video_tower = video_tower[0] + return video_tower + + + def get_all_tower(self, keys): + tower = {key: getattr(self, f'get_{key}_tower') for key in keys} + return tower + + + def load_video_tower_pretrained(self, pretrained_checkpoint): + self.mm_projector.load_state_dict(pretrained_checkpoint, strict=True) + + + def initialize_image_modules(self, model_args, fsdp=None): + image_tower = model_args.image_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + + self.config.mm_image_tower = image_tower + + image_tower = build_image_tower(model_args) + + if fsdp is not None and len(fsdp) > 0: + self.image_tower = [image_tower] + else: + self.image_tower = image_tower + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = image_tower.hidden_size + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + + self.mm_projector = build_vision_projector(self.config) + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + + def initialize_video_modules(self, model_args, fsdp=None): + video_tower = model_args.video_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + + self.config.mm_video_tower = video_tower + + video_tower = build_video_tower(model_args) + + if fsdp is not None and len(fsdp) > 0: + self.video_tower = [video_tower] + else: + self.video_tower = video_tower + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = video_tower.hidden_size + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + + self.mm_projector = build_vision_projector(self.config) + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + + def encode_images(self, images): + image_features = self.get_image_tower()(images) + image_features = self.mm_projector(image_features) + return image_features + + def encode_videos(self, videos): + # import pdb; pdb.set_trace() + # videos: torch.Size([1, 3, 8, 224, 224]) + video_features = self.get_video_tower()(videos) # torch.Size([1, 2048, 1024]) + video_features = self.mm_projector(video_features.float()) # torch.Size([1, 2048, 4096]) + return video_features + + def get_multimodal_embeddings(self, X_modalities): + Xs, keys= X_modalities + + X_features = getattr(self, f'encode_{keys[0]}s')(Xs) # expand to get batchsize + + return X_features + + +class Projection(nn.Module): + def __init__(self, ): + super().__init__() + self.linear_proj = nn.Linear(512, 4096) + def forward(self, x): + return self.linear_proj(x) + + +class ProjectionNN(nn.Module): + def __init__(self, ): + super().__init__() + self.proj = nn.Sequential( + nn.Linear(512, 4096), + nn.GELU(), + nn.Linear(4096, 4096) + ) + def forward(self, x): + return self.proj(x) + + +class Conversation(): + def __init__(self, output=None, input_prompt=None, prompt=None): + if output is None: + self.messages = [] + else: + self.messages = [] + self.append_message(prompt, input_prompt, output) + + def append_message(self, output, input_prompt, prompt, show_images): + # print(output) + # print(input_prompt) + # print(prompt) + # print(show_images) + self.messages.append((output, input_prompt, prompt, show_images)) + + def to_gradio_chatbot(self, show_images=None, output_text=None): + # return a list + if show_images is None: + show_images = self.messages[-1][3] + output_text = self.messages[-1][0] + return [ + [show_images, output_text] + ] + + def get_info(self): + return self.messages[-1][0], self.messages[-1][1] + + +class ConversationBuffer(): + def __init__(self, input_text): + self.buffer_ = [] + self.buffer.append(input_text) + + +def init_conv(): + conv = Conversation() + return conv + + +def get_processor(X, config, device, pretrained_checkpoint_tower, model_path = 'LanguageBind/MotionLLM-7B'): + mm_backbone_mlp_model = LlavaMetaModel(config, pretrained_checkpoint_tower) + + processor = {} + if 'Image' in X: + image_tower = mm_backbone_mlp_model.get_image_tower() # LanguageBindImageTower() + if not image_tower.is_loaded: + image_tower.load_model() + image_tower.to(device=device, dtype=torch.float16) + image_processor = image_tower.image_processor + processor['image'] = image_processor + if 'Video' in X: + video_tower = mm_backbone_mlp_model.get_video_tower() + if not video_tower.is_loaded: + video_tower.load_model() + video_tower.to(device=device, dtype=torch.float16) + video_processor = video_tower.video_processor + processor['video'] = video_processor + + return mm_backbone_mlp_model, processor + + +def motionllm( + args, + input_video_path: str, + text_en_in: str, + quantize: Optional[str] = None, + dtype: str = "float32", + max_new_tokens: int = 200, + top_k: int = 200, + temperature: float = 0.8, + accelerator: str = "auto",): + + video_tensor = video_processor(input_video_path, return_tensors='pt')['pixel_values'] + + if type(video_tensor) is list: + tensor = [video.to('cuda', dtype=torch.float16) for video in video_tensor] + else: + tensor = video_tensor.to('cuda', dtype=torch.float16) # (1,3,8,224,224) + + X_modalities = [tensor,['video']] + video_feature = mm_backbone_mlp_model.get_multimodal_embeddings(X_modalities) + prompt = text_en_in + input_prompt = prompt + + sample = {"instruction": prompt, "input": input_video_path} + + prefix = generate_prompt_mlp(sample) + pre = torch.cat((tokenizer.encode(prefix.split('INPUT_VIDEO: ')[0] + "\n", bos=True, eos=False, device=model.device).view(1, -1), tokenizer.encode("INPUT_VIDEO: ", bos=False, eos=False, device=model.device).view(1, -1)), dim=1) + + prompt = (pre, ". ASSISTANT: ") + encoded = (prompt[0], video_feature[0], tokenizer.encode(prompt[1], bos=False, eos=False, device=model.device).view(1, -1)) + + t0 = time.perf_counter() + + output_seq = generate_( + model, + idx=encoded, + max_seq_length=4096, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + tokenizer = tokenizer, + ) + outputfull = tokenizer.decode(output_seq) + output = outputfull.split("ASSISTANT:")[-1].strip() + print("================================") + print(output) + print("================================") + + return output, input_prompt, prompt + + +def save_image_to_local(image): + filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') + image = Image.open(image) + image.save(filename) + # print(filename) + return filename + + +def save_video_to_local(video_path): + filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') + shutil.copyfile(video_path, filename) + return filename + + +def generate(image1, video, textbox_in, first_run, state, images_tensor): + flag = 1 + + image1 = image1 if image1 else "none" + video = video if video else "none" + + if type(state) is not Conversation: + state = init_conv() + images_tensor = [[], []] + + first_run = False if len(state.messages) > 0 else True + text_en_in = textbox_in.replace("picture", "image") + output, input_prompt, prompt = motionllm(args, video, text_en_in) + + text_en_out = output + textbox_out = text_en_out + + show_images = "" + if os.path.exists(image1): + filename = save_image_to_local(image1) + show_images += f'' + + if os.path.exists(video): + filename = save_video_to_local(video) + show_images += f'' + + show_images = textbox_in + "\n" + show_images + state.append_message(output, input_prompt, prompt, show_images) + + torch.cuda.empty_cache() + + return (state, state.to_gradio_chatbot(show_images, output), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) + +def regenerate(state): + if len(state.messages) > 0: + tobot = state.to_gradio_chatbot() + tobot[-1][1] = None + textbox = state.messages[-1][1] + state.messages.pop(-1) + return state, tobot, False, textbox + return (state, [], True) + + +def clear_history(state): + state = init_conv() + try: + tgt = state.to_gradio_chatbot() + except: + tgt = [None, None] + return (gr.update(value=None, interactive=True), + gr.update(value=None, interactive=True),\ + gr.update(value=None, interactive=True),\ + True, state, tgt, [[], []]) + + +def get_md5(file_path): + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def logging_up(video, state): + try: + state.get_info() + except: + return False + action = "upvote" + # Get the current time + current_time = str(time.time()) + + # Create an md5 object + hash_object = hashlib.md5(current_time.encode()) + + # Get the hexadecimal representation of the hash + md5_hash = get_md5(video) + "-" + hash_object.hexdigest() + + command = f"cp {video} ./feedback/{action}/mp4/{md5_hash}.mp4" + os.system(command) + with open (f"./feedback/{action}/txt/{md5_hash}.txt", "w") as f: + out, prp = state.get_info() + f.write(f"==========\nPrompt: {prp}\n==========\nOutput: {out}==========\n") + return True + + +def logging_down(video, state): + try: + state.get_info() + except: + return False + action = "downvote" + # Get the current time + current_time = str(time.time()) + + # Create an md5 object + hash_object = hashlib.md5(current_time.encode()) + + # Get the hexadecimal representation of the hash + md5_hash = get_md5(video) + "-" + hash_object.hexdigest() + + command = f"cp {video} ./feedback/{action}/mp4/{md5_hash}.mp4" + os.system(command) + with open (f"./feedback/{action}/txt/{md5_hash}.txt", "w") as f: + out, prp = state.get_info() + f.write(f"==========\nPrompt: {prp}\n==========\nOutput: {out}==========\n") + return True + + +torch.set_float32_matmul_precision("high") +warnings.filterwarnings('ignore') +args = option.get_args_parser() + +conv_mode = "llava_v1" +model_path = 'LanguageBind/Video-LLaVA-7B' +device = 'cuda' +load_8bit = False +load_4bit = True +dtype = torch.float16 + +if not os.path.exists("temp"): + os.makedirs("temp") + +lora_path = Path(args.lora_path) +pretrained_llm_path = Path(f"./checkpoints/vicuna-7b-v1.5/lit_model.pth") +tokenizer_llm_path = Path("./checkpoints/vicuna-7b-v1.5/tokenizer.model") + +# assert lora_path.is_file() +assert pretrained_llm_path.is_file() +assert tokenizer_llm_path.is_file() + +accelerator = "auto" +fabric = L.Fabric(accelerator=accelerator, devices=1) + +dtype = "float32" +dt = getattr(torch, dtype, None) +if not isinstance(dt, torch.dtype): + raise ValueError(f"{dtype} is not a valid dtype.") +dtype = dt + +quantize = None +t0 = time.time() + +with EmptyInitOnDevice( + device=fabric.device, dtype=dtype, quantization_mode=quantize +), lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True): + checkpoint_dir = Path("checkpoints/vicuna-7b-v1.5") + lora_query = True + lora_key = False + lora_value = True + lora_projection = False + lora_mlp = False + lora_head = False + config = Config.from_name( + name=checkpoint_dir.name, + r=args.lora_r, + alpha=args.lora_alpha, + dropout=args.lora_dropout, + to_query=lora_query, + to_key=lora_key, + to_value=lora_value, + to_projection=lora_projection, + to_mlp=lora_mlp, + to_head=lora_head, + ) + model = GPT(config).bfloat16() + +mlp_path = args.mlp_path +pretrained_checkpoint_mlp = torch.load(mlp_path) + +X = ['Video'] + +mm_backbone_mlp_model, processor = get_processor(X, args, 'cuda', pretrained_checkpoint_mlp, model_path = 'LanguageBind/Video-LLaVA-7B') +video_processor = processor['video'] + +linear_proj = mm_backbone_mlp_model.mm_projector + +# 1. Load the pretrained weights +pretrained_llm_checkpoint = lazy_load(pretrained_llm_path) +# 2. Load the fine-tuned LoRA weights +lora_checkpoint = lazy_load(lora_path) +# 3. merge the two checkpoints +model_state_dict = {**pretrained_llm_checkpoint, **lora_checkpoint} +model.load_state_dict(model_state_dict, strict=True) +print('Load llm base model from', pretrained_llm_path) +print('Load lora model from', lora_path) + +# load mlp again, to en sure, not neccessary actually +linear_proj.load_state_dict(pretrained_checkpoint_mlp) +linear_proj = linear_proj.cuda() +print('Load mlp model again from', mlp_path) +print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) + +model.eval() +model = fabric.setup_module(model) +linear_proj.eval() + +tokenizer = Tokenizer(tokenizer_llm_path) +print('Load tokenizer from', tokenizer_llm_path) + +print(torch.cuda.memory_allocated()) +print(torch.cuda.max_memory_allocated()) + + +app = FastAPI() + +textbox = gr.Textbox( + show_label=False, placeholder="Enter text and press ENTER", container=False + ) + +with gr.Blocks(title='MotionLLM', theme=gr.themes.Default(), css=block_css) as demo: + gr.Markdown(title_markdown) + state = gr.State() + buffer_ = gr.State() + first_run = gr.State() + images_tensor = gr.State() + + with gr.Row(): + with gr.Column(scale=3): + image1 = gr.State() + video = gr.Video(label="Input Video") + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + gr.Examples( + examples=[ + [ + f"{cur_dir}/examples/Play_Electric_guitar_16_clip1.mp4", + "why is the girl so happy", + ], + [ + f"{cur_dir}/examples/guoyoucai.mov", + "what is the feeling of him", + ], + [ + f"{cur_dir}/examples/sprint_run_18_clip1.mp4", + "Why is the man running so fast?", + ], + [ + f"{cur_dir}/examples/lift_weight.mp4", + "Assume you are a fitness coach, refer to the video of the professional athlete, please analyze specific action essentials in steps and give detailed instruction.", + ], + [ + f"{cur_dir}/examples/Shaolin_Kung_Fu_Wushu_Selfdefense_Sword_Form_Session_22_clip3.mp4", + "wow, can you teach me the motion, step by step in detail", + ], + [ + f"{cur_dir}/examples/mabaoguo.mp4", + "why is the video funny?", + ], + [ + f"{cur_dir}/examples/COBRA_PUSH_UPS_clip2.mp4", + "describe the body movement of the woman", + ], + [ + f"{cur_dir}/examples/sample_demo_1.mp4", + "Why is this video interesting?", + ], + ], + inputs=[video, textbox], + ) + + with gr.Column(scale=7): + chatbot = gr.Chatbot(label="MotionLLM", bubble_full_width=True).style(height=875) + with gr.Row(): + with gr.Column(scale=8): + textbox.render() + with gr.Column(scale=1, min_width=50): + submit_btn = gr.Button( + value="Send", variant="primary", interactive=True + ) + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=True) + downvote_btn = gr.Button(value="👎 Downvote", interactive=True) + flag_btn = gr.Button(value="⚠️ Flag", interactive=True) + # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) + + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + + tmp = gr.State() + upvote_btn.click(logging_up, [video, state], [tmp]) + + downvote_btn.click(logging_down, [video, state], [tmp]) + + submit_btn.click(generate, [image1, video, textbox, first_run, state, images_tensor], + [state, chatbot, first_run, textbox, images_tensor, image1, video]) + + regenerate_btn.click(regenerate, [state], [state, chatbot, first_run, textbox]).then( + generate, [image1, video, textbox, first_run, state, images_tensor], [state, chatbot, first_run, textbox, images_tensor, image1, video]) + + clear_btn.click(clear_history, [state], + [image1, video, textbox, first_run, state, chatbot, images_tensor]) + +app = gr.mount_gradio_app(app, demo, path="/") +uvicorn.run(app, host="0.0.0.0", port=6657) \ No newline at end of file diff --git a/assets/application.png b/assets/application.png new file mode 100644 index 0000000000000000000000000000000000000000..1b6c5f5505e4493d9835f4963bb6083a16641533 Binary files /dev/null and b/assets/application.png differ diff --git a/assets/compare.png b/assets/compare.png new file mode 100644 index 0000000000000000000000000000000000000000..ca5978b0e95397bee174b3a9a117f7e994bcbe90 Binary files /dev/null and b/assets/compare.png differ diff --git a/assets/highlight.png b/assets/highlight.png new file mode 100644 index 0000000000000000000000000000000000000000..137a762b8c034c47d330d67a9c859779a01de564 Binary files /dev/null and b/assets/highlight.png differ diff --git a/assets/logo.png b/assets/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..bf6998d5d78932bbcc0ff65c0f8db856593d55d7 Binary files /dev/null and b/assets/logo.png differ diff --git a/assets/system.png b/assets/system.png new file mode 100644 index 0000000000000000000000000000000000000000..45094e7f958443f4c444dafccc19651270779503 Binary files /dev/null and b/assets/system.png differ diff --git a/generate.py b/generate.py new file mode 100755 index 0000000000000000000000000000000000000000..677a1714aa1fdc4891d4196be39ce20695eb348a --- /dev/null +++ b/generate.py @@ -0,0 +1,199 @@ +import sys +import time +import warnings +from pathlib import Path +from typing import Optional + +import lightning as L +import torch + +from lit_llama import LLaMA, Tokenizer +from lit_llama.utils import EmptyInitOnDevice, lazy_load + + +@torch.no_grad() +def generate( + model: torch.nn.Module, + idx: torch.Tensor, + max_new_tokens: int, + max_seq_length: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + eos_id: Optional[int] = None, + tokenizer = None, +) -> torch.Tensor: + """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + + The implementation of this function is modified from A. Karpathy's nanoGPT. + + Args: + model: The model to use. + idx: Tensor of shape (T) with indices of the prompt sequence. + max_new_tokens: The number of new tokens to generate. + max_seq_length: The maximum sequence length allowed. + temperature: Scales the predicted logits by 1 / temperature + top_k: If specified, only sample among the tokens with the k highest probabilities + eos_id: If specified, stop generating any more token once the token is triggered + """ + # create an empty tensor of the expected final shape and fill in the current tokens + # import pdb; pdb.set_trace() + if type(idx) == tuple: + # import pdb; pdb.set_trace() + T = idx[0].shape[-1] + idx[2].shape[-1] + len(idx[1]) + before_len = idx[0].shape[-1] + catted = torch.cat((idx[0], torch.zeros((1, len(idx[1]))).cuda(), idx[2]), dim=1).long() + idx = (catted, idx[1], before_len) + T_new = T + max_new_tokens + # import pdb; pdb.set_trace() + empty = torch.empty(T_new, dtype=idx[0].dtype, device=idx[0].device) + empty = torch.empty(T_new, dtype=idx[0].dtype, device=idx[0].device) + empty[:T] = idx[0] + idx = (empty, idx[1], [before_len]) + # import pdb; pdb.set_trace() + else: + # import pdb; pdb.set_trace() + T = idx.size(0) + T_new = T + max_new_tokens + empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device) + empty[:T] = idx + idx = empty + + # generate max_new_tokens tokens + # import pdb; pdb.set_trace() + for t in range(T, T_new): + if type(idx) == tuple: + idx_cond = idx[0][:t] + tmp = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:] + # import pdb; pdb.set_trace() + idx_cond = (tmp.view(1, -1), idx[1].unsqueeze(0), idx[2]) + else: + # ignore the not-filled-yet tokens + idx_cond = idx[:t] + # if the sequence context is growing too long we must crop it at max_seq_length + idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:] + + # forward + if type(idx) == tuple: + logits = model(idx_cond, maxlen=idx_cond[0].size(1)) + else: + logits = model(idx_cond.view(1, -1)) + logits = logits[0, -1] / temperature + + # import pdb; pdb.set_trace() + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[[-1]]] = -float("Inf") + + probs = torch.nn.functional.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1) + + # concatenate the new generation + if type(idx) == tuple: + seq = idx[0] + seq[t] = idx_next + idx = (seq, idx[1], idx[2]) + else: + idx[t] = idx_next + + # if token is triggered, return the output (stop generation) + if idx_next == eos_id: + if type(idx) == tuple: + return idx[0][:t+1] + else: + return idx[:t + 1] # include the EOS token + if type(idx) == tuple: + return idx[0] + else: + return idx + + +def main( + prompt: str = "Hello, my name is", + *, + num_samples: int = 1, + max_new_tokens: int = 50, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Optional[Path] = None, + tokenizer_path: Optional[Path] = None, + model_size: str = "7B", + quantize: Optional[str] = None, +) -> None: + """Generates text samples based on a pre-trained LLaMA model and tokenizer. + + Args: + prompt: The prompt string to use for generating the samples. + num_samples: The number of text samples to generate. + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + checkpoint_path: The checkpoint path to load. + tokenizer_path: The tokenizer path to load. + model_size: The model size to load. + quantize: Whether to quantize the model and using which method: + ``"llm.int8"``: LLM.int8() mode, + ``"gptq.int4"``: GPTQ 4-bit mode. + """ + if not checkpoint_path: + checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/lit-llama.pth") + if not tokenizer_path: + tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model") + assert checkpoint_path.is_file(), checkpoint_path + assert tokenizer_path.is_file(), tokenizer_path + + fabric = L.Fabric(accelerator="cuda", devices=1) + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + + print("Loading model ...", file=sys.stderr) + t0 = time.time() + with EmptyInitOnDevice( + device=fabric.device, dtype=dtype, quantization_mode=quantize + ): + model = LLaMA.from_name(model_size) + + checkpoint = lazy_load(checkpoint_path) + model.load_state_dict(checkpoint) + print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) + + model.eval() + model = fabric.setup_module(model) + + tokenizer = Tokenizer(tokenizer_path) + encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device) + + L.seed_everything(1234) + t0 = time.perf_counter() + + for _ in range(num_samples): + y = generate( + model, + encoded_prompt, + max_new_tokens, + model.config.block_size, # type: ignore[union-attr,arg-type] + temperature=temperature, + top_k=top_k, + ) + print(tokenizer.decode(y)) + + t = time.perf_counter() - t0 + print(f"\n\nTime for inference: {t:.02f} sec total, {num_samples * max_new_tokens / t:.02f} tokens/sec", file=sys.stderr) + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr) + + +if __name__ == "__main__": + from jsonargparse import CLI + + torch.set_float32_matmul_precision("high") + warnings.filterwarnings( + # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31 + "ignore", + message="ComplexHalf support is experimental and many operators don't support it yet" + ) + warnings.filterwarnings( + # Triggered in bitsandbytes/autograd/_functions.py:298 + "ignore", + message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization", + ) + CLI(main) diff --git a/lit_gpt/__init__.py b/lit_gpt/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..dc1c439981a5b6a94a7dad991bed98a9bb224515 --- /dev/null +++ b/lit_gpt/__init__.py @@ -0,0 +1,15 @@ +from lit_gpt.model import GPT +from lit_gpt.config import Config +from lit_gpt.tokenizer import Tokenizer + +from lightning_utilities.core.imports import RequirementCache + +_LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0") +# if not bool(_LIGHTNING_AVAILABLE): +# raise ImportError( +# "Lit-GPT requires lightning==2.1. Please run:\n" +# f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}" +# ) + + +__all__ = ["GPT", "Config", "Tokenizer"] diff --git a/lit_gpt/adapter.py b/lit_gpt/adapter.py new file mode 100755 index 0000000000000000000000000000000000000000..a99f85f6d48735ca5937b25f16607319c102f095 --- /dev/null +++ b/lit_gpt/adapter.py @@ -0,0 +1,165 @@ +"""Implementation of the paper: + +LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention +https://arxiv.org/abs/2303.16199 + +Port for Lit-GPT +""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from typing_extensions import Self + +from lit_gpt.config import Config as BaseConfig +from lit_gpt.model import GPT as BaseModel +from lit_gpt.model import Block as BaseBlock +from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention + + +@dataclass +class Config(BaseConfig): + adapter_prompt_length: int = 10 + adapter_start_layer: int = 2 + + +class GPT(BaseModel): + """The implementation is identical to `lit_gpt.model.GPT` with the exception that + the `Block` saves the layer index and passes it down to the attention layer.""" + + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 + ) -> Union[torch.Tensor, List[torch.Tensor]]: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, CausalSelfAttention): + module.reset_parameters() + + +class Block(BaseBlock): + """The implementation is identical to `lit_gpt.model.Block` with the exception that + we replace the attention layer where adaption is implemented.""" + + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config, block_idx) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention + over the adaption prompt.""" + + def __init__(self, config: Config, block_idx: int) -> None: + super().__init__(config) + if block_idx >= config.adapter_start_layer: + # adapter embedding layer + self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) + # gate for adaption + self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) + # kv cache for inference + self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self.block_idx = block_idx + + def scaled_dot_product_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + y = super().scaled_dot_product_attention(q, k, v, mask) + if self.block_idx < self.config.adapter_start_layer: + return y + + aT = self.config.adapter_prompt_length + if self.adapter_kv_cache is not None: + # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av + # are the same every call + ak, av = self.adapter_kv_cache + else: + prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) + aqkv = self.attn(prefix) + q_per_kv = self.config.n_head // self.config.n_query_groups + aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) + aqkv = aqkv.permute(0, 2, 3, 1, 4) + _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) + if self.config.n_query_groups != 1: + # for MHA this is a no-op + ak = ak.repeat_interleave(q_per_kv, dim=2) + av = av.repeat_interleave(q_per_kv, dim=2) + ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) + av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) + self.adapter_kv_cache = (ak, av) + + T = q.size(2) + amask = torch.ones(T, aT, dtype=torch.bool, device=q.device) + ay = super().scaled_dot_product_attention(q, ak, av, amask) + return y + self.gating_factor * ay + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.gating_factor) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with older checkpoints.""" + if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: + state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def mark_only_adapter_as_trainable(model: GPT) -> None: + """Sets `requires_grad=False` for all non-adapter weights.""" + for name, param in model.named_parameters(): + param.requires_grad = adapter_filter(name, param) + + +def adapter_filter(key: str, value: Any) -> bool: + return "adapter_wte" in key or "gating_factor" in key diff --git a/lit_gpt/adapter_v2.py b/lit_gpt/adapter_v2.py new file mode 100755 index 0000000000000000000000000000000000000000..e9e4c69cc59ec729a73cab6412c9d5485a2b0ab1 --- /dev/null +++ b/lit_gpt/adapter_v2.py @@ -0,0 +1,197 @@ +"""Implementation of the paper: + +LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model +https://arxiv.org/abs/2304.15010 + +Port for Lit-GPT +""" +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Type + +import torch +import torch.nn as nn +from typing_extensions import Self + +import lit_gpt +from lit_gpt.adapter import GPT as BaseModel +from lit_gpt.adapter import Block as BaseBlock +from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention +from lit_gpt.adapter import Config as BaseConfig +from lit_gpt.model import KVCache +from lit_gpt.utils import map_old_state_dict_weights + + +@dataclass +class Config(BaseConfig): + @property + def mlp_class(self) -> Type: + return getattr(lit_gpt.adapter_v2, self._mlp_class) + + +def adapter_filter(key: str, value: Any) -> bool: + adapter_substrings = ( + # regular adapter v1 parameters + "adapter_wte", + "gating_factor", + # adapter v2: new bias and scale used in Linear + "adapter_scale", + "adapter_bias", + # adapter v2: Norm parameters are now trainable + "norm_1", + "norm_2", + "ln_f", + ) + return any(s in key for s in adapter_substrings) + + +class AdapterV2Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, **kwargs) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False) + self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.adapter_scale * (self.linear(x) + self.adapter_bias) + + def reset_parameters(self) -> None: + nn.init.zeros_(self.adapter_bias) + nn.init.ones_(self.adapter_scale) + + +class GPT(BaseModel): + def __init__(self, config: Config) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, AdapterV2Linear): + module.reset_parameters() + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = {"lm_head.weight": "lm_head.linear.weight"} + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class Block(BaseBlock): + """The implementation is identical to `lit_gpt.model.Block` with the exception that + we replace the attention layer where adaption is implemented.""" + + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config, block_idx) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + """A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class""" + + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) + # output projection + self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + if block_idx >= config.adapter_start_layer: + # adapter embedding layer + self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) + # gate for adaption + self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) + # kv cache for inference + self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self.block_idx = block_idx + + self.config = config + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "attn.weight": "attn.linear.weight", + "attn.bias": "attn.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + # For compatibility with older checkpoints + if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: + state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + self.config = config + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc.weight": "fc.linear.weight", + "fc.bias": "fc.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class LLaMAMLP(lit_gpt.model.LLaMAMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc_1.weight": "fc_1.linear.weight", + "fc_1.bias": "fc_1.linear.bias", + "fc_2.weight": "fc_2.linear.weight", + "fc_2.bias": "fc_2.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def mark_only_adapter_v2_as_trainable(model: GPT) -> None: + """Sets requires_grad=False for all non-adapter weights""" + for name, param in model.named_parameters(): + param.requires_grad = adapter_filter(name, param) diff --git a/lit_gpt/config.py b/lit_gpt/config.py new file mode 100755 index 0000000000000000000000000000000000000000..43c4dcc2a2b13eeb3855a335dcaf8bd989df7bfc --- /dev/null +++ b/lit_gpt/config.py @@ -0,0 +1,1040 @@ +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, Optional, Type, Union + +import torch +from typing_extensions import Self + +import lit_gpt.model +from lit_gpt.utils import find_multiple + + +@dataclass +class Config: + org: str = "Lightning-AI" + name: str = "lit-GPT" + block_size: int = 4096 + vocab_size: int = 50254 + padding_multiple: int = 512 + padded_vocab_size: Optional[int] = None + n_layer: int = 16 + n_head: int = 32 + n_embd: int = 4096 + rotary_percentage: float = 0.25 + parallel_residual: bool = True + bias: bool = True + lm_head_bias: bool = False + # to use multi-head attention (MHA), set this to `n_head` (default) + # to use multi-query attention (MQA), set this to 1 + # to use grouped-query attention (GQA), set this to a value in between + # Example with `n_head=4` + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ │ │ │ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + # MHA GQA MQA + # n_query_groups=4 n_query_groups=2 n_query_groups=1 + # + # credit https://arxiv.org/pdf/2305.13245.pdf + n_query_groups: Optional[int] = None + shared_attention_norm: bool = False + _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" + norm_eps: float = 1e-5 + _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP" + gelu_approximate: str = "none" + intermediate_size: Optional[int] = None + rope_condense_ratio: int = 1 + rope_base: int = 10000 + + def __post_init__(self): + assert self.n_embd % self.n_head == 0 + self.head_size = self.n_embd // self.n_head + + # vocab size should be a power of 2 to be optimal on hardware. compute the closest value + if self.padded_vocab_size is None: + self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple) + else: + # vocab size shouldn't be larger than padded vocab size + self.vocab_size = min(self.vocab_size, self.padded_vocab_size) + + # compute the number of query groups + if self.n_query_groups is not None: + assert self.n_head % self.n_query_groups == 0 + else: + self.n_query_groups = self.n_head + + # compute the intermediate size for MLP if not set + if self.intermediate_size is None: + if self._mlp_class == "LLaMAMLP": + raise ValueError("The config needs to set the `intermediate_size`") + self.intermediate_size = 4 * self.n_embd + + self.rope_n_elem = int(self.rotary_percentage * self.head_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + conf_dict = name_to_config[name].copy() + if "condense_ratio" in kwargs: # legacy name + kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio") + conf_dict.update(kwargs) + return cls(**conf_dict) + + @classmethod + def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self: + with open(path, encoding="utf-8") as fp: + json_kwargs = json.load(fp) + if "condense_ratio" in json_kwargs: # legacy name + json_kwargs["rope_condense_ratio"] = json_kwargs.pop("condense_ratio") + if "condense_ratio" in kwargs: # legacy name + kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio") + json_kwargs.update(kwargs) + return cls(**json_kwargs) + + @property + def mlp_class(self) -> Type: + # `self._mlp_class` cannot be the type to keep the config json serializable + return getattr(lit_gpt.model, self._mlp_class) + + @property + def norm_class(self) -> Type: + # `self._norm_class` cannot be the type to keep the config json serializable + if self._norm_class == "RMSNorm": + from lit_gpt.rmsnorm import RMSNorm + + return RMSNorm + return getattr(torch.nn, self._norm_class) + + +######################## +# Stability AI StableLM +######################## +configs = [ + # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json + dict(org="stabilityai", name="stablelm-base-alpha-3b"), + # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json + dict(org="stabilityai", name="stablelm-base-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256), + # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json + dict(org="stabilityai", name="stablelm-tuned-alpha-3b", n_head=32), + # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json + dict(org="stabilityai", name="stablelm-tuned-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256), +] + +#################### +# EleutherAI Pythia +#################### +pythia = [ + # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json + dict(org="EleutherAI", name="pythia-70m", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128), + # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json + dict( + org="EleutherAI", name="pythia-160m", block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128 + ), + # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json + dict( + org="EleutherAI", name="pythia-410m", block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=128 + ), + # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json + dict(org="EleutherAI", name="pythia-1b", block_size=2048, n_embd=2048, n_head=8, padding_multiple=128), + # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json + dict( + org="EleutherAI", name="pythia-1.4b", block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128 + ), + # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json + dict(org="EleutherAI", name="pythia-2.8b", block_size=2048, n_layer=32, n_embd=2560, padding_multiple=128), + # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json + dict(org="EleutherAI", name="pythia-6.9b", block_size=2048, n_layer=32, padding_multiple=256), + # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json + dict(org="EleutherAI", name="pythia-12b", block_size=2048, n_layer=36, n_embd=5120, n_head=40), +] +configs.extend(pythia) +for c in pythia: + copy = c.copy() + copy["name"] = f"{c['name']}-deduped" + configs.append(copy) + + +#################################### +# togethercomputer RedPajama INCITE +#################################### +redpajama_incite = [ + # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json + dict( + org="togethercomputer", + name="RedPajama-INCITE-{}-3B-v1", + block_size=2048, + n_layer=32, + n_embd=2560, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), + # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json + dict( + org="togethercomputer", + name="RedPajama-INCITE-7B-{}", + block_size=2048, + n_layer=32, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), + # this redirects to the checkpoint above. kept for those who had the old weights already downloaded + dict( + org="togethercomputer", + name="RedPajama-INCITE-{}-7B-v0.1", + block_size=2048, + n_layer=32, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), +] +for c in redpajama_incite: + for kind in ("Base", "Chat", "Instruct"): + copy = c.copy() + copy["name"] = c["name"].format(kind) + configs.append(copy) + + +################# +# TII UAE Falcon +################# +falcon = [ + # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json + dict( + org="tiiuae", + name="falcon-7b{}", + block_size=2048, + vocab_size=65024, + padded_vocab_size=65024, + n_layer=32, + n_head=71, + n_embd=4544, + rotary_percentage=1.0, + n_query_groups=1, + bias=False, + # this is not in the config, but in the original model implementation, only for this config + shared_attention_norm=True, + ), + # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json + dict( + org="tiiuae", + name="falcon-40b{}", + block_size=2048, + vocab_size=65024, + padded_vocab_size=65024, + n_layer=60, + n_head=128, + n_embd=8192, + rotary_percentage=1.0, + n_query_groups=8, + bias=False, + ), +] +for c in falcon: + for kind in ("", "-instruct"): + copy = c.copy() + copy["name"] = c["name"].format(kind) + configs.append(copy) + +# https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json +falcon180b = dict( + org="tiiuae", + name="falcon-180B{}", + block_size=2048, + vocab_size=65024, + padded_vocab_size=65024, + n_layer=80, + n_head=232, + n_embd=14848, + rotary_percentage=1.0, + n_query_groups=8, + bias=False, +) + +for kind in ("", "-chat"): + copy = falcon180b.copy() + copy["name"] = falcon180b["name"].format(kind) + configs.append(copy) + + +############################# +# OpenLM Research Open LLaMA +############################# +open_LLaMA = [ + # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json + dict( + org="openlm-research", + name="open_llama_3b", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=26, + n_embd=3200, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=8640, + ), + # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json + dict( + org="openlm-research", + name="open_llama_7b", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json + dict( + org="openlm-research", + name="open_llama_13b", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), +] +configs.extend(open_LLaMA) + + +############### +# LMSYS Vicuna +############### +vicuna = [ + # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json + dict( + org="lmsys", + name="vicuna-7b-v1.3", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json + dict( + org="lmsys", + name="vicuna-13b-v1.3", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json + dict( + org="lmsys", + name="vicuna-33b-v1.3", + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=60, + n_head=52, + n_embd=6656, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=17920, + ), + # https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json + dict( + org="lmsys", + name="vicuna-7b-v1.5", + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json + dict( + org="lmsys", + name="vicuna-7b-v1.5-16k", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_condense_ratio=4, + ), + # https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json + dict( + org="lmsys", + name="vicuna-13b-v1.5", + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json + dict( + org="lmsys", + name="vicuna-13b-v1.5-16k", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_condense_ratio=4, + ), +] +configs.extend(vicuna) + + +################# +# LMSYS LongChat +################# +long_chat = [ + # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json + dict( + org="lmsys", + name="longchat-7b-16k", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_condense_ratio=8, + ), + # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json + dict( + org="lmsys", + name="longchat-13b-16k", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_condense_ratio=8, + ), +] +configs.extend(long_chat) + + +###################### +# NousResearch Hermes +###################### +nous_research = [ + # https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json + dict( + org="NousResearch", + name="Nous-Hermes-llama-2-7b", + padded_vocab_size=32000, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json + dict( + org="NousResearch", + name="Nous-Hermes-13b", + block_size=2048, + vocab_size=32000, + padded_vocab_size=32001, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b + dict( + org="NousResearch", + name="Nous-Hermes-Llama2-13b", + vocab_size=32000, + padded_vocab_size=32032, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), +] +configs.extend(nous_research) + + +############### +# Meta LLaMA 2 +############### +llama_2 = [ + # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json + dict( + org="meta-llama", + name="Llama-2-7b{}-hf", + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json + dict( + org="meta-llama", + name="Llama-2-13b{}-hf", + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json + dict( + org="meta-llama", + name="Llama-2-70b{}-hf", + vocab_size=32000, + padding_multiple=64, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), +] +for c in llama_2: + for kind in ("", "-chat"): + copy = c.copy() + copy["name"] = c["name"].format(kind) + configs.append(copy) + + +########################## +# Stability AI FreeWilly2 +########################## +freewilly_2 = [ + # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json + dict( + org="stabilityai", + name="FreeWilly2", + vocab_size=32000, + padding_multiple=64, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ) +] +configs.extend(freewilly_2) + + +################## +# Meta Code Llama +################## +code_llama = [ + # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json + dict( + org="codellama", + name="CodeLlama-7b-hf", + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json + dict( + org="codellama", + name="CodeLlama-13b-hf", + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json + dict( + org="codellama", + name="CodeLlama-34b-hf", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=48, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=22016, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json + dict( + org="codellama", + name="CodeLlama-7b-Python-hf", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json + dict( + org="codellama", + name="CodeLlama-13b-Python-hf", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json + dict( + org="codellama", + name="CodeLlama-34b-Python-hf", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=48, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=22016, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json + dict( + org="codellama", + name="CodeLlama-7b-Instruct-hf", + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json + dict( + org="codellama", + name="CodeLlama-13b-Instruct-hf", + block_size=2048, + vocab_size=32016, + padding_multiple=16, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json + dict( + org="codellama", + name="CodeLlama-34b-Instruct-hf", + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=48, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=22016, + rope_base=1000000, + ), +] +configs.extend(code_llama) + + +######################## +# garage-bAInd Platypus +######################## +platypus = [ + # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json + dict( + org="garage-bAInd", + name="Platypus-30B", + block_size=2048, + padded_vocab_size=32000, + n_layer=60, + n_head=52, + n_embd=6656, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-06, + _mlp_class="LLaMAMLP", + intermediate_size=17920, + ), + # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json + dict( + org="garage-bAInd", + name="Platypus2-7B", + padded_vocab_size=32000, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json + dict( + org="garage-bAInd", + name="Platypus2-13B", + padded_vocab_size=32000, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json + dict( + org="garage-bAInd", + name="Platypus2-70B", + padded_vocab_size=32000, + n_layer=80, + n_head=64, + n_embd=8192, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), + # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json + dict( + org="garage-bAInd", + name="Camel-Platypus2-13B", + padded_vocab_size=32000, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json + dict( + org="garage-bAInd", + name="Camel-Platypus2-70B", + padded_vocab_size=32000, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), + # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json + dict( + org="garage-bAInd", + name="Stable-Platypus2-13B", + padded_vocab_size=32000, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json + dict( + org="garage-bAInd", + name="Platypus2-70B-instruct", + padded_vocab_size=32000, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), +] +configs.extend(platypus) + + +########################## +# Stability AI StableCode +########################## +stablecode = [ + # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json + dict( + org="stabilityai", + name="stablecode-completion-alpha-3b", + block_size=16384, + vocab_size=49152, + n_layer=32, + n_embd=2560, + ), + # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json + dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k", vocab_size=49152, n_layer=32, n_embd=2560), + # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json + dict(org="stabilityai", name="stablecode-instruct-alpha-3b", vocab_size=49152, n_layer=32, n_embd=2560), +] +configs.extend(stablecode) + + +################################## +# togethercomputer LLaMA-2-7B-32K +################################## +together_llama2_32k = [ + # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json + dict( + org="togethercomputer", + name="LLaMA-2-7B-32K", + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_condense_ratio=8, + ) +] +configs.extend(together_llama2_32k) + + +################ +# Microsoft Phi +################ +phi = [ + # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json + dict( + org="microsoft", + name="phi-1_5", + vocab_size=50257, + padded_vocab_size=51200, + block_size=2048, + n_embd=2048, + n_layer=24, + rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64 + shared_attention_norm=True, + lm_head_bias=True, + gelu_approximate="tanh", + ) +] +configs.extend(phi) + + +############# +# Mistral AI +############# +mistral = [ + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + dict( + org="mistralai", + name="Mistral-7B-{}v0.1", + padded_vocab_size=32000, + block_size=4096, # should be 32768 but sliding window attention is not implemented + n_layer=32, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=14336, + ) +] +for c in mistral: + for kind in ("", "Instruct-"): + copy = c.copy() + copy["name"] = c["name"].format(kind) + configs.append(copy) + + +name_to_config = {config["name"]: config for config in configs} diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py new file mode 100755 index 0000000000000000000000000000000000000000..cce862be0e374364200fcf2d04d1a05953fa977f --- /dev/null +++ b/lit_gpt/lora.py @@ -0,0 +1,671 @@ +# Derived from https://github.com/microsoft/LoRA +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +r""" + Low Ranking Adaptation for LLMs scheme. + + ┌───────────────────┐ + ┆ h ┆ + └───────────────────┘ + ▲ + | + + + / \ + ┌─────────────────┐ ╭───────────────╮ Matrix initialization: + ┆ ┆ \ B / B = 0 + ┆ pretrained ┆ \ r*d / A = N(0, sigma^2) + ┆ weights ┆ ╰─────────╯ + ┆ ┆ | r | r - rank + ┆ W e R^(d*d) ┆ | ◀─────▶ | + ┆ ┆ ╭─────────╮ + └─────────────────┘ / A \ + ▲ / d*r \ + \ ╰───────────────╯ + \ ▲ + \ / + \ / + ┌───────────────────┐ + ┆ x ┆ + └───────────────────┘ + +With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d, +we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates +for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of +course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen +pretrained weights and thus fine-tune the model. + +The goal of this approach is to move weight updates into a separate matrix which is decomposed with +two matrices of a lower rank. +""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F +from typing_extensions import Self + +import lit_gpt +from lit_gpt.config import Config as BaseConfig +from lit_gpt.model import GPT as BaseModel +from lit_gpt.model import Block as BaseBlock +from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention +from lit_gpt.model import KVCache +from lit_gpt.utils import map_old_state_dict_weights + + +class LoRALayer(nn.Module): + def __init__(self, r: int, lora_alpha: int, lora_dropout: float): + """Store LoRA specific attributes in a class. + + Args: + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + """ + super().__init__() + assert r >= 0 + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + + +class LoRALinear(LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + in_features: int, + out_features: int, + # ↓ the remaining part is for LoRA + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + """LoRA wrapper around linear class. + + This class has three weight matrices: + 1. Pretrained weights are stored as `self.linear.weight` + 2. LoRA A matrix as `self.lora_A` + 3. LoRA B matrix as `self.lora_B` + Only LoRA's A and B matrices are updated, pretrained weights stay frozen. + + Args: + in_features: number of input features of the pretrained weights + out_features: number of output features of the pretrained weights + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + """ + super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(torch.zeros((r, in_features))) + self.lora_B = nn.Parameter(torch.zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset all the weights, even including pretrained ones.""" + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and not self.merged: + # Merge the weights and mark it + self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or self.merged: + return pretrained + lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return pretrained + lora + + +class LoRAQKVLinear(LoRALinear): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + in_features: int, + out_features: int, + # ↓ the remaining part is for LoRA + n_head: int, + n_query_groups: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + enable_lora: Union[bool, Tuple[bool, bool, bool]] = False, + **kwargs, + ): + """LoRA wrapper around linear class that is used for calculation of q, k and v matrices. + + This class has three weight matrices: + 1. Pretrained weights are stored as `self.linear.weight` + 2. LoRA A matrix as `self.lora_A` + 3. LoRA B matrix as `self.lora_B` + Only LoRA's A and B matrices are updated, pretrained weights stay frozen. + + Args: + in_features: number of input features of the pretrained weights + out_features: number of output features of the pretrained weights + n_head: number of attention heads + n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`) + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we + don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query` + and `value` but keep `key` without weight updates we should pass `[True, False, True]` + """ + super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + self.n_head = n_head + self.n_query_groups = n_query_groups + if isinstance(enable_lora, bool): + enable_lora = [enable_lora] * 3 + assert len(enable_lora) == 3 + self.enable_lora = enable_lora + + # Actual trainable parameters + # To better understand initialization let's imagine that we have such parameters: + # ⚬ in_features: 128 (embeddings_size) + # ⚬ out_features: 384 (3 * embedding_size) + # ⚬ r: 2 + # ⚬ enable_lora: [True, False, True] + if r > 0 and any(enable_lora): + self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features))) # (4, 128) + enable_q, enable_k, enable_v = enable_lora + self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups) + # qkv_shapes will be used to split a tensor with weights correctly + qkv_shapes = ( + self.linear.in_features * enable_q, + self.kv_embd_size * enable_k, + self.kv_embd_size * enable_v, + ) + self.qkv_shapes = [s for s in qkv_shapes if s] + self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r)) # (256, 2)) + # Notes about shapes above + # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices; + # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in + # F.linear function weights are automatically transposed. In addition conv1d requires channels to + # be before seq length + # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is + # 128*2; 2 tells to have two channels per group for group convolution + + # Scaling: + # This balances the pretrained model`s knowledge and the new task-specific adaptation + # https://lightning.ai/pages/community/tutorial/lora-llm/ + # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set + # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can + # tune these values to your needs. This value can be even slightly greater than 1.0! + # https://github.com/cloneofsimo/lora + self.scaling = self.lora_alpha / self.r + + # Compute the indices + # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values, + # but not keys, then the weights update should be: + # + # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], + # [....................................], + # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] + # ↑ ↑ ↑ + # ________________________________________ + # | query | key | value | + # ---------------------------------------- + self.lora_ind = [] + if enable_q: + self.lora_ind.extend(range(0, self.linear.in_features)) + if enable_k: + self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size)) + if enable_v: + self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features)) + self.reset_parameters() + + def zero_pad(self, x: torch.Tensor) -> torch.Tensor: + """Properly pad weight updates with zeros. + + If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, + then the weights update should be: + + [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], + [....................................], + [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] + ↑ ↑ ↑ + ________________________________________ + | query | key | value | + ---------------------------------------- + + Args: + x: tensor with weights update that will be padded with zeros if necessary + + Returns: + A tensor with weight updates and zeros for deselected q, k or v + """ + # we need to do zero padding only if LoRA is disabled for one of QKV matrices + if all(self.enable_lora): + return x + + # Let's image that: + # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size) + # ⚬ embeddings_size: 128 + # ⚬ self.linear.out_features: 384 (3 * embeddings_size) + # ⚬ enable_lora: [True, False, True] + # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected + # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but + # only for key updates (this is where self.lora_ind comes in handy) + # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors + # for example when we want to merge/unmerge LoRA weights and pretrained weights + x = x.transpose(0, 1) + result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384) + result = result.view(-1, self.linear.out_features) # (4096, 384) + result = result.index_copy( + 1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) + ) # (4096, 256) + return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384) + + def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. + + If the number of heads is equal to the number of query groups - grouped queries are disabled + (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized + query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the + input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple + conv layers side by side). + + Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually, + apply each part of the weight matrix to the corresponding input's part and concatenate the result. + + Args: + input: input matrix of shape (B, C, T) + weight: weight matrix of shape (C_output, rank, 1). + "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class). + + Returns: + A tensor with a shape (B, C_output, T) + + """ + if self.n_head == self.n_query_groups: + return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T) + + # Notation: + # ⚬ N: number of enabled LoRA layers (self.enable_lora) + # ⚬ C_output': embeddings size for each LoRA layer (not equal in size) + # ⚬ r: rank of all LoRA layers (equal in size) + + input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) + weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) + return torch.cat( + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T) + ) # (B, C_output, T) + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + + # Let's assume that: + # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size) + # ⚬ self.lora_A.data: (4, 128) + # ⚬ self.lora_B.data: (256, 2) + if self.r > 0 and any(self.enable_lora) and not self.merged: + delta_w = self.conv1d( + self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) + self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) + ).squeeze( + 0 + ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + # W = W + delta_W (merge) + self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128) + self.merged = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Do the forward pass. + + If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. + If not, then multiply pretrained weights with input, apply LoRA on input and do summation. + + Args: + x: input tensor of shape (batch_size, context_length, embedding_size) + + Returns: + Output tensor of shape (batch_size, context_length, 3 * embedding_size) + """ + + # Let's assume that: + # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size) + # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size) + # ⚬ self.lora_A.data: (4, 128) + # ⚬ self.lora_B.data: (256, 2) + + # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or not any(self.enable_lora) or self.merged: + return pretrained + after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4) + # For F.conv1d: + # ⚬ input: input tensor of shape (mini-batch, in_channels, iW) + # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW) + after_B = self.conv1d( + after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) + self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) + ).transpose( + -2, -1 + ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) + return pretrained + lora + + +def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: + """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights. + + Args: + model: model with LoRA layers + bias: + ``"none"``: all bias weights will be frozen, + ``"lora_only"``: only bias weight for LoRA layers will be unfrozen, + ``"all"``: all bias weights will be unfrozen. + + Raises: + NotImplementedError: if `bias` not in ["none", "lora_only", "all"] + """ + # freeze all layers except LoRA's + for n, p in model.named_parameters(): + if "lora_" not in n: + p.requires_grad = False + + # depending on the `bias` value unfreeze bias weights + if bias == "none": + return + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in model.modules(): + if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_filter(key: str, value: Any) -> bool: + return "lora_" in key + + +@dataclass +class Config(BaseConfig): + """ + Args: + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + to_*: either apply LoRA to the specified weights or not + """ + + r: int = 0 + alpha: int = 1 + dropout: float = 0.0 + to_query: bool = False + to_key: bool = False + to_value: bool = False + to_projection: bool = False + to_mlp: bool = False + to_head: bool = False + + @property + def mlp_class(self) -> Type: + return getattr(lit_gpt.lora, self._mlp_class) + + +class GPT(BaseModel): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = LoRALinear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + r=(config.r if config.to_head else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0, maxlen: int = None + ) -> Union[torch.Tensor, List[torch.Tensor]]: + T = idx.size(1) if maxlen is None else maxlen + if self.max_seq_length < T: + raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") + # import pdb; pdb.set_trace() + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + if type(idx) is tuple: + # import pdb; pdb.set_trace() + stack_before_tokens_x, motion_tokens, before_len = idx + # stack_before_tokens_x = stack_before_tokens_x.unsqueeze(0) + # motion_tokens = motion_tokens.unsqueeze(0) + # stack_before_tokens_x[0][before_len[0]: before_len[0] + len(motion_tokens[0])] = 1 + # import pdb; pdb.set_trace() + x = self.transformer.wte(stack_before_tokens_x) + # import pdb; pdb.set_trace() + for i in range(len(x)): + x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i] + else: + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] + return self.lm_head(x) # (B, T, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, LoRALinear): + module.reset_parameters() + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = {"lm_head.weight": "lm_head.linear.weight"} + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class Block(BaseBlock): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + def __init__(self, config: Config) -> None: + # Skip the parent class __init__ altogether and replace it to avoid + # useless allocations + nn.Module.__init__(self) + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = LoRAQKVLinear( + in_features=config.n_embd, + out_features=shape, + r=config.r, + lora_alpha=config.alpha, + lora_dropout=config.dropout, + enable_lora=(config.to_query, config.to_key, config.to_value), + bias=config.bias, + # for MQA/GQA support + n_head=config.n_head, + n_query_groups=config.n_query_groups, + ) + # output projection + self.proj = LoRALinear( + config.n_embd, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_projection else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + self.config = config + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "attn.weight": "attn.linear.weight", + "attn.bias": "attn.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.proj = LoRALinear( + config.intermediate_size, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + + self.config = config + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc.weight": "fc.linear.weight", + "fc.bias": "fc.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class LLaMAMLP(lit_gpt.model.LLaMAMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc_1 = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.fc_2 = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.proj = LoRALinear( + config.intermediate_size, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc_1.weight": "fc_1.linear.weight", + "fc_1.bias": "fc_1.linear.bias", + "fc_2.weight": "fc_2.linear.weight", + "fc_2.bias": "fc_2.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def merge_lora_weights(model: GPT) -> None: + """Merge LoRA weights into the full-rank weights to speed up inference.""" + for module in model.modules(): + if isinstance(module, LoRALinear): + module.merge() diff --git a/lit_gpt/model.py b/lit_gpt/model.py new file mode 100755 index 0000000000000000000000000000000000000000..bc9a4e58352ea1f528b3f8e296de747da5280217 --- /dev/null +++ b/lit_gpt/model.py @@ -0,0 +1,355 @@ +"""Full definition of a GPT NeoX Language Model, all of it in this single file. + +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and +https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. +""" +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Self + +from lit_gpt.config import Config + + +class GPT(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + @property + def max_seq_length(self) -> int: + return self._max_seq_length + + @max_seq_length.setter + def max_seq_length(self, value: int) -> None: + """ + When doing inference, the sequences used might be shorter than the model's context length. + This allows setting a smaller number to avoid allocating unused memory + """ + if value > self.config.block_size: + raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}") + self._max_seq_length = value + if not hasattr(self, "cos"): + # first call + cos, sin = self.rope_cache() + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + elif value != self.cos.size(0): + # override + self.cos, self.sin = self.rope_cache(device=self.cos.device) + # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know + # if the kv cache is expected + + def reset_parameters(self) -> None: + # Trigger resetting the rope-cache + self.max_seq_length = self.config.block_size + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`.""" + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, maxlen: int = None) -> torch.Tensor: + T = idx.size(1) if maxlen is None else maxlen + # print(T, end=', ') + if self.max_seq_length < T: + raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") + + # import pdb; pdb.set_trace() + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + if type(idx) is tuple: + stack_before_tokens_x, motion_tokens, before_len = idx + # stack_before_tokens_x = stack_before_tokens_x.unsqueeze(0) + # motion_tokens = motion_tokens.unsqueeze(0) + # stack_before_tokens_x[0][before_len[0]: before_len[0] + len(motion_tokens[0])] = 1 + x = self.transformer.wte(stack_before_tokens_x.cuda()) + # import pdb; pdb.set_trace() + for i in range(len(x)): + x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i].cuda() + else: + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return build_rope_cache( + seq_len=self.max_seq_length, + n_elem=self.config.rope_n_elem, + device=device, + condense_ratio=self.config.rope_condense_ratio, + base=self.config.rope_base, + ) + + def set_kv_cache( + self, + batch_size: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + if rope_cache_length is None: + rope_cache_length = self.cos.size(-1) + max_seq_length = self.max_seq_length + + # initialize the kv cache for all blocks + for block in self.transformer.h: + block.attn.kv_cache = block.attn.build_kv_cache( + batch_size, max_seq_length, rope_cache_length, device, dtype + ) + + if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: + # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 + ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) + self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) + + def clear_kv_cache(self) -> None: + self.mask_cache = None + for block in self.transformer.h: + block.attn.kv_cache = None + + +class Block(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + n_1 = self.norm_1(x) + h = self.attn(n_1, cos, sin, mask, input_pos) + if self.config.parallel_residual: + n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) + x = self.mlp(n_2) + h + x + else: + if self.config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + x = h + x + x = self.mlp(self.norm_2(x)) + x + return x + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + qkv = self.attn(x) + + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) + qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + + # maybe repeat k and v if for the non multi-head attention cases + # training: flash attention requires it + # inference: multi-query would require a full kv cache so avoid it to limit its memory usage + if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): + k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + + q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) + k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) + v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) + + q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + + if input_pos is not None: + if not isinstance(self.kv_cache, KVCache): + raise TypeError("You need to call `gpt.set_kv_cache()`") + k, v = self.kv_cache(input_pos, k, v) + + y = self.scaled_dot_product_attention(q, k, v, mask) + + y = y.reshape(B, T, C) # re-assemble all head outputs side by side + + # output projection + return self.proj(y) + + def scaled_dot_product_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + scale = 1.0 / math.sqrt(self.config.head_size) + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, + # scale=scale, + is_causal=mask is None + ) + return y.transpose(1, 2) + + def build_kv_cache( + self, + batch_size: int, + max_seq_length: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> "KVCache": + heads = 1 if self.config.n_query_groups == 1 else self.config.n_head + v_shape = (batch_size, heads, max_seq_length, self.config.head_size) + if rope_cache_length is None: + if self.config.rotary_percentage != 1.0: + raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") + k_shape = v_shape + else: + k_shape = ( + batch_size, + heads, + max_seq_length, + rope_cache_length + self.config.head_size - self.config.rope_n_elem, + ) + return KVCache(k_shape, v_shape, device=device, dtype=dtype) + + +class GptNeoxMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate) + return self.proj(x) + + +class LLaMAMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + return self.proj(x) + + +def build_rope_cache( + seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + + return torch.cos(idx_theta), torch.sin(idx_theta) + + +def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + head_size = x.size(-1) + x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) + x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + roped = (x * cos) + (rotated * sin) + return roped.type_as(x) + + +class KVCache(nn.Module): + def __init__( + self, + k_shape: Tuple[int, int, int, int], + v_shape: Tuple[int, int, int, int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) + self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) + + def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(k.dtype) + self.v = self.v.to(v.dtype) + # update the cache + k = self.k.index_copy_(2, input_pos, k) + v = self.v.index_copy_(2, input_pos, v) + return k, v diff --git a/lit_gpt/packed_dataset.py b/lit_gpt/packed_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..12f85f714d05be3a92c98cae90ee9b8fde2608b4 --- /dev/null +++ b/lit_gpt/packed_dataset.py @@ -0,0 +1,235 @@ +# Very loosely inspired by indexed_dataset in Fairseq, Megatron +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py + + +import os +import random +import struct + +import numpy as np +import torch +from torch.utils.data import IterableDataset, get_worker_info + +dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} + + +def code(dtype): + for k in dtypes: + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +HDR_MAGIC = b"LITPKDS" +HDR_SIZE = 24 # bytes + + +class PackedDataset(IterableDataset): + def __init__( + self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 + ): + self._filenames = filenames + self._n_chunks = n_chunks + self._block_size = block_size + self._seed = seed + self._shuffle = shuffle + self._wrap = wrap + self._num_processes = num_processes + self._process_rank = process_rank + + def __iter__(self): + worker_info = get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + worker_id = worker_info.id if worker_info is not None else 0 + num_shards = num_workers * self._num_processes + shard_id = self._process_rank * num_workers + worker_id + + max_num_files = len(self._filenames) // num_shards * num_shards + filenames = self._filenames[shard_id:max_num_files:num_shards] + + return PackedDatasetIterator( + filenames=filenames, + n_chunks=self._n_chunks, + block_size=self._block_size, + seed=self._seed, + shuffle=self._shuffle, + wrap=self._wrap, + ) + + +class PackedDatasetBuilder(object): + def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None): + if dtype == "auto": + if vocab_size is None: + raise ValueError("vocab_size cannot be None when dtype='auto'") + if vocab_size is not None and vocab_size < 65500: + self._dtype = np.uint16 + else: + self._dtype = np.int32 + else: + self._dtype = dtype + self._counter = 0 + self._chunk_size = chunk_size + self._outdir = outdir + self._prefix = prefix + self._sep_token = sep_token + self._arr = np.zeros(self._chunk_size, dtype=self._dtype) + self._arr.fill(self._sep_token) + self._idx = 0 + self._version = 1 + self._filenames = [] + + def _write_chunk(self): + filename = f"{self._prefix}_{self._counter:010d}.bin" + filename = os.path.join(self._outdir, filename) + + with open(filename, "wb") as f: + f.write(HDR_MAGIC) + f.write(struct.pack(" self._chunk_size: + part_len = self._chunk_size - self._idx + self._arr[self._idx : self._idx + part_len] = arr[:part_len] + self._write_chunk() + arr = arr[part_len:] + + arr_len = arr.shape[0] + self._arr[self._idx : self._idx + arr_len] = arr + self._idx += arr_len + + def write_reminder(self): + self._write_chunk() + + +class PackedDatasetIterator: + def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): + self._seed = seed + self._shuffle = shuffle + self._rng = np.random.default_rng(seed) if shuffle else None + self._block_idxs = None + + self._wrap = wrap + + # TODO: instead of filenames, we could have a single text stream + # (or text file) with the sequence of all files to be + # fetched/loaded. + self._filenames = filenames + self._file_idx = 0 + + self._n_chunks = n_chunks + + self._dtype = None + self._block_size = block_size + self._n_blocks = None + + self._mmaps = [] + self._buffers = [] + + self._block_idxs = [] + self._curr_idx = 0 + + self._load_n_chunks() + + def _read_header(self, path): + with open(path, "rb") as f: + magic = f.read(len(HDR_MAGIC)) + assert magic == HDR_MAGIC, "File doesn't match expected format." + version = struct.unpack(" len(self._filenames[self._file_idx :]): + if not self._wrap: + raise StopIteration + self._file_idx = 0 + + for i in range(self._n_chunks): + filename = self._filenames[self._file_idx + i] + if self._dtype is None: + self._dtype, self._chunk_size = self._read_header(filename) + self._n_blocks = self._chunk_size // self._block_size + # TODO: check header matches with previous files + mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) + self._mmaps.append(mmap) + self._buffers.append(memoryview(mmap)) + + self._file_idx += self._n_chunks + n_all_blocks = self._n_chunks * self._n_blocks + + self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) + + self._curr_idx = 0 + + def __del__(self): + self._close_mmaps() + del self._mmaps + del self._buffers + + def __iter__(self): + return self + + def __next__(self): + if self._curr_idx >= len(self._block_idxs): + self._load_n_chunks() + # TODO: trigger fetching next next n_chunks if remote + block_idx = self._block_idxs[self._curr_idx] + chunk_id = block_idx // self._n_blocks + buffer = self._buffers[chunk_id] + elem_id = (block_idx % self._n_blocks) * self._block_size + offset = np.dtype(self._dtype).itemsize * elem_id + arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) + self._curr_idx += 1 + return torch.from_numpy(arr.astype(np.int64)) + + +class CombinedDataset(IterableDataset): + def __init__(self, datasets, seed, weights=None): + self._seed = seed + self._datasets = datasets + self._weights = weights + n_datasets = len(datasets) + if weights is None: + self._weights = [1 / n_datasets] * n_datasets + + def __iter__(self): + return CombinedDatasetIterator(self._datasets, self._seed, self._weights) + + +class CombinedDatasetIterator: + def __init__(self, datasets, seed, weights): + self._datasets = [iter(el) for el in datasets] + self._weights = weights + self._rng = random.Random(seed) + + def __next__(self): + (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) + return next(dataset) diff --git a/lit_gpt/rmsnorm.py b/lit_gpt/rmsnorm.py new file mode 100755 index 0000000000000000000000000000000000000000..e2580fcbe9c6b713ad19eee5a346981f0039dcd8 --- /dev/null +++ b/lit_gpt/rmsnorm.py @@ -0,0 +1,26 @@ +import torch + + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x = x.float() + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return (self.weight * x_normed).to(dtype=dtype) + + def reset_parameters(self) -> None: + torch.nn.init.ones_(self.weight) diff --git a/lit_gpt/speed_monitor.py b/lit_gpt/speed_monitor.py new file mode 100755 index 0000000000000000000000000000000000000000..4d4b1b0001be07ca08842dacae5686bb5dd706e9 --- /dev/null +++ b/lit_gpt/speed_monitor.py @@ -0,0 +1,425 @@ +import time +from collections import deque +from contextlib import nullcontext +from typing import Any, Callable, Deque, Dict, Optional + +import torch +from lightning import Callback, Fabric, LightningModule, Trainer +from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 +from lightning.fabric.plugins import ( + BitsandbytesPrecision, + DoublePrecision, + FSDPPrecision, + HalfPrecision, + MixedPrecision, + Precision, + TransformerEnginePrecision, + XLAPrecision, +) +from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only +from lightning.pytorch.plugins import ( + DoublePrecisionPlugin, + FSDPPrecisionPlugin, + HalfPrecisionPlugin, + MixedPrecisionPlugin, + XLAPrecisionPlugin, +) +from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only +from torch.utils.flop_counter import FlopCounterMode + +from lit_gpt import GPT +from lit_gpt.utils import num_parameters + +GPU_AVAILABLE_FLOPS = { + # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet + # nvidia publishes spec sheet with a 2x sparsity factor + "h100-sxm": { + torch.float64: 67e12, + torch.float32: 67e12, + torch.bfloat16: 1.979e15 / 2, + torch.float16: 1.979e15 / 2, + torch.int8: 3.958e15 / 2, + }, + "h100-pcie": { + torch.float64: 51e12, + torch.float32: 51e12, + torch.bfloat16: 1.513e15 / 2, + torch.float16: 1.513e15 / 2, + torch.int8: 3.026e15 / 2, + }, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + # sxm and pcie have same flop counts + "a100": {torch.float64: 19.5e12, torch.float32: 19.5e12, torch.bfloat16: 312e12, torch.float16: 312e12}, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf + "a10g": {torch.float32: 31.2e12, torch.bfloat16: 125e12, torch.float16: 125e12}, + # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf + "v100-sxm": {torch.float64: 7.8e12, torch.float32: 15.7e12, torch.float16: 125e12}, + "v100-pcie": {torch.float64: 7e12, torch.float32: 14e12, torch.float16: 112e12}, + "v100s-pcie": {torch.float64: 8.2e12, torch.float32: 16.4e12, torch.float16: 130e12}, + # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf + # sxm and pcie have same flop counts + "t4": {torch.float32: 8.1e12, torch.float16: 65e12, torch.int8: 130e12}, + # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf + "quadro rtx 5000": {torch.float32: 11.2e12, torch.float16: 89.2e12}, +} + +TPU_AVAILABLE_FLOPS = { + # flop count for each TPU generation is the same for all precisions + # since bfloat16 precision is always used for performing matrix operations + # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16 + # source: https://arxiv.org/pdf/1907.10701.pdf + "v2": 45e12, + # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3 + "v3": 123e12, + # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4 + "v4": 275e12, + # source: https://cloud.google.com/tpu/docs/v5e-training + "v5litepod": 197e12, +} + + +def get_flops_available(device: torch.device, dtype: torch.dtype) -> Optional[float]: + if device.type == "cuda": + device_name = torch.cuda.get_device_name(device).lower() + if "h100" in device_name and "hbm3" in device_name: + device_name = "h100-sxm" + elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name): + device_name = "h100-pcie" + elif "a100" in device_name: + device_name = "a100" + elif "a10g" in device_name: + device_name = "a10g" + elif "v100-sxm" in device_name: + device_name = "v100-sxm" + elif "v100-pcie" in device_name: + device_name = "v100-pcie" + elif "t4" in device_name: + device_name = "t4" + elif "quadro rtx 5000" in device_name: + device_name = "quadro rtx 5000" + else: + device_name = None + + if device_name is not None: + try: + return int(GPU_AVAILABLE_FLOPS[device_name][dtype]) + except KeyError: + raise KeyError( + f"flop count not found for {device_name} with dtype: {dtype}; " + "MFU cannot be calculated and reported." + ) + elif device.type == "xla": + if _XLA_GREATER_EQUAL_2_1: + from torch_xla._internal import tpu + else: + from torch_xla.experimental import tpu + + device_name = tpu.get_tpu_env()["TYPE"].lower() + try: + return int(TPU_AVAILABLE_FLOPS[device_name]) + except KeyError: + raise KeyError( + f"flop count not found for {device_name} with dtype: {dtype}; MFU cannot be calculated and reported." + ) + + return None + + +# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py + + +class SpeedMonitorBase: + """Logs the training throughput and utilization. + + +-------------------------------------+-----------------------------------------------------------+ + | Key | Logged data | + +=====================================+===========================================================+ + | | Rolling average (over `window_size` most recent | + | `throughput/batches_per_sec` | batches) of the number of batches processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/samples_per_sec` | batches) of the number of samples processed per second | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | Rolling average (over `window_size` most recent | + | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | + | | This may include padding depending on dataset | + +-------------------------------------+-----------------------------------------------------------+ + | | Estimates flops by `flops_per_batch * batches_per_sec` | + | `throughput/flops_per_sec` | | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/tokens_per_sec` divided by world size. This | + | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/flops_per_sec` divided by world size. Only | + | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | | `throughput/device/flops_per_sec` divided by world size. | + | `throughput/device/mfu` | | + | | | + +-------------------------------------+-----------------------------------------------------------+ + | `time/train` | Total elapsed training time | + +-------------------------------------+-----------------------------------------------------------+ + | `time/val` | Total elapsed validation time | + +-------------------------------------+-----------------------------------------------------------+ + | `time/total` | Total elapsed time (time/train + time/val) | + +-------------------------------------+-----------------------------------------------------------+ + + Notes: + - The implementation assumes that devices are homogeneous as it normalizes by the world size. + - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or + batches/sec to measure throughput under this circumstance. + - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``. + There is no widespread, realistic, and reliable implementation to compute them. + We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which + will almost always be an overestimate when compared to the true value. + + Args: + window_size (int, optional): Number of batches to use for a rolling average of throughput. + Defaults to 100. + time_unit (str, optional): Time unit to use for `time` logging. Can be one of + 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'. + """ + + def __init__( + self, + flops_available: float, + log_dict: Callable[[Dict, int], None], + window_size: int = 100, + time_unit: str = "hours", + ): + self.flops_available = flops_available + self.log_dict = log_dict + + # Track the batch num samples and wct to compute throughput over a window of batches + self.history_samples: Deque[int] = deque(maxlen=window_size + 1) + self.history_wct: Deque[float] = deque(maxlen=window_size + 1) + self.history_lengths: Deque[int] = deque(maxlen=window_size + 1) + self.history_flops: Deque[int] = deque(maxlen=window_size + 1) + + self.divider = 1 + if time_unit == "seconds": + self.divider = 1 + elif time_unit == "minutes": + self.divider = 60 + elif time_unit == "hours": + self.divider = 60 * 60 + elif time_unit == "days": + self.divider = 60 * 60 * 24 + else: + raise ValueError( + f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".' + ) + + # Keep track of time spent evaluating + self.total_eval_wct = 0.0 + self.step = -1 + + def on_train_batch_end( + self, + samples: int, # total samples seen (per device) + train_elapsed: float, # total training time (seconds) + world_size: int, + flops_per_batch: Optional[int] = None, # (per device) + lengths: Optional[int] = None, # total length of the samples seen (per device) + ) -> None: + self.step += 1 + step = self.step + metrics = {} + + self.history_samples.append(samples) + if lengths is not None: + self.history_lengths.append(lengths) + # if lengths are passed, there should be as many values as samples + assert len(self.history_samples) == len(self.history_lengths) + self.history_wct.append(train_elapsed) + if len(self.history_wct) == self.history_wct.maxlen: + elapsed_batches = len(self.history_samples) - 1 + elapsed_samples = self.history_samples[-1] - self.history_samples[0] + elapsed_wct = self.history_wct[-1] - self.history_wct[0] + samples_per_sec = elapsed_samples * world_size / elapsed_wct + dev_samples_per_sec = elapsed_samples / elapsed_wct + metrics.update( + { + "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct, + "throughput/samples_per_sec": samples_per_sec, + "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct, + "throughput/device/samples_per_sec": dev_samples_per_sec, + } + ) + if lengths is not None: + elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0]) + avg_length = elapsed_lengths / elapsed_batches + metrics.update( + { + "throughput/tokens_per_sec": samples_per_sec * avg_length, + "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length, + } + ) + + if flops_per_batch is not None: + # sum of flops per batch across ranks + self.history_flops.append(flops_per_batch * world_size) + if len(self.history_flops) == self.history_flops.maxlen: + elapsed_flops = sum(self.history_flops) - self.history_flops[0] + elapsed_wct = self.history_wct[-1] - self.history_wct[0] + flops_per_sec = elapsed_flops / elapsed_wct + device_flops_per_sec = flops_per_sec / world_size + metrics.update( + {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec} + ) + if self.flops_available: + metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available + + metrics.update( + { + "time/train": train_elapsed / self.divider, + "time/val": self.total_eval_wct / self.divider, + "time/total": (train_elapsed + self.total_eval_wct) / self.divider, + "samples": samples, + } + ) + + self.log_dict(metrics, step) + + def eval_end(self, eval_elapsed: float) -> None: + self.total_eval_wct += eval_elapsed # seconds + + +def plugin_to_compute_dtype(plugin: Precision) -> torch.dtype: + if isinstance(plugin, BitsandbytesPrecision): + return plugin.dtype + if isinstance(plugin, (HalfPrecision, MixedPrecision, HalfPrecisionPlugin)): + return plugin._desired_input_dtype + if isinstance(plugin, MixedPrecisionPlugin): + return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half + if isinstance(plugin, (DoublePrecision, DoublePrecisionPlugin)): + return torch.double + if isinstance(plugin, (XLAPrecision, XLAPrecisionPlugin)): + return plugin._desired_dtype + if isinstance(plugin, TransformerEnginePrecision): + return torch.int8 + if isinstance(plugin, (FSDPPrecision, FSDPPrecisionPlugin)): + return plugin.mixed_precision_config.reduce_dtype + if isinstance(plugin, Precision): + return torch.float32 + raise NotImplementedError(plugin) + + +class SpeedMonitorFabric(SpeedMonitorBase): + def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None: + dtype = plugin_to_compute_dtype(fabric.strategy.precision) + flops_available = get_flops_available(fabric.device, dtype) + super().__init__(flops_available, fabric.log_dict, *args, **kwargs) + + @fabric_rank_zero_only + def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None: + super().on_train_batch_end(*args, **kwargs) + + +class SpeedMonitorCallback(Callback): + def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None: + super().__init__() + self.speed_monitor: Optional[SpeedMonitorBase] = None + self.speed_monitor_kwargs = kwargs + self.length_fn = length_fn + self.batch_size = batch_size + self.eval_t0: int = 0 + self.train_t0: int = 0 + self.total_lengths: int = 0 + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + if self.speed_monitor is not None: + return # already setup + dtype = plugin_to_compute_dtype(trainer.precision_plugin) + flops_available = get_flops_available(trainer.strategy.root_device, dtype) + self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs) + + @trainer_rank_zero_only + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + if trainer.fit_loop._should_accumulate(): + return + + self.train_t0 = time.perf_counter() + + @trainer_rank_zero_only + def on_train_batch_end( + self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int + ) -> None: + self.total_lengths += self.length_fn(batch) + if trainer.fit_loop._should_accumulate(): + return + train_elapsed = time.perf_counter() - self.train_t0 + assert self.speed_monitor is not None + iter_num = trainer.fit_loop.total_batch_idx + assert (measured_flops := pl_module.measured_flops) is not None + self.speed_monitor.on_train_batch_end( + (iter_num + 1) * self.batch_size, + train_elapsed, + # this assumes that device FLOPs are the same and that all devices have the same batch size + trainer.world_size, + flops_per_batch=measured_flops, + lengths=self.total_lengths, + ) + + @trainer_rank_zero_only + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self.eval_t0 = time.perf_counter() + + @trainer_rank_zero_only + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + eval_elapsed = time.perf_counter() - self.eval_t0 + assert self.speed_monitor is not None + self.speed_monitor.eval_end(eval_elapsed) + + +def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int: + flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation + # this assumes that all samples have a fixed length equal to the block size + # which is most likely false during finetuning + flops_per_seq = flops_per_token * max_seq_length + attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) + return flops_per_seq + attn_flops_per_seq + + +def estimate_flops(model: GPT) -> int: + """Measures estimated FLOPs for MFU. + + Refs: + * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 + * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 + """ + # using all parameters for this is a naive over estimation because not all model parameters actually contribute to + # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage + # (~10%) compared to the measured FLOPs, making those lower but more realistic. + # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. + n_trainable_params = num_parameters(model, requires_grad=True) + trainable_flops = flops_per_param( + model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params + ) + # forward + backward + gradients (assumes no gradient accumulation) + ops_per_step = 3 if model.training else 1 + n_frozen_params = num_parameters(model, requires_grad=False) + frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params) + # forward + backward + frozen_ops_per_step = 2 if model.training else 1 + return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops + + +def measure_flops(model: GPT, x: torch.Tensor) -> int: + """Measures real FLOPs for HFU""" + flop_counter = FlopCounterMode(model, display=False) + ctx = nullcontext() if model.training else torch.no_grad() + with ctx, flop_counter: + y = model(x) + if model.training: + y.sum().backward() + return flop_counter.get_total_flops() diff --git a/lit_gpt/tokenizer.py b/lit_gpt/tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..1907ca796bbd2fec2b53724975e8c836513f2c90 --- /dev/null +++ b/lit_gpt/tokenizer.py @@ -0,0 +1,103 @@ +import json +from pathlib import Path +from typing import Optional + +import torch + + +class Tokenizer: + def __init__(self, checkpoint_dir: Path) -> None: + self.use_bos = self.check_if_bos_token_used(checkpoint_dir) + self.bos_id = None + self.eos_id = None + + # some checkpoints have both files, `.model` takes precedence + if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): + from sentencepiece import SentencePieceProcessor + + self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) + self.backend = "sentencepiece" + self.bos_id = self.processor.bos_id() + self.eos_id = self.processor.eos_id() + + elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): + from tokenizers import Tokenizer as HFTokenizer + + self.processor = HFTokenizer.from_file(str(vocabulary_path)) + self.backend = "huggingface" + + if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file(): + with open(special_tokens_path) as fp: + config = json.load(fp) + bos_token = config.get("bos_token") + self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None + eos_token = config.get("eos_token") + self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None + if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file(): + with open(special_tokens_path) as fp: + config = json.load(fp) + if self.bos_id is None: + self.bos_id = config.get("bos_token_id") + if self.eos_id is None: + self.eos_id = config.get("eos_token_id") + else: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + if self.backend == "huggingface": + return self.processor.get_vocab_size(with_added_tokens=False) + if self.backend == "sentencepiece": + return self.processor.vocab_size() + raise RuntimeError + + def token_to_id(self, token: str) -> int: + if self.backend == "huggingface": + id_ = self.processor.token_to_id(token) + elif self.backend == "sentencepiece": + id_ = self.processor.piece_to_id(token) + else: + raise RuntimeError + if id_ is None: + raise ValueError(f"token {token!r} not found in the collection.") + return id_ + + def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: + if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file(): + return False + with open(tokenizer_config_path) as fp: + config = json.load(fp) + if any(config.get(check, False) for check in ("add_bos_token", "add_prefix_space")): + return True + # for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True. + # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2 + return config.get("add_bos_token") is None and config.get("tokenizer_class") == "LlamaTokenizer" + + def encode( + self, + string: str, + device: Optional[torch.device] = None, + bos: Optional[bool] = None, + eos: bool = False, + max_length: int = -1, + ) -> torch.Tensor: + if self.backend == "huggingface": + tokens = self.processor.encode(string).ids + elif self.backend == "sentencepiece": + tokens = self.processor.encode(string) + else: + raise RuntimeError + if bos or (bos is None and self.use_bos): + bos_id = self.bos_id + if bos_id is None: + raise NotImplementedError("This tokenizer does not have a defined a bos token") + tokens = [bos_id] + tokens + if eos: + tokens = tokens + [self.eos_id] + if max_length > 0: + tokens = tokens[:max_length] + return torch.tensor(tokens, dtype=torch.int, device=device) + + def decode(self, tensor: torch.Tensor) -> str: + tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() + return self.processor.decode(tokens) diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..cfe0bc26ed9b90f1153f8363f33624530973570e --- /dev/null +++ b/lit_gpt/utils.py @@ -0,0 +1,311 @@ +"""Utility functions for training and inference.""" +import math +import pickle +import sys +from contextlib import nullcontext +from io import BytesIO +from pathlib import Path +from typing import ContextManager, Dict, List, Mapping, Optional, TypeVar, Union + +import lightning as L +import torch +import torch.nn as nn +import torch.utils._device +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities.load import _lazy_load as lazy_load +from torch.serialization import normalize_storage_type + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n + k - (n % k) + + +def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: + total = 0 + for p in module.parameters(): + if requires_grad is None or p.requires_grad == requires_grad: + if hasattr(p, "quant_state"): + # bitsandbytes 4bit layer support + total += math.prod(p.quant_state[1]) + else: + total += p.numel() + return total + + +def gptq_quantization(enabled: bool = False) -> ContextManager: + if not enabled: + return nullcontext() + + from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager + + from quantize.gptq import ColBlockQuantizedLinear + + class QuantizedLinear(ColBlockQuantizedLinear): + def __init__(self, *args, **kwargs): + super().__init__(*args, bits=4, tile_cols=-1, **kwargs) + + return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear}) + + +def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: + files = { + "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), + "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), + "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( + checkpoint_dir / "tokenizer.model" + ).is_file(), + "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), + } + if checkpoint_dir.is_dir(): + if all(files.values()): + # we're good + return + problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" + else: + problem = " is not a checkpoint directory" + + # list locally available checkpoints + available = list(Path("checkpoints").glob("*/*")) + if available: + options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) + extra = f"\nYou have downloaded locally:{options}\n" + else: + extra = "" + + error_message = ( + f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." + "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" + f"{extra}\nSee all download options by running:\n python scripts/download.py" + ) + print(error_message, file=sys.stderr) + raise SystemExit(1) + + +class SavingProxyForStorage: + def __init__(self, obj, saver, protocol_version=5): + self.protocol_version = protocol_version + self.saver = saver + if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): + raise TypeError(f"expected storage, not {type(obj)}") + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + if isinstance(obj, torch.storage.TypedStorage): + # PT upstream wants to deprecate this eventually... + storage = obj._untyped_storage + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + else: + storage = obj + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + storage_key = saver._write_storage_and_return_key(storage) + location = torch.serialization.location_tag(storage) + + self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) + + def __reduce_ex__(self, protocol_version): + assert False, "this should be handled with out of band" + + +class SavingProxyForTensor: + def __init__(self, tensor, saver, protocol_version=5): + self.protocol_version = protocol_version + self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version) + if reduce_args[0] == torch._utils._rebuild_tensor_v2: + # for Tensors with Python attributes + (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args + assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" + storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) + self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args) + else: + (storage, *other_reduce_args) = reduce_args + assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" + storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) + self.reduce_args = (storage_proxy, *other_reduce_args) + + def __reduce_ex__(self, protocol_version): + if protocol_version != self.protocol_version: + raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") + return self.reduce_ret_fn, self.reduce_args + + +class IncrementalPyTorchPickler(pickle.Pickler): + def __init__(self, saver, *args, **kwargs): + super().__init__(*args, **kwargs) + self.storage_dtypes = {} + self.saver = saver + self.id_map = {} + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + def persistent_id(self, obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, SavingProxyForStorage): + return obj.storage_info + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in self.storage_dtypes: + if storage_dtype != self.storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that view the same data as different types" + ) + else: + self.storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = self.id_map.get(storage._cdata) + if storage_key is None: + storage_key = self.saver._write_storage_and_return_key(storage) + self.id_map[storage._cdata] = storage_key + location = torch.serialization.location_tag(storage) + + return ("storage", storage_type, storage_key, location, storage_numel) + + return None + + +class incremental_save: + def __init__(self, name): + self.name = name + self.zipfile = torch._C.PyTorchFileWriter(str(name)) + self.has_saved = False + self.next_key = 0 + + def __enter__(self): + return self + + def store_early(self, tensor): + if isinstance(tensor, torch.Tensor): + return SavingProxyForTensor(tensor, self) + raise TypeError(f"can only store tensors early, not {type(tensor)}") + + def save(self, obj): + if self.has_saved: + raise RuntimeError("have already saved") + # Write the pickle data for `obj` + data_buf = BytesIO() + pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) + pickler.dump(obj) + data_value = data_buf.getvalue() + self.zipfile.write_record("data.pkl", data_value, len(data_value)) + self.has_saved = True + + def _write_storage_and_return_key(self, storage): + if self.has_saved: + raise RuntimeError("have already saved") + key = self.next_key + self.next_key += 1 + name = f"data/{key}" + if storage.device.type != "cpu": + storage = storage.cpu() + num_bytes = storage.nbytes() + self.zipfile.write_record(name, storage.data_ptr(), num_bytes) + return key + + def __exit__(self, type, value, traceback): + self.zipfile.write_end_of_file() + + +T = TypeVar("T") + + +def chunked_cross_entropy( + logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 +) -> torch.Tensor: + # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate + # the memory usage in fine-tuning settings with low number of parameters. + # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing + # the memory spike's magnitude + + # lm_head was chunked (we are fine-tuning) + if isinstance(logits, list): + # don't want to chunk cross entropy + if chunk_size == 0: + logits = torch.cat(logits, dim=1) + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) + + # chunk cross entropy + logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] + target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] + loss_chunks = [ + torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + return torch.cat(loss_chunks).mean() + + # no chunking at all + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + if chunk_size == 0: + return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) + + # lm_head wasn't chunked, chunk cross entropy + logit_chunks = logits.split(chunk_size) + target_chunks = targets.split(chunk_size) + loss_chunks = [ + torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + return torch.cat(loss_chunks).mean() + + +def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: + for checkpoint_name, attribute_name in mapping.items(): + full_checkpoint_name = prefix + checkpoint_name + if full_checkpoint_name in state_dict: + full_attribute_name = prefix + attribute_name + state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) + return state_dict + + +def get_default_supported_precision(training: bool) -> str: + """Return default precision that is supported by the hardware: either `bf16` or `16`. + + Args: + training: `-mixed` or `-true` version of the precision to use + + Returns: + default precision that is suitable for the task and is supported by the hardware + """ + from lightning.fabric.accelerators import MPSAccelerator + + if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()): + return "16-mixed" if training else "16-true" + return "bf16-mixed" if training else "bf16-true" + + +def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None: + if isinstance(fabric.strategy, FSDPStrategy): + fabric.load_raw(checkpoint_path, model, strict=strict) + else: + state_dict = lazy_load(checkpoint_path) + state_dict = state_dict.get("model", state_dict) + model.load_state_dict(state_dict, strict=strict) diff --git a/lit_llama/__init__.py b/lit_llama/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..c169d4c6c201d9085d5b0b95f9a1a5ca85eab4c0 --- /dev/null +++ b/lit_llama/__init__.py @@ -0,0 +1,2 @@ +from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope +from lit_llama.tokenizer import Tokenizer diff --git a/lit_llama/adapter.py b/lit_llama/adapter.py new file mode 100755 index 0000000000000000000000000000000000000000..f743c1945beb85bef610b34add8e0703388f3da6 --- /dev/null +++ b/lit_llama/adapter.py @@ -0,0 +1,151 @@ +"""Implementation of the paper: + +LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention +https://arxiv.org/abs/2303.16199 +""" +# mypy: ignore-errors +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F +import lit_llama.model as llama +from lit_llama.model import build_rope_cache, apply_rope, RMSNorm, MLP + + +@dataclass +class LLaMAConfig(llama.LLaMAConfig): + adapter_prompt_length: int = 10 + adapter_start_layer: int = 2 + + +class CausalSelfAttention(nn.Module): + """A modification of `lit_llama.model.CausalSelfAttention` that adds the attention + over the adaption prompt.""" + + def __init__(self, config: LLaMAConfig, block_idx: int) -> None: + super().__init__() + assert config.n_embd % config.n_head == 0 + + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) + + if block_idx >= config.adapter_start_layer: + # adapter embedding layer + self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) + # gate for adaption + self.gating_factor = torch.nn.Parameter(torch.zeros(1)) + + self.n_head = config.n_head + self.n_embd = config.n_embd + self.block_size = config.block_size + self.block_idx = block_idx + self.adapter_prompt_length = config.adapter_prompt_length + self.adapter_start_layer = config.adapter_start_layer + self.rope_cache = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + + head_size = C // self.n_head + k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + + if self.rope_cache is None: + # cache for future forward calls + self.rope_cache = build_rope_cache( + seq_len=self.block_size, + n_elem=self.n_embd // self.n_head, + dtype=x.dtype, + device=x.device, + ) + + q = apply_rope(q, self.rope_cache) + k = apply_rope(k, self.rope_cache) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + # att = F.softmax(att, dim=-1) + # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + + # efficient attention using Flash Attention CUDA kernels + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) + + if self.block_idx >= self.adapter_start_layer: + prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd) + + aT = prefix.size(1) + _, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) + ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) + av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) + + amask = torch.ones(q.shape[-2], ak.shape[-2], dtype=torch.bool, device=x.device) + ay = F.scaled_dot_product_attention(q, ak, av, attn_mask=amask, dropout_p=0.0, is_causal=False) + y = y + self.gating_factor * ay + + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.c_proj(y) + + return y + + +class Block(nn.Module): + """The implementation is identical to `lit_llama.model.Block` with the exception that + we replace the attention layer where adaption is implemented.""" + + def __init__(self, config: LLaMAConfig, block_idx: int) -> None: + super().__init__() + self.rms_1 = RMSNorm(config.n_embd) + self.attn = CausalSelfAttention(config, block_idx) + self.rms_2 = RMSNorm(config.n_embd) + self.mlp = MLP(config) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.rms_1(x)) + x = x + self.mlp(self.rms_2(x)) + return x + + +class LLaMA(llama.LLaMA): + """The implementation is identical to `lit_llama.model.LLaMA` with the exception that + the `Block` saves the layer index and passes it down to the attention layer.""" + + def __init__(self, config: LLaMAConfig) -> None: + nn.Module.__init__(self) + assert config.vocab_size is not None + assert config.block_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + h=nn.ModuleList([Block(config, i) for i in range(config.n_layer)]), + ln_f=RMSNorm(config.n_embd), + ) + ) + + @classmethod + def from_name(cls, name: str): + return cls(LLaMAConfig.from_name(name)) + + +def mark_only_adapter_as_trainable(model: LLaMA) -> None: + """Sets `requires_grad=False` for all non-adapter weights.""" + for name, param in model.named_parameters(): + param.requires_grad = "adapter_wte" in name or "gating_factor" in name + + +def adapter_state_from_state_dict(state_dict: dict) -> dict: + """Returns the model state dict with only the adapter weights for saving.""" + return {name: param for name, param in state_dict.items() if "adapter_wte" in name or "gating_factor" in name} diff --git a/lit_llama/indexed_dataset.py b/lit_llama/indexed_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..c93af7d0a35a347d85e5622f8c893d86c49cf090 --- /dev/null +++ b/lit_llama/indexed_dataset.py @@ -0,0 +1,588 @@ +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of the FairSeq source tree. + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +from functools import lru_cache +import os +import shutil +import struct +from itertools import accumulate + +import numpy as np +import torch + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ['lazy', 'cached', 'mmap'] + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return 'cached' + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return 'mmap' + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == 'mmap': + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == 'infer': + impl = infer_dataset_impl(path) + if impl == 'lazy' and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == 'cached' and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == 'mmap' and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == 'mmap': + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float32, + 7: np.float64, + 8: np.uint16 +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + '.idx' + + +def data_file_path(prefix_path): + return prefix_path + '.bin' + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + _HDR_MAGIC = b'TNTIDX\x00\x00' + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), 'rb') as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + 'Index file doesn\'t match expected format. ' + 'Make sure that --dataset-impl is configured properly.' + ) + version = f.read(8) + assert struct.unpack('= self._len: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return ( + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx: ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx: ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float32: 4, + np.float64: 8 + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, 'wb') + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + doc_offset = len(self.sizes) + + begin = self.data_offsets[-1] + for data_offset in index.data_offsets[1:]: + self.data_offsets.append(begin + data_offset) + self.sizes.extend(index.sizes) + + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + self.doc_idx.extend((doc_offset + index.doc_idx)[1:]) + + with open(data_file_path(another_file), 'rb') as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, 'wb') + index.write(b'TNTIDX\x00\x00') + index.write(struct.pack(' 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + +class MergedLinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + enable_lora: List[bool] = [False], + fan_in_fan_out: bool = False, + merge_weights: bool = True, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + assert out_features % len(enable_lora) == 0, \ + 'The length of enable_lora must divide out_features' + self.enable_lora = enable_lora + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0 and any(enable_lora): + self.lora_A = nn.Parameter( + self.weight.new_zeros((r * sum(enable_lora), in_features))) + self.lora_B = nn.Parameter( + self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) + ) # weights for Conv1D with groups=sum(enable_lora) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + # Compute the indices + self.lora_ind = self.weight.new_zeros( + (out_features, ), dtype=torch.bool + ).view(len(enable_lora), -1) + self.lora_ind[enable_lora, :] = True + self.lora_ind = self.lora_ind.view(-1) + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.T + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def zero_pad(self, x): + result = x.new_zeros((*x.shape[:-1], self.out_features)) + result = result.view(-1, self.out_features) + result[:, self.lora_ind] = x.reshape( + -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) + ) + return result.view((*x.shape[:-1], self.out_features)) + + def train(self, mode: bool = True): + def T(w): + return w.T if self.fan_in_fan_out else w + nn.Linear.train(self, mode) + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + if self.r > 0 and any(self.enable_lora): + delta_w = F.conv1d( + self.lora_A.data.unsqueeze(0), + self.lora_B.data.unsqueeze(-1), + groups=sum(self.enable_lora) + ).squeeze(0) + self.weight.data -= self.zero_pad(T(delta_w * self.scaling)) + self.merged = False + + def eval(self): + def T(w): + return w.T if self.fan_in_fan_out else w + nn.Linear.eval(self) + if self.merge_weights and not self.merged: + # Merge the weights and mark it + if self.r > 0 and any(self.enable_lora): + delta_w = F.conv1d( + self.lora_A.data.unsqueeze(0), + self.lora_B.data.unsqueeze(-1), + groups=sum(self.enable_lora) + ).squeeze(0) + self.weight.data += self.zero_pad(T(delta_w * self.scaling)) + self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.T if self.fan_in_fan_out else w + if self.merged: + return F.linear(x, T(self.weight), bias=self.bias) + else: + result = F.linear(x, T(self.weight), bias=self.bias) + if self.r > 0: + after_A = F.linear(self.lora_dropout(x), self.lora_A) + after_B = F.conv1d( + after_A.transpose(-2, -1), + self.lora_B.unsqueeze(-1), + groups=sum(self.enable_lora) + ).transpose(-2, -1) + result += self.zero_pad(after_B) * self.scaling + return result + + +def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: + # import pdb; pdb.set_trace() + for n, p in model.named_parameters(): + if 'lora_' not in n and 'motion_proj' not in n and 'llama_proj' not in n: + p.requires_grad = False + if bias == 'none': + return + elif bias == 'all': + for n, p in model.named_parameters(): + if 'bias' in n: + p.requires_grad = True + elif bias == 'lora_only': + for m in model.modules(): + if isinstance(m, LoRALayer) and \ + hasattr(m, 'bias') and \ + m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: + my_state_dict = model.state_dict() + if bias == 'none': + return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'llama_proj' in k or 'motion_proj' in k} + elif bias == 'all': + return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k or 'llama_proj' in k or 'motion_proj' in k} + elif bias == 'lora_only': + to_return = {} + for k in my_state_dict: + if 'lora_' in k: + to_return[k] = my_state_dict[k] + bias_name = k.split('lora_')[0]+'bias' + if bias_name in my_state_dict: + to_return[bias_name] = my_state_dict[bias_name] + return to_return + else: + raise NotImplementedError + + +@dataclass +class LoRAConfig: + r: float = 0.0 + alpha: float = 1.0 + dropout: float = 0.0 + + +class CausalSelfAttention(llama.CausalSelfAttention): + lora_config = None + + def __init__(self, config: llama.LLaMAConfig) -> None: + # Skip the parent class __init__ altogether and replace it to avoid + # useless allocations + nn.Module.__init__(self) + assert config.n_embd % config.n_head == 0 + + # key, query, value projections for all heads, but in a batch + self.c_attn = MergedLinear( + in_features=config.n_embd, + out_features=3 * config.n_embd, + r=self.lora_config.r, + lora_alpha=self.lora_config.alpha, + lora_dropout=self.lora_config.dropout, + enable_lora=[True, False, True], + fan_in_fan_out = False, + merge_weights=True, + bias=False) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) + # regularization + self.n_head = config.n_head + self.n_embd = config.n_embd + self.block_size = config.block_size + self.rope_cache = None + + +@contextmanager +def lora(r, alpha, dropout, enabled: bool = True): + """A context manager under which you can instantiate the model with LoRA.""" + if not enabled: + yield + return + + CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout) + + causal_self_attention = llama.CausalSelfAttention + llama.CausalSelfAttention = CausalSelfAttention + yield + llama.CausalSelfAttention = causal_self_attention + + CausalSelfAttention.lora_config = None diff --git a/lit_llama/model.py b/lit_llama/model.py new file mode 100755 index 0000000000000000000000000000000000000000..83cd976a5f157b21b365665a9d5c599f6b3efcdd --- /dev/null +++ b/lit_llama/model.py @@ -0,0 +1,246 @@ +"""Full definition of a LLaMA Language Model, all of it in this single file. + +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. +""" +# mypy: ignore-errors +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F +from typing_extensions import Self + + +@dataclass +class LLaMAConfig: + block_size: int = 4096 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + n_embd: int = 4096 + + @classmethod + def from_name(cls, name: str) -> Self: + return cls(**llama_configs[name]) + + +llama_configs = { + "7B": dict(n_layer=32, n_head=32, n_embd=4096), + "13B": dict(n_layer=40, n_head=40, n_embd=5120), + "30B": dict(n_layer=60, n_head=52, n_embd=6656), + "65B": dict(n_layer=80, n_head=64, n_embd=8192), +} + + +class LLaMA(nn.Module): + def __init__(self, config: LLaMAConfig) -> None: + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f=RMSNorm(config.n_embd), + ) + ) + # self.llama_proj = nn.Sequential( + # nn.Linear(256, 1024), + # nn.ReLU(), + # nn.Linear(1024, config.n_embd) + # ) + self.llama_proj = nn.Linear(512, config.n_embd) + # self.motion_proj = nn.Sequential( + # nn.Linear(config.n_embd, 1024), + # nn.ReLU(), + # nn.Linear(1024, 256) + # ) + self.motion_proj = nn.Linear(config.n_embd, 512) + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) + + def forward(self, idx: torch.Tensor) -> torch.Tensor: + # import pdb; pdb.set_trace() + _, t = idx.size() + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the LLaMA model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + + logits = self.lm_head(x) # (b, t, vocab_size) + + return logits + + @classmethod + def from_name(cls, name: str) -> Self: + return cls(LLaMAConfig.from_name(name)) + + +class Block(nn.Module): + def __init__(self, config: LLaMAConfig) -> None: + super().__init__() + self.rms_1 = RMSNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.rms_2 = RMSNorm(config.n_embd) + self.mlp = MLP(config) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.rms_1(x)) + x = x + self.mlp(self.rms_2(x)) + return x + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: LLaMAConfig) -> None: + super().__init__() + assert config.n_embd % config.n_head == 0 + + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) + + self.n_head = config.n_head + self.n_embd = config.n_embd + self.block_size = config.block_size + self.rope_cache = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + + head_size = C // self.n_head + k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + + if self.rope_cache is None: + # cache for future forward calls + self.rope_cache = build_rope_cache( + seq_len=self.block_size, + n_elem=self.n_embd // self.n_head, + dtype=x.dtype, + device=x.device, + ) + + q = apply_rope(q, self.rope_cache) + k = apply_rope(k, self.rope_cache) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + # att = F.softmax(att, dim=-1) + # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + + # efficient attention using Flash Attention CUDA kernels + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) + + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.c_proj(y) + + return y + + +class MLP(nn.Module): + def __init__(self, config: LLaMAConfig) -> None: + super().__init__() + hidden_dim = 4 * config.n_embd + n_hidden = int(2 * hidden_dim / 3) + N = 256 + # ensure n_hidden is multiple of N + n_hidden = ((n_hidden - 1) // N) * N + N + + self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) + self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) + self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.c_fc1(x)) * self.c_fc2(x) + x = self.c_proj(x) + return x + + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: + super().__init__() + self.scale = nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE: the original RMSNorm paper implementation is not equivalent + # norm_x = x.norm(2, dim=self.dim, keepdim=True) + # rms_x = norm_x * d_x ** (-1. / 2) + # x_normed = x / (rms_x + self.eps) + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return self.scale * x_normed + + +def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta) + + # Compute cache. Because polar only takes float32 or float64, we need to cast + # when working with 16 bit floats (float16 or bfloat16) + dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8] + working_dtype = ( + torch.float32 if dtype in dtypes_requiring_casting else dtype + ) + complex_dtype = ( + torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64 + ) + cache = torch.polar( + torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype) + ).to(complex_dtype) + return cache + + +def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + x = x.transpose(1, 2) + + # truncate to support variable sizes + T = x.size(1) + rope_cache = rope_cache[:T] + + # cast because `view_as_complex` does not support 16 bit tensors + xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3)) + x_out = torch.view_as_real(xc * rope_cache).flatten(3) + return x_out.transpose(1, 2).type_as(x) diff --git a/lit_llama/quantization.py b/lit_llama/quantization.py new file mode 100755 index 0000000000000000000000000000000000000000..bb09a1fc23177524142424734f2572d4a3788b5a --- /dev/null +++ b/lit_llama/quantization.py @@ -0,0 +1,281 @@ +import os +from contextlib import contextmanager +import warnings +import math + +import torch + +# configuration for bitsandbytes before import +os.environ["BITSANDBYTES_NOWELCOME"] = "1" +warnings.filterwarnings( + "ignore", + message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization" +) +warnings.filterwarnings( + "ignore", + message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization" +) +warnings.filterwarnings( + "ignore", + message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable." +) + +try: + import bitsandbytes as bnb # noqa: E402 +except: + bnb = None + +if bnb is not None: + class Linear8bitLt(bnb.nn.Linear8bitLt): + """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and + re-quantizaton when loading the state dict. + + + This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0) + # We quantize the initial weight here so we don't end up filling the device + # memory with float32 weights which could lead to OOM. + self._quantize_weight(self.weight.data) + + def _load_from_state_dict(self, local_state_dict, *args, **kwargs): + # There is only one key that ends with `*.weight`, the other one is the bias + weight_key = next((name for name in local_state_dict.keys() if name.endswith("weight")), None) + if weight_key is None: + return + + # Load the weight from the state dict and re-quantize it + weight = local_state_dict.pop(weight_key) + self._quantize_weight(weight) + + # If there is a bias, let nn.Module load it + if local_state_dict: + super()._load_from_state_dict(local_state_dict, *args, **kwargs) + + def _quantize_weight(self, weight: torch.Tensor) -> None: + # This code is taken and adapted from `bnb.nn.Int8Params.cuda()` + B = weight.contiguous().half().cuda() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + self.weight.data = CB + setattr(self.weight, "CB", CB) + setattr(self.weight, "SCB", SCB) + + +# for correctness but with terrible perf +class ColBlockQuantizedLinear(torch.nn.Module): + def __init__(self, in_features, out_features, bias: bool, *, bits, tile_cols): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.tile_cols = tile_cols if tile_cols != -1 else self.in_features + self.bits = bits + self.entries_per_byte = 8 // bits + assert self.entries_per_byte > 0 and self.entries_per_byte * self.bits == 8 + assert in_features % self.entries_per_byte == 0 + self.register_buffer("quant_weight", torch.empty((self.out_features, self.in_features // self.entries_per_byte), dtype=torch.uint8)) + self.register_buffer("scales", torch.empty((self.out_features, (self.in_features + self.tile_cols - 1) // self.tile_cols))) + self.register_buffer("zeros", torch.empty_like(self.scales)) + assert isinstance(bias, bool) + if bias: + self.register_buffer("bias", torch.empty((self.out_features,))) + else: + self.register_buffer("bias", None) + + def pack_weight(self, weight): + weight = weight.to(device=self.quant_weight.device, copy=True) + for j in range(self.scales.size(1)): + weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] /= self.scales[: , j: j+1] + weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] += self.zeros[: , j: j+1] + weight = weight.clamp_(min=0, max=2 ** self.bits - 1).to(dtype=torch.uint8) + self.quant_weight.zero_() + for nr in range(self.entries_per_byte): + self.quant_weight += weight[:, nr::self.entries_per_byte] << (nr * self.bits) + + def get_weight(self, dtype=torch.float): + weight = torch.empty((self.out_features, self.in_features), device=self.quant_weight.device, dtype=dtype) + mask = (1<> (nr * self.bits)) & mask).float() + self.quant_weight.to(dtype) + for j in range(self.scales.size(1)): + weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] -= self.zeros[: , j: j+1] + weight[:, j * self.tile_cols: (j + 1) * self.tile_cols] *= self.scales[: , j: j+1] + return weight + + def forward(self, inp): + weight = self.get_weight(dtype=inp.dtype) + return torch.nn.functional.linear(inp, weight, self.bias) + + + + +class GPTQQuantizer: + # The algorithm and code has been taken from https://github.com/IST-DASLab/gptq/ + # E. Frantar et al GPTQ: Accurate Post-training Compression for GPT, arXiv:2210.17323 + # portions copyright by the authors licensed under the Apache License 2.0 + # All errors are our own. + + def __init__(self, linear_module, *, bits, perchannel=True, sym=False, blocksize=128, percdamp=.01, groupsize=-1, actorder=False): + assert isinstance(linear_module, torch.nn.Linear) + + self.linear_module = linear_module + self.dev = self.linear_module.weight.device + self.rows = linear_module.weight.shape[0] + self.columns = linear_module.weight.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.bits = bits + self.maxq = 2 ** bits - 1 + self.perchannel = perchannel + self.sym = sym + self.blocksize = blocksize + self.percdamp = percdamp + self.groupsize = groupsize + self.actorder = actorder + self.tile_cols = self.columns if groupsize == -1 else groupsize + self.scales = torch.zeros((self.rows, (self.columns + self.tile_cols - 1) // self.tile_cols), dtype=self.linear_module.weight.dtype, device = self.dev) + self.zeros = torch.zeros_like(self.scales) + assert not (self.actorder and self.groupsize != -1), "The permutation trick does not work for grouped quantization" + + @staticmethod + def quantize_weight(x, scale, zero, maxq): + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + x_rec = scale * (q - zero) + return x_rec + + def find_params_weight(self, x): + dev = x.device + + shape = x.shape + if self.perchannel: + x = x.flatten(1) + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + scale = (xmax - xmin) / self.maxq + if self.sym: + zero = torch.full_like(scale, (self.maxq + 1) / 2) + else: + zero = torch.round(-xmin / scale) + + if not self.perchannel: + tmp = shape[0] + scale = scale.repeat(tmp) + zero = zero.repeat(tmp) + + shape = [-1] + [1] * (len(shape) - 1) + scale = scale.reshape(shape) + zero = zero.reshape(shape) + return scale, zero + + def collect_input_stats(self, _1, inp, _2): + inp = inp[0].detach() + self.last_inp = inp + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def quantize(self): + W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True) + + scale, zero = self.find_params_weight(W) + self.scales[:] = scale + self.zeros[:] = zero + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + if self.actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = self.percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + for i1 in range(0, self.columns, self.blocksize): + i2 = min(i1 + self.blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if self.groupsize != -1: + if (i1 + i) % self.groupsize == 0: + scale, zero = self.find_params_weight(W[:, (i1 + i):(i1 + i + self.groupsize)]) + self.scales[:, (i1 + i) // self.groupsize] = scale + self.zeros[:, (i1 + i) // self.groupsize] = zeros + + q = self.quantize_weight( + w.unsqueeze(1), scale, zero, self.maxq + ) + q = q.squeeze(1) + assert q.dim() == 1 + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + if self.actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + + weight = Q.reshape(self.linear_module.weight.shape).to(self.linear_module.weight.data.dtype) + error = torch.sum(Losses).item() + + q_module = ColBlockQuantizedLinear(self.linear_module.in_features, self.linear_module.out_features, self.linear_module.bias is not None, + bits=self.bits, tile_cols=self.groupsize).to(self.dev) + q_module.scales = self.scales + q_module.zeros = self.zeros + q_module.pack_weight(weight) + q_module.bias = self.linear_module.bias + return q_module, error diff --git a/lit_llama/tokenizer.py b/lit_llama/tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..fb681e3f51e697902cd3cb0bdcadee4ac3306f2d --- /dev/null +++ b/lit_llama/tokenizer.py @@ -0,0 +1,49 @@ +import os +from pathlib import Path +from typing import Optional + +import torch +from sentencepiece import SentencePieceProcessor, SentencePieceTrainer + + +class Tokenizer: + """Tokenizer for LLaMA.""" + + def __init__(self, model_path: Path) -> None: + self.processor = SentencePieceProcessor(model_file=str(model_path)) + self.bos_id = self.processor.bos_id() + self.eos_id = self.processor.eos_id() + self.pad_id = self.processor.pad_id() + + @property + def vocab_size(self) -> int: + return self.processor.vocab_size() + + def encode( + self, + string: str, + bos: bool = True, + eos: bool = False, + max_length: int = -1, + pad: bool = False, + device: Optional[torch.device] = None + ) -> torch.Tensor: + tokens = self.processor.encode(string) + if bos: + tokens = [self.bos_id] + tokens + if eos: + tokens = tokens + [self.eos_id] + if max_length > 0: + tokens = tokens[:max_length] + if pad and len(tokens) < max_length: + tokens += [self.pad_id] * (max_length - len(tokens)) + + return torch.tensor(tokens, dtype=torch.int, device=device) + + def decode(self, tokens: torch.Tensor) -> str: + return self.processor.decode(tokens.tolist()) + + @staticmethod + def train(input: str, destination: str, vocab_size=32000) -> None: + model_prefix = os.path.join(destination, "tokenizer") + SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size) diff --git a/lit_llama/utils.py b/lit_llama/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..3727d9c61f65ea613be6fcc2d51a9eabd94c2940 --- /dev/null +++ b/lit_llama/utils.py @@ -0,0 +1,244 @@ +"""Utility functions for training and inference.""" + +import functools +from pathlib import Path +import pickle +import warnings +from io import BytesIO + +import torch +import torch.utils._device +from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType + + +def save_model_checkpoint(fabric, model, file_path): + """Handles boilerplate logic for retrieving and saving the state_dict. + + This will be upstreamed to Fabric soon. + """ + file_path = Path(file_path) + + if isinstance(fabric.strategy, DeepSpeedStrategy): + from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + + fabric.save(file_path, {"model": model}) + fabric.barrier() + if fabric.global_rank == 0: + # Create a consolidated checkpoint with the same name next to the deepspeed checkpoint + convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth")) + return + + if isinstance(fabric.strategy, FSDPStrategy): + save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): + state_dict = model._forward_module.state_dict() + else: + state_dict = model.state_dict() + + if fabric.global_rank == 0: + torch.save(state_dict, file_path) + fabric.barrier() + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + def __init__(self, device=None, dtype=None, quantization_mode=None): + """ + Create tensors with given device and dtype and don't run initialization + (but instead use "empty tensors", i.e. uninitialized memory). + + device: `torch.device` to work with + dtype: `torch.dtype` to work with + quantization_mode: optional string, quantization mode to work with, default `None`. + Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU) + `qptq.int4`, `gptq.int8`: GPTQ pre-quantized models + + Example:: + with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): + model = LLaMA.from_name('7B') + model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))""" + + self.quantization_mode = quantization_mode + self.quantized_linear_cls = None + if self.quantization_mode == 'llm.int8': + if device.type != "cuda": + raise ValueError("Quantization is only supported on the GPU.") + from .quantization import Linear8bitLt + self.quantized_linear_cls = Linear8bitLt + elif self.quantization_mode == 'gptq.int4': + from .quantization import ColBlockQuantizedLinear + self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1) + elif self.quantization_mode == 'gptq.int8': + from .quantization import ColBlockQuantizedLinear + self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1) + elif self.quantization_mode is not None: + raise RuntimeError(f"unknown quantization mode {self.quantization_mode}") + self.device = device + self.dtype = dtype + + def __enter__(self): + if self.quantized_linear_cls != None: + self.torch_linear_cls = torch.nn.Linear + torch.nn.Linear = self.quantized_linear_cls + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.quantized_linear_cls != None: + torch.nn.Linear = self.torch_linear_cls + return super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, "__module__", None) == "torch.nn.init": + if "tensor" in kwargs: + return kwargs["tensor"] + else: + return args[0] + if ( + self.device is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("device") is None + ): + kwargs["device"] = self.device + if ( + self.dtype is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("dtype") is None + ): + kwargs["dtype"] = self.dtype + return func(*args, **kwargs) + + +# this is taken from torchhacks https://github.com/lernapparat/torchhacks + + +class NotYetLoadedTensor: + def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): + self.metatensor = metatensor + self.archiveinfo = archiveinfo + self.storageinfo = storageinfo + self.rebuild_args = rebuild_args + + @classmethod + def rebuild( + cls, + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata=None, + archiveinfo=None, + ): + rebuild_args = ( + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata, + ) + metatensor = torch._utils._rebuild_tensor_v2( + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata, + ) + storageinfo = storage.archiveinfo + return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) + + def _load_tensor(self): + name, storage_cls, fn, device, size = self.storageinfo + dtype = self.metatensor.dtype + + uts = ( + self.archiveinfo.zipfile.get_storage_from_record( + f"data/{fn}", + size * torch._utils._element_size(dtype), + torch.UntypedStorage, + ) + ._typed_storage() + ._untyped_storage + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + storage = torch.storage.TypedStorage( + wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True + ) + tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) + return tensor + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + loaded_args = [ + (a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args + ] + res = func(*loaded_args, **kwargs) + # gc.collect would be costly here, maybe do it optionally + return res + + def __getattr__(self, name): + # properties + ## TODO: device, is_...?? + ## TODO: mH, mT, H, T, data, imag, real + ## name ??? + if name in { + "dtype", + "grad", + "grad_fn", + "layout", + "names", + "ndim", + "output_nr", + "requires_grad", + "retains_grad", + "shape", + "volatile", + }: + return getattr(self.metatensor, name) + if name in {"size"}: + return getattr(self.metatensor, name) + # materializing with contiguous is needed for quantization + if name in {"contiguous"}: + return getattr(self._load_tensor(), name) + + raise AttributeError(f"{type(self)} does not have {name}") + + def __repr__(self): + return f"NotYetLoadedTensor({repr(self.metatensor)})" + + +class LazyLoadingUnpickler(pickle.Unpickler): + def __init__(self, file, zipfile): + super().__init__(file) + self.zipfile = zipfile + + def find_class(self, module, name): + if module == "torch._utils" and name == "_rebuild_tensor_v2": + res = super().find_class(module, name) + return functools.partial(NotYetLoadedTensor.rebuild, archiveinfo=self) + return super().find_class(module, name) + + def persistent_load(self, pid): + name, cls, fn, device, size = pid + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") + s.archiveinfo = pid + return s + + +def lazy_load(fn): + zf = torch._C.PyTorchFileReader(str(fn)) + with BytesIO(zf.get_record("data.pkl")) as pkl: + mup = LazyLoadingUnpickler(pkl, zf) + sd = mup.load() + return sd diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/constants.py b/models/constants.py new file mode 100755 index 0000000000000000000000000000000000000000..f1bcfaedb217f9ed2f3c399f8e5ccae68d6d189a --- /dev/null +++ b/models/constants.py @@ -0,0 +1,18 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +X_TOKEN_INDEX = {'IMAGE': -200, 'VIDEO': -201, 'AUDIO': -202, 'THERMAL': -203, 'DEPTH': -204} +X_INDEX_TOKEN = {v: k for k, v in X_TOKEN_INDEX.items()} +# IMAGE_TOKEN_INDEX = -200 +DEFAULT_X_TOKEN = {'IMAGE': "", 'VIDEO': "