File size: 2,575 Bytes
d88ea7a
 
 
 
 
 
 
 
b381730
d88ea7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---
license: mit
pipeline_tag: text-generation
---

<div align="center">
<h1>UltraGist for Llama-2-7b-chat</h1>

[<a href="https://arxiv.org/abs/2405.16635">Paper</a>] [<a href="https://github.com/namespace-Pt/UltraGist">Github</a>]
</div>

UltraGist is a context compression method can **flexibly**, **effectively**, and **efficiently** to handle various context lengths and compression ratios. We apply UltraGist on [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf).

## Usage
```python
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "namespace-Pt/ultragist-llama2-7b-chat"

tokenizer = AutoTokenizer.from_pretrained(
  model_id, 
  trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
  model_id, 
  trust_remote_code=True, 
  torch_dtype=torch.bfloat16, 
  attn_implementation="sdpa",
  # load the entire model on the default gpu
  device_map={"": "cuda"}, 
  # you can manually set the compression ratio, otherwise the model will automatically choose the most suitable compression ratio from [2,4,8,16,32]
  # ultragist_ratio=[8],
).eval()


with torch.no_grad():
  # long context
  with open("data/nqa.json", encoding="utf-8") as f:
    example = json.load(f)
    content = f"Read this article:\n\n{example['context']}\n\nNow, answer the question based on the above context.\nQuestion:\n{example['input']}"
  messages = [{"role": "user", "content": content}]
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")

  # reset memory before new compression task
  model.memory.reset()

  # directly call generate to progressively compress the context while generating next tokens
  outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=40)[:, inputs["input_ids"].shape[1]:]
  print("*"*20)
  print(f"Input size:       {inputs['input_ids'].shape[1]}")
  print(f"Question:         {example['input']}")
  print(f"Answers:          {example['answers']}")
  print(f"Prediction:       {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
  print("*"*20)

  # extract the compressed memory (including the generated tokens)
  compressed_memory = model.memory.get_memory()
  ultragist_size, raw_size, sink_size = model.memory.get_memory_size()
  print(f"UltraGist size:   {ultragist_size}")
  print(f"Raw size:         {raw_size}")
  print(f"Sink size:        {sink_size}")
  print(f"Memory:           {compressed_memory[0][0].shape}")
  print("*"*20)
```