--- license: mit pipeline_tag: text-generation ---
UltraGist is a context compression method can **flexibly**, **effectively**, and **efficiently** to handle various context lengths and compression ratios. We apply UltraGist on [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). ## Usage ```python import json import torch from transformers import AutoModelForCausalLM, AutoTokenizer model_id = "namespace-Pt/ultragist-llama2-7b-chat" tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, ) model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa", # load the entire model on the default gpu device_map={"": "cuda"}, # you can manually set the compression ratio, otherwise the model will automatically choose the most suitable compression ratio from [2,4,8,16,32] # ultragist_ratio=[8], ).eval() with torch.no_grad(): # long context with open("data/nqa.json", encoding="utf-8") as f: example = json.load(f) content = f"Read this article:\n\n{example['context']}\n\nNow, answer the question based on the above context.\nQuestion:\n{example['input']}" messages = [{"role": "user", "content": content}] inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda") # reset memory before new compression task model.memory.reset() # directly call generate to progressively compress the context while generating next tokens outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=40)[:, inputs["input_ids"].shape[1]:] print("*"*20) print(f"Input size: {inputs['input_ids'].shape[1]}") print(f"Question: {example['input']}") print(f"Answers: {example['answers']}") print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") print("*"*20) # extract the compressed memory (including the generated tokens) compressed_memory = model.memory.get_memory() ultragist_size, raw_size, sink_size = model.memory.get_memory_size() print(f"UltraGist size: {ultragist_size}") print(f"Raw size: {raw_size}") print(f"Sink size: {sink_size}") print(f"Memory: {compressed_memory[0][0].shape}") print("*"*20) ```