Suri-SFT
Suri-SFT is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.2 using supervised fine-tuning with LoRA. Please check our paper for more details on the method.
π Model Details
Model Description
- Language(s) (NLP): English
- License: Apache-2.0
- Finetuned from model: mistralai/Mistral-7B-Instruct-v0.2
Model Sources
- Repository: Github repository -- contains code to reconstruct books3 subset.
- Paper: Link
- Demo: Website
β οΈ Getting Started
Use the code in this repository for training and inference.
π» Training Details
Training Data
Training Procedure
Configurations | Values |
---|---|
Hardware (Training and Inference) | 4xA100s |
Tracking | wandb |
lora_r | 16 |
lora_alpha | 16 |
lora_dropout | 0.05 |
gradient_accumulation_steps | 1 |
gradient_checkpointing | True |
learning_rate | 5.0e-5 |
lr_scheduler_type | cosine |
max_length | 15024 |
max_completion_length | 15000 |
max_prompt_length | 5000 |
num_train_epochs | 2 |
optim | adamw_torch |
per_device_train_batch_size | 1 |
Software
Training code is adapted from Alignment Handbook and Trl.
π€ Inference
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from datasets import load_dataset
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "False"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
model_name = "chtmp223/suri-sft"
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
config = PeftConfig.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
model = PeftModel.from_pretrained(base_model, model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
prompt = [
{
"role": "user",
"content": user_prompt,
}
]
input_context = tokenizer.apply_chat_template(
prompt, add_generation_prompt=True, tokenize=False
)
input_ids = tokenizer.encode(
input_context, return_tensors="pt", add_special_tokens=False
).to(model.device)
output = model.generate(
input_ids, max_length=10000, do_sample=True, use_cache=True
).cpu()
print(tokenizer.decode(output[0]))
π Citation
@misc{pham2024surimulticonstraintinstructionfollowing,
title={Suri: Multi-constraint Instruction Following for Long-form Text Generation},
author={Chau Minh Pham and Simeng Sun and Mohit Iyyer},
year={2024},
eprint={2406.19371},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2406.19371},
}
βοΈ Framework versions
- PEFT 0.11.1
- Downloads last month
- 6
Model tree for chtmp223/suri-sft
Base model
mistralai/Mistral-7B-Instruct-v0.2