add dataset scripts
Browse files- .gitignore +2 -0
- convert_files.py +17 -0
- get_data.sh +23 -0
- merge_datasets.py +12 -0
- prepare_data.sh +0 -0
- train.py +41 -215
- train.sh +22 -0
- wiki_sentences.py +46 -0
.gitignore
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
.vscode
|
2 |
venv
|
3 |
*.pyc
|
|
|
|
|
|
1 |
.vscode
|
2 |
venv
|
3 |
*.pyc
|
4 |
+
segment_*
|
5 |
+
dataset.csv
|
convert_files.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from tqdm import tqdm
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
|
5 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
6 |
+
|
7 |
+
for i in tqdm(range(298)):
|
8 |
+
|
9 |
+
with open(f'wikipedia_json_64_filtered/wikipedia.segmented.nltk.split.seq64.{i}.json', 'r') as f:
|
10 |
+
rows = json.load(f)
|
11 |
+
|
12 |
+
tokens = [row['gpt2_token'] for row in rows]
|
13 |
+
texts = tokenizer.batch_decode(tokens)
|
14 |
+
|
15 |
+
with open(f'wikipedia/{i}.txt', 'w') as f:
|
16 |
+
for txt in texts:
|
17 |
+
f.write(txt.strip() + '\n')
|
get_data.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD" -O segment_1.zip && rm -rf /tmp/cookies.txt
|
4 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4" -O segment_2.zip && rm -rf /tmp/cookies.txt
|
5 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN" -O segment_3.zip && rm -rf /tmp/cookies.txt
|
6 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q" -O segment_4.zip && rm -rf /tmp/cookies.txt
|
7 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt" -O segment_5.zip && rm -rf /tmp/cookies.txt
|
8 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW" -O segment_6.zip && rm -rf /tmp/cookies.txt
|
9 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu" -O segment_7.zip && rm -rf /tmp/cookies.txt
|
10 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp" -O segment_8.zip && rm -rf /tmp/cookies.txt
|
11 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc" -O segment_9.zip && rm -rf /tmp/cookies.txt
|
12 |
+
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU" -O segment_0.zip && rm -rf /tmp/cookies.txt
|
13 |
+
|
14 |
+
unzip segment_1.zip
|
15 |
+
unzip segment_2.zip
|
16 |
+
unzip segment_3.zip
|
17 |
+
unzip segment_4.zip
|
18 |
+
unzip segment_5.zip
|
19 |
+
unzip segment_6.zip
|
20 |
+
unzip segment_7.zip
|
21 |
+
unzip segment_8.zip
|
22 |
+
unzip segment_9.zip
|
23 |
+
|
merge_datasets.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
dfs = []
|
5 |
+
|
6 |
+
for i in range(10):
|
7 |
+
dfs.append(
|
8 |
+
datasets.ArrowReader.read_table(f'segment_{i}/dataset.arrow').to_pandas()
|
9 |
+
)
|
10 |
+
|
11 |
+
full_df = pd.concat(dfs, ignore_index=True)
|
12 |
+
full_df.to_csv('dataset.csv')
|
prepare_data.sh
ADDED
File without changes
|
train.py
CHANGED
@@ -17,7 +17,6 @@
|
|
17 |
- [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
|
18 |
'''
|
19 |
import logging
|
20 |
-
import math
|
21 |
import os
|
22 |
import sys
|
23 |
import time
|
@@ -31,6 +30,7 @@ from tqdm import tqdm
|
|
31 |
|
32 |
import jax
|
33 |
import jax.numpy as jnp
|
|
|
34 |
import optax
|
35 |
import transformers
|
36 |
from flax import jax_utils, traverse_util
|
@@ -44,7 +44,6 @@ from transformers import (
|
|
44 |
is_tensorboard_available,
|
45 |
)
|
46 |
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
47 |
-
from transformers.testing_utils import CaptureLogger
|
48 |
|
49 |
from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
|
50 |
from t5_vae_flax.src.config import T5VaeConfig
|
@@ -113,7 +112,7 @@ class ModelArguments:
|
|
113 |
@dataclass
|
114 |
class DataTrainingArguments:
|
115 |
"""
|
116 |
-
Arguments pertaining to what data we are going to input our model for training
|
117 |
"""
|
118 |
|
119 |
dataset_name: Optional[str] = field(
|
@@ -123,10 +122,6 @@ class DataTrainingArguments:
|
|
123 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
124 |
)
|
125 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
126 |
-
validation_file: Optional[str] = field(
|
127 |
-
default=None,
|
128 |
-
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
129 |
-
)
|
130 |
max_train_samples: Optional[int] = field(
|
131 |
default=None,
|
132 |
metadata={
|
@@ -134,21 +129,8 @@ class DataTrainingArguments:
|
|
134 |
"value if set."
|
135 |
},
|
136 |
)
|
137 |
-
max_eval_samples: Optional[int] = field(
|
138 |
-
default=None,
|
139 |
-
metadata={
|
140 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
141 |
-
"value if set."
|
142 |
-
},
|
143 |
-
)
|
144 |
overwrite_cache: bool = field(
|
145 |
-
default=False, metadata={"help": "Overwrite the cached training
|
146 |
-
)
|
147 |
-
validation_split_percentage: Optional[int] = field(
|
148 |
-
default=5,
|
149 |
-
metadata={
|
150 |
-
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
151 |
-
},
|
152 |
)
|
153 |
block_size: Optional[int] = field(
|
154 |
default=None,
|
@@ -162,7 +144,7 @@ class DataTrainingArguments:
|
|
162 |
default=False, metadata={"help": "Stream the dataset."}
|
163 |
)
|
164 |
overwrite_cache: bool = field(
|
165 |
-
default=False, metadata={"help": "Overwrite the cached training
|
166 |
)
|
167 |
preprocessing_num_workers: Optional[int] = field(
|
168 |
default=None,
|
@@ -170,15 +152,12 @@ class DataTrainingArguments:
|
|
170 |
)
|
171 |
|
172 |
def __post_init__(self):
|
173 |
-
if self.dataset_name is None and self.train_file is None
|
174 |
-
raise ValueError("Need either a dataset name or a training
|
175 |
else:
|
176 |
if self.train_file is not None:
|
177 |
extension = self.train_file.split(".")[-1]
|
178 |
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
179 |
-
if self.validation_file is not None:
|
180 |
-
extension = self.validation_file.split(".")[-1]
|
181 |
-
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
182 |
|
183 |
|
184 |
class TrainState(train_state.TrainState):
|
@@ -188,28 +167,19 @@ class TrainState(train_state.TrainState):
|
|
188 |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
189 |
|
190 |
|
191 |
-
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int
|
192 |
"""
|
193 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
194 |
Shuffle batches if `shuffle` is `True`.
|
195 |
"""
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
205 |
-
|
206 |
-
for idx in batch_idx:
|
207 |
-
batch = dataset[idx]
|
208 |
-
batch = {k: jnp.array(v) for k, v in batch.items()}
|
209 |
-
|
210 |
-
batch = shard(batch)
|
211 |
-
|
212 |
-
yield batch
|
213 |
|
214 |
|
215 |
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
@@ -222,11 +192,6 @@ def write_train_metric(summary_writer, train_metrics, train_time, step):
|
|
222 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
223 |
|
224 |
|
225 |
-
def write_eval_metric(summary_writer, eval_metrics, step):
|
226 |
-
for metric_name, value in eval_metrics.items():
|
227 |
-
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
228 |
-
|
229 |
-
|
230 |
def create_learning_rate_fn(
|
231 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
232 |
) -> Callable[[int], jnp.array]:
|
@@ -284,9 +249,9 @@ def main():
|
|
284 |
transformers.utils.logging.set_verbosity_error()
|
285 |
|
286 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
287 |
-
logger.info(f"Training
|
288 |
|
289 |
-
#
|
290 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
291 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
292 |
#
|
@@ -295,35 +260,7 @@ def main():
|
|
295 |
#
|
296 |
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
297 |
# download the dataset.
|
298 |
-
|
299 |
-
# Downloading and loading a dataset from the hub.
|
300 |
-
dataset = load_dataset(
|
301 |
-
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.streaming, keep_in_memory=False
|
302 |
-
)
|
303 |
-
|
304 |
-
if "validation" not in dataset.keys():
|
305 |
-
dataset["validation"] = load_dataset(
|
306 |
-
data_args.dataset_name,
|
307 |
-
data_args.dataset_config_name,
|
308 |
-
split=f"train[:{data_args.validation_split_percentage}%]",
|
309 |
-
cache_dir=model_args.cache_dir,
|
310 |
-
)
|
311 |
-
dataset["train"] = load_dataset(
|
312 |
-
data_args.dataset_name,
|
313 |
-
data_args.dataset_config_name,
|
314 |
-
split=f"train[{data_args.validation_split_percentage}%:]",
|
315 |
-
cache_dir=model_args.cache_dir,
|
316 |
-
)
|
317 |
-
else:
|
318 |
-
data_files = {}
|
319 |
-
if data_args.train_file is not None:
|
320 |
-
data_files["train"] = data_args.train_file
|
321 |
-
if data_args.validation_file is not None:
|
322 |
-
data_files["validation"] = data_args.validation_file
|
323 |
-
extension = data_args.train_file.split(".")[-1]
|
324 |
-
if extension == "txt":
|
325 |
-
extension = "text"
|
326 |
-
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
327 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
328 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
329 |
|
@@ -381,37 +318,6 @@ def main():
|
|
381 |
assert tokenizer.pad_token == '<PAD>'
|
382 |
|
383 |
# Preprocessing the datasets.
|
384 |
-
# First we tokenize all the texts.
|
385 |
-
if training_args.do_train:
|
386 |
-
column_names = dataset["train"].column_names
|
387 |
-
else:
|
388 |
-
column_names = dataset["validation"].column_names
|
389 |
-
text_column_name = "text" if "text" in column_names else column_names[0]
|
390 |
-
|
391 |
-
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
|
392 |
-
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
393 |
-
|
394 |
-
def tokenize_function(examples):
|
395 |
-
with CaptureLogger(tok_logger) as cl:
|
396 |
-
output = tokenizer(examples[text_column_name])
|
397 |
-
# clm input could be much much longer than block_size
|
398 |
-
if "Token indices sequence length is longer than the" in cl.out:
|
399 |
-
tok_logger.warning(
|
400 |
-
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
|
401 |
-
)
|
402 |
-
return output
|
403 |
-
|
404 |
-
# remove dataset tasks
|
405 |
-
for k in dataset.keys():
|
406 |
-
dataset[k].info.task_templates = []
|
407 |
-
|
408 |
-
tokenized_datasets = dataset.map(
|
409 |
-
tokenize_function,
|
410 |
-
batched=True,
|
411 |
-
num_proc=data_args.preprocessing_num_workers,
|
412 |
-
remove_columns=column_names,
|
413 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
414 |
-
)
|
415 |
|
416 |
if data_args.block_size > tokenizer.model_max_length:
|
417 |
logger.warning(
|
@@ -422,65 +328,27 @@ def main():
|
|
422 |
|
423 |
pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id
|
424 |
|
425 |
-
def
|
426 |
-
examples["
|
427 |
-
|
428 |
-
for i, input_ids in enumerate(examples["input_ids"]):
|
429 |
-
if len(input_ids) > block_size:
|
430 |
-
for k in examples.keys():
|
431 |
-
examples[k][i] = examples[k][i][:block_size]
|
432 |
-
elif len(input_ids) < block_size:
|
433 |
-
delta = block_size - len(input_ids)
|
434 |
-
examples['input_ids'][i] = examples['input_ids'][i] + [pad_token_id] * delta
|
435 |
-
examples['attention_mask'][i] = examples['attention_mask'][i] + [0] * delta
|
436 |
-
examples['labels'][i] = examples['labels'][i] + [-100] * delta
|
437 |
-
|
438 |
-
return examples
|
439 |
-
|
440 |
-
logger.info('clip_texts...')
|
441 |
-
clipped_lm_datasets = tokenized_datasets.map(
|
442 |
-
clip_texts,
|
443 |
-
batched=True,
|
444 |
-
num_proc=data_args.preprocessing_num_workers,
|
445 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
446 |
-
)
|
447 |
-
|
448 |
-
def add_decoder_input_ids(examples):
|
449 |
-
arr_input_ids = jnp.array(examples["input_ids"])
|
450 |
-
pad = pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
|
451 |
-
arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
|
452 |
-
examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id)
|
453 |
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
|
458 |
-
|
459 |
-
|
|
|
460 |
|
461 |
-
|
|
|
462 |
|
463 |
-
|
464 |
-
lm_datasets = clipped_lm_datasets.map(
|
465 |
-
add_decoder_input_ids,
|
466 |
-
batched=True,
|
467 |
-
num_proc=data_args.preprocessing_num_workers,
|
468 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
469 |
-
)
|
470 |
|
471 |
-
|
472 |
-
if "train" not in tokenized_datasets:
|
473 |
-
raise ValueError("--do_train requires a train dataset")
|
474 |
-
train_dataset = lm_datasets["train"]
|
475 |
-
if data_args.max_train_samples is not None:
|
476 |
-
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
477 |
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
eval_dataset = lm_datasets["validation"]
|
482 |
-
if data_args.max_eval_samples is not None:
|
483 |
-
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
484 |
|
485 |
# Enable tensorboard only on the master node
|
486 |
has_tensorboard = is_tensorboard_available()
|
@@ -507,13 +375,13 @@ def main():
|
|
507 |
# Store some constant
|
508 |
num_epochs = int(training_args.num_train_epochs)
|
509 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
510 |
-
|
511 |
-
steps_per_epoch =
|
512 |
total_train_steps = steps_per_epoch * num_epochs
|
513 |
|
514 |
# Create learning rate schedule
|
515 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
516 |
-
|
517 |
train_batch_size,
|
518 |
training_args.num_train_epochs,
|
519 |
training_args.warmup_steps,
|
@@ -602,26 +470,14 @@ def main():
|
|
602 |
|
603 |
return new_state, metrics
|
604 |
|
605 |
-
#
|
606 |
-
def eval_step(params, rng, batch):
|
607 |
-
labels = batch.pop("labels")
|
608 |
-
logits, latent_codes = model(**batch, params=params, train=False)[:2]
|
609 |
-
loss = loss_fn(logits, labels, latent_codes, rng)
|
610 |
-
|
611 |
-
# summarize metrics
|
612 |
-
metrics = {"loss": loss}
|
613 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
614 |
-
return metrics
|
615 |
-
|
616 |
-
# Create parallel version of the train and eval step
|
617 |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
618 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
619 |
|
620 |
# Replicate the train state on each device
|
621 |
state = state.replicate()
|
622 |
|
623 |
logger.info("***** Running training *****")
|
624 |
-
logger.info(f" Num examples = {
|
625 |
logger.info(f" Num Epochs = {num_epochs}")
|
626 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
627 |
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
@@ -638,15 +494,15 @@ def main():
|
|
638 |
rng, input_rng = jax.random.split(rng)
|
639 |
|
640 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
641 |
-
train_loader = data_loader(input_rng, train_dataset, train_batch_size
|
642 |
-
steps_per_epoch =
|
643 |
# train
|
644 |
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
645 |
batch = next(train_loader)
|
646 |
state, train_metric = p_train_step(state, batch)
|
647 |
train_metrics.append(train_metric)
|
648 |
|
649 |
-
cur_step = epoch * (
|
650 |
|
651 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
652 |
# Save metrics
|
@@ -661,36 +517,6 @@ def main():
|
|
661 |
|
662 |
train_metrics = []
|
663 |
|
664 |
-
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
665 |
-
# ======================== Evaluating ==============================
|
666 |
-
eval_metrics = []
|
667 |
-
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
668 |
-
eval_steps = len(eval_dataset) // eval_batch_size
|
669 |
-
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
670 |
-
# Model forward
|
671 |
-
batch = next(eval_loader)
|
672 |
-
metrics = p_eval_step(state.params, state.dropout_rng, batch)
|
673 |
-
eval_metrics.append(metrics)
|
674 |
-
|
675 |
-
# normalize eval metrics
|
676 |
-
eval_metrics = get_metrics(eval_metrics)
|
677 |
-
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
678 |
-
|
679 |
-
try:
|
680 |
-
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
681 |
-
except OverflowError:
|
682 |
-
eval_metrics["perplexity"] = float("inf")
|
683 |
-
|
684 |
-
# Print metrics and update progress bar
|
685 |
-
desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
|
686 |
-
epochs.write(desc)
|
687 |
-
epochs.desc = desc
|
688 |
-
|
689 |
-
# Save metrics
|
690 |
-
if has_tensorboard and jax.process_index() == 0:
|
691 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
692 |
-
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
693 |
-
|
694 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
695 |
# save checkpoint after each epoch and push checkpoint to the hub
|
696 |
if jax.process_index() == 0:
|
|
|
17 |
- [ ] Send the schedule time to the compute_loss method and calculate a coefficient based on that.
|
18 |
'''
|
19 |
import logging
|
|
|
20 |
import os
|
21 |
import sys
|
22 |
import time
|
|
|
30 |
|
31 |
import jax
|
32 |
import jax.numpy as jnp
|
33 |
+
import numpy as onp
|
34 |
import optax
|
35 |
import transformers
|
36 |
from flax import jax_utils, traverse_util
|
|
|
44 |
is_tensorboard_available,
|
45 |
)
|
46 |
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
|
|
47 |
|
48 |
from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
|
49 |
from t5_vae_flax.src.config import T5VaeConfig
|
|
|
112 |
@dataclass
|
113 |
class DataTrainingArguments:
|
114 |
"""
|
115 |
+
Arguments pertaining to what data we are going to input our model for training.
|
116 |
"""
|
117 |
|
118 |
dataset_name: Optional[str] = field(
|
|
|
122 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
123 |
)
|
124 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
|
|
|
|
|
|
|
|
125 |
max_train_samples: Optional[int] = field(
|
126 |
default=None,
|
127 |
metadata={
|
|
|
129 |
"value if set."
|
130 |
},
|
131 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
overwrite_cache: bool = field(
|
133 |
+
default=False, metadata={"help": "Overwrite the cached training sets"}
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
)
|
135 |
block_size: Optional[int] = field(
|
136 |
default=None,
|
|
|
144 |
default=False, metadata={"help": "Stream the dataset."}
|
145 |
)
|
146 |
overwrite_cache: bool = field(
|
147 |
+
default=False, metadata={"help": "Overwrite the cached training sets"}
|
148 |
)
|
149 |
preprocessing_num_workers: Optional[int] = field(
|
150 |
default=None,
|
|
|
152 |
)
|
153 |
|
154 |
def __post_init__(self):
|
155 |
+
if self.dataset_name is None and self.train_file is None:
|
156 |
+
raise ValueError("Need either a dataset name or a training file.")
|
157 |
else:
|
158 |
if self.train_file is not None:
|
159 |
extension = self.train_file.split(".")[-1]
|
160 |
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
|
|
|
|
|
|
161 |
|
162 |
|
163 |
class TrainState(train_state.TrainState):
|
|
|
167 |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
168 |
|
169 |
|
170 |
+
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int):
|
171 |
"""
|
172 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
173 |
Shuffle batches if `shuffle` is `True`.
|
174 |
"""
|
175 |
+
batch = []
|
176 |
+
for row in dataset:
|
177 |
+
batch.append(row)
|
178 |
+
if len(batch) >= batch_size:
|
179 |
+
batch = {k: jnp.stack([row[k] for row in batch]) for k in batch[0].keys()}
|
180 |
+
batch = shard(batch)
|
181 |
+
yield batch
|
182 |
+
batch = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
|
185 |
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
|
|
192 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
193 |
|
194 |
|
|
|
|
|
|
|
|
|
|
|
195 |
def create_learning_rate_fn(
|
196 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
197 |
) -> Callable[[int], jnp.array]:
|
|
|
249 |
transformers.utils.logging.set_verbosity_error()
|
250 |
|
251 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
252 |
+
logger.info(f"Training parameters {training_args}")
|
253 |
|
254 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training files (see below)
|
255 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
256 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
257 |
#
|
|
|
260 |
#
|
261 |
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
262 |
# download the dataset.
|
263 |
+
dataset = load_dataset('text', data_files=[f'wikipedia/{i}.txt' for i in range(298)], cache_dir=model_args.cache_dir, streaming=True)['train']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
265 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
266 |
|
|
|
318 |
assert tokenizer.pad_token == '<PAD>'
|
319 |
|
320 |
# Preprocessing the datasets.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
if data_args.block_size > tokenizer.model_max_length:
|
323 |
logger.warning(
|
|
|
328 |
|
329 |
pad_token_id, start_token_id = tokenizer.pad_token_id, config.decoder_start_token_id
|
330 |
|
331 |
+
def tokenize_function(examples):
|
332 |
+
output = tokenizer(examples["text"], return_tensors='jax', padding='max_length', max_length=block_size, truncation=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
+
output['labels'] = onp.array(output['input_ids'].copy())
|
335 |
+
output['labels'][output['labels'] == pad_token_id] = -100
|
336 |
+
output['labels'] = jnp.array(output['labels'])
|
337 |
|
338 |
+
pad = pad_token_id * jnp.ones((output['input_ids'].shape[0], 1), dtype=jnp.int32)
|
339 |
+
arr_pad_input_ids = jnp.concatenate((output['input_ids'], pad), axis=1)
|
340 |
+
output['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, pad_token_id, start_token_id)
|
341 |
|
342 |
+
ones = jnp.ones((output['attention_mask'].shape[0], 1), dtype=jnp.int32)
|
343 |
+
output['decoder_attention_mask'] = jnp.concatenate((ones, output['attention_mask']), axis=1)
|
344 |
|
345 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
|
347 |
+
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
+
train_dataset = tokenized_datasets
|
350 |
+
if data_args.max_train_samples is not None:
|
351 |
+
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
|
|
|
|
|
|
352 |
|
353 |
# Enable tensorboard only on the master node
|
354 |
has_tensorboard = is_tensorboard_available()
|
|
|
375 |
# Store some constant
|
376 |
num_epochs = int(training_args.num_train_epochs)
|
377 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
378 |
+
train_dataset_len = 97876602
|
379 |
+
steps_per_epoch = train_dataset_len // train_batch_size
|
380 |
total_train_steps = steps_per_epoch * num_epochs
|
381 |
|
382 |
# Create learning rate schedule
|
383 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
384 |
+
train_dataset_len,
|
385 |
train_batch_size,
|
386 |
training_args.num_train_epochs,
|
387 |
training_args.warmup_steps,
|
|
|
470 |
|
471 |
return new_state, metrics
|
472 |
|
473 |
+
# Create parallel version of the train step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
|
|
475 |
|
476 |
# Replicate the train state on each device
|
477 |
state = state.replicate()
|
478 |
|
479 |
logger.info("***** Running training *****")
|
480 |
+
logger.info(f" Num examples = {train_dataset_len}")
|
481 |
logger.info(f" Num Epochs = {num_epochs}")
|
482 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
483 |
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
|
|
494 |
rng, input_rng = jax.random.split(rng)
|
495 |
|
496 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
497 |
+
train_loader = data_loader(input_rng, train_dataset, train_batch_size)
|
498 |
+
steps_per_epoch = train_dataset_len // train_batch_size
|
499 |
# train
|
500 |
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
501 |
batch = next(train_loader)
|
502 |
state, train_metric = p_train_step(state, batch)
|
503 |
train_metrics.append(train_metric)
|
504 |
|
505 |
+
cur_step = epoch * (train_dataset_len // train_batch_size) + step
|
506 |
|
507 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
508 |
# Save metrics
|
|
|
517 |
|
518 |
train_metrics = []
|
519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
521 |
# save checkpoint after each epoch and push checkpoint to the hub
|
522 |
if jax.process_index() == 0:
|
train.sh
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export RUN_NAME=single_latent
|
2 |
+
|
3 |
+
# TODO update to not use tokenizer, instead use gpt2 one
|
4 |
+
./venv/bin/python train.py \
|
5 |
+
--t5_model_name_or_path="t5-base" \
|
6 |
+
--output_dir="output/${RUN_NAME}" \
|
7 |
+
--overwrite_output_dir \
|
8 |
+
--do_train \
|
9 |
+
--n_latent_tokens 1 \
|
10 |
+
--latent_token_size 32 \
|
11 |
+
--save_steps="2000" \
|
12 |
+
--block_size="128" \
|
13 |
+
--per_device_train_batch_size="100" \
|
14 |
+
--train_file="INVALID.txt" \
|
15 |
+
--overwrite_output_dir \
|
16 |
+
--num_train_epochs="1" \
|
17 |
+
|
18 |
+
# 200 batch size, 128 sequence len: ? (breaks)
|
19 |
+
# 100 batch size, 128 sequence len: 252:38:58
|
20 |
+
# 10 batch size, 128 sequence len: 281:32:53
|
21 |
+
|
22 |
+
# Got ~12 hours to train, want 3 saves so one save every 4 hours
|
wiki_sentences.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# unused
|
2 |
+
"""Wikipedia Sentences"""
|
3 |
+
|
4 |
+
from __future__ import absolute_import, division, print_function
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
|
9 |
+
import datasets
|
10 |
+
|
11 |
+
|
12 |
+
_DESCRIPTION = """\
|
13 |
+
Dataset of sentences from Wikipedia (from the [Optimus paper](https://arxiv.org/abs/2004.04092)).
|
14 |
+
Each is of mex 64 words & <=256 GPT2 tokens.
|
15 |
+
Each row is a tokenised sentence.
|
16 |
+
{'token_ids': '{gpt2 token ids}'}
|
17 |
+
This is to test the semantics of a Transformer-VAEs latent space by interpolating on sentences.
|
18 |
+
"""
|
19 |
+
|
20 |
+
NUM_SEGMENTS = 5
|
21 |
+
DOWNLOAD_URLS = 'https://drive.google.com/file/d/13NnkYAhwszQxc1C5HHfThnF7c1cjzjAD/view?usp=sharing, https://drive.google.com/file/d/14p6FHip_hGTXC-_7SYaK32BpEhZRDJI4/view?usp=sharing, https://drive.google.com/file/d/1IaRfTFh51Wf_zPtK6tjE6xw-up_Z6EyN/view?usp=sharing, https://drive.google.com/file/d/1KGhV397Xfej56uJ9H10xD7tfLdhWlg4q/view?usp=sharing, https://drive.google.com/file/d/1LfsQ1s9wr1mBG3I1bbvnbyrYmnsrXxZt/view?usp=sharing, https://drive.google.com/file/d/1OctFe_JPR0Ajh77FzWdfeYnWZinKl2sW/view?usp=sharing, https://drive.google.com/file/d/1W-Yi8gHCcT8O5F4TcDHScH7pOb0GQZdu/view?usp=sharing, https://drive.google.com/file/d/1jgHjnpe7Vk1pvRgfnH4S4KiRrpUQyqyp/view?usp=sharing, https://drive.google.com/file/d/1oVst8RG8G2d21DL6q4DwO7aJxE1vA2fc/view?usp=sharing, https://drive.google.com/file/d/1qwckIM8YBbU9bnArB6bAoStY3e9I1kqU/view?usp=sharing'.split(', ')
|
22 |
+
|
23 |
+
|
24 |
+
class WikiSentences(datasets.GeneratorBasedBuilder):
|
25 |
+
"""Sentences from Wikipedia."""
|
26 |
+
|
27 |
+
BUILDER_CONFIGS = [datasets.BuilderConfig(name="main", description="Run through json files one by one.",)]
|
28 |
+
|
29 |
+
def _info(self):
|
30 |
+
return datasets.DatasetInfo(
|
31 |
+
description=_DESCRIPTION,
|
32 |
+
features=datasets.Features(
|
33 |
+
{
|
34 |
+
'token_ids': [datasets.Value("int32")],
|
35 |
+
}
|
36 |
+
),
|
37 |
+
homepage="https://github.com/Fraser-Greenlee/transformer-vae",
|
38 |
+
)
|
39 |
+
|
40 |
+
def _generate_examples(self, filepath):
|
41 |
+
"""Generate examples."""
|
42 |
+
with open(filepath, encoding="utf-8") as json_lines_file:
|
43 |
+
for id_, line in enumerate(json_lines_file):
|
44 |
+
yield id_, json.loads(line)
|
45 |
+
if id_ >= self.config.max_num_samples:
|
46 |
+
break
|