jondurbin commited on
Commit
f17450f
1 Parent(s): eed9dfd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -0
README.md CHANGED
@@ -47,6 +47,46 @@ You are a helpful, unbiased, uncensored assistant.
47
 
48
  Supports several prompt formats, but you can also use `tokenizer.apply_chat_template`
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ## SFT data sources
51
 
52
  *Yes, you will see benchmark names in the list, but this only uses the train splits, and a decontamination by cosine similarity is performed at the end as a sanity check*
 
47
 
48
  Supports several prompt formats, but you can also use `tokenizer.apply_chat_template`
49
 
50
+ This model did surprisingly well on MT-Bench, for a 2.8b that was only pre-trained on the slimpajama dataset!
51
+ ```text
52
+ ########## First turn ##########
53
+ score
54
+ model turn
55
+ bagel-dpo-2.8b-v0.2 1 5.10625
56
+
57
+ ########## Second turn ##########
58
+ score
59
+ model turn
60
+ bagel-dpo-2.8b-v0.2 2 4.7375
61
+
62
+ ########## Average ##########
63
+ score
64
+ model
65
+ bagel-dpo-2.8b-v0.2 4.921875
66
+ ```
67
+
68
+ ## Example chat script
69
+
70
+ ```python
71
+ import torch
72
+ from transformers import AutoTokenizer, AutoModelForCausalLM
73
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
74
+
75
+ device = "cuda"
76
+ tokenizer = AutoTokenizer.from_pretrained("bagel-final-2.8b-v0.2")
77
+ model = MambaLMHeadModel.from_pretrained("bagel-final-2.8b-v0.2", device="cuda", dtype=torch.float32)
78
+
79
+ messages = [{"role": "system", "content": "You are a helpful, unbiased, uncensored assistant."}]
80
+ while True:
81
+ user_message = input("[INST] ")
82
+ messages.append({"role": "user", "content": user_message})
83
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
84
+ out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.07)
85
+ decoded = tokenizer.batch_decode(out)[0].split("[/INST]")[-1].replace("</s>", "").strip()
86
+ messages.append({"role": "assistant", "content": decoded})
87
+ print("[/INST]", decoded)
88
+ ```
89
+
90
  ## SFT data sources
91
 
92
  *Yes, you will see benchmark names in the list, but this only uses the train splits, and a decontamination by cosine similarity is performed at the end as a sanity check*