vicgalle commited on
Commit
e501d43
1 Parent(s): 66fe41c

Upload train_human.py

Browse files
Files changed (1) hide show
  1. train_human.py +146 -0
train_human.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ from datasets import load_dataset, Dataset
5
+
6
+ from trl import DPOTrainer, DPOConfig
7
+ from peft import LoraConfig
8
+ from peft import prepare_model_for_kbit_training
9
+ import torch
10
+
11
+ import pandas as pd
12
+
13
+ # %%
14
+ dataset = load_dataset("Undi95/Weyaxi-humanish-dpo-project-noemoji")["train"]
15
+
16
+ model_name = "Undi95/Meta-Llama-3.1-8B-Claude-bf16"
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ tokenizer.padding_side = "right"
20
+
21
+ tokenizer.pad_token = tokenizer.eos_token
22
+
23
+ # %%
24
+ tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
25
+
26
+
27
+ # %%
28
+ dataset2 = load_dataset("ResplendentAI/NSFW_RP_Format_DPO")['train']
29
+
30
+ # %%
31
+ dataset = dataset.to_pandas(
32
+ )
33
+ dataset2 = dataset2.to_pandas()
34
+
35
+ dataset = Dataset.from_pandas(pd.concat([dataset.sample(400), dataset2]).sample(frac=1))
36
+
37
+ # %%
38
+ def template_prompt(system, prompt):
39
+ if system is None:
40
+ messages = [
41
+ {"role": "user", "content": prompt},
42
+ ]
43
+ else:
44
+ messages = [
45
+ {
46
+ "role": "system",
47
+ "content": system,
48
+ },
49
+ {"role": "user", "content": prompt},
50
+ ]
51
+ prompt = tokenizer.apply_chat_template(
52
+ messages, tokenize=False, add_generation_prompt=False
53
+ )
54
+ return prompt
55
+
56
+
57
+ def template_answer(answer):
58
+ messages = [
59
+ {
60
+ "role": "assistant",
61
+ "content": answer,
62
+ },
63
+ ]
64
+ answer = tokenizer.apply_chat_template(
65
+ messages, tokenize=False, add_generation_prompt=False
66
+ )
67
+ return answer
68
+
69
+ # %%
70
+ # create new columns
71
+ dataset = dataset.map(
72
+ lambda x: {
73
+ "prompt": template_prompt(None, x["prompt"]).replace("<|start_header_id|>assistant<|end_header_id|>\n\n", "")
74
+ }, # change this according to the dataset!!!
75
+ )
76
+
77
+ # %%
78
+ dataset = dataset.map(
79
+ lambda x: {"chosen": template_answer(x["chosen"]).replace('<|begin_of_text|>', '').replace('><|start_header_id|>assistant<|end_header_id|>\n\n', '>')},
80
+ )
81
+ dataset = dataset.map(
82
+ lambda x: {"rejected": template_answer(x["rejected"]).replace('<|begin_of_text|>', '').replace('><|start_header_id|>assistant<|end_header_id|>\n\n', '>')},
83
+ )
84
+
85
+ # %%
86
+ dataset[0]
87
+
88
+ # %%
89
+ # LoRA configuration
90
+ peft_config = LoraConfig(
91
+ r=16,
92
+ lora_alpha=32,
93
+ lora_dropout=0.05,
94
+ bias="none",
95
+ task_type="CAUSAL_LM",
96
+ target_modules=[
97
+ "k_proj",
98
+ "gate_proj",
99
+ "v_proj",
100
+ "up_proj",
101
+ "q_proj",
102
+ "o_proj",
103
+ "down_proj",
104
+ ],
105
+ )
106
+
107
+ # Model to fine-tune
108
+ model = AutoModelForCausalLM.from_pretrained(
109
+ model_name,
110
+ torch_dtype=torch.float16,
111
+ load_in_4bit=True,
112
+ device_map="auto",
113
+ )
114
+ model.config.use_cache = False
115
+
116
+
117
+ model.gradient_checkpointing_enable()
118
+ model = prepare_model_for_kbit_training(model)
119
+
120
+ # %%
121
+ output_name = f"checkpoints/exp_human_{model_name}"
122
+
123
+ training_args = DPOConfig(
124
+ per_device_train_batch_size=1,
125
+ gradient_accumulation_steps=4,
126
+ num_train_epochs=1,
127
+ gradient_checkpointing=True,
128
+ output_dir=output_name,
129
+ logging_steps=1,
130
+ max_steps=50
131
+ )
132
+
133
+ trainer = DPOTrainer(
134
+ model,
135
+ ref_model=None,
136
+ train_dataset=dataset,
137
+ tokenizer=tokenizer,
138
+ args=training_args,
139
+ peft_config=peft_config,
140
+ )
141
+
142
+ trainer.train()
143
+
144
+ trainer.save_model(output_name)
145
+
146
+