Preparing code for final runs
Browse files- config.py +2 -2
- run_mlm_flax_stream.py +69 -23
- run_stream.sh +5 -6
config.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
from transformers import RobertaConfig
|
3 |
config = RobertaConfig.from_pretrained("roberta-large")
|
4 |
-
config.save_pretrained("./")
|
5 |
|
6 |
config = RobertaConfig.from_pretrained("roberta-base")
|
7 |
-
config.save_pretrained("./
|
|
|
1 |
#!/usr/bin/env python
|
2 |
from transformers import RobertaConfig
|
3 |
config = RobertaConfig.from_pretrained("roberta-large")
|
4 |
+
config.save_pretrained("./configs/large")
|
5 |
|
6 |
config = RobertaConfig.from_pretrained("roberta-base")
|
7 |
+
config.save_pretrained("./configs/base")
|
run_mlm_flax_stream.py
CHANGED
@@ -21,13 +21,16 @@ Here is the full list of checkpoints on the hub that can be fine-tuned by this s
|
|
21 |
https://huggingface.co/models?filter=masked-lm
|
22 |
"""
|
23 |
import logging
|
|
|
24 |
import os
|
|
|
25 |
import sys
|
26 |
import time
|
27 |
from collections import defaultdict
|
28 |
from dataclasses import dataclass, field
|
29 |
|
30 |
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
|
|
31 |
from pathlib import Path
|
32 |
from typing import Dict, List, Optional, Tuple
|
33 |
|
@@ -39,9 +42,10 @@ from tqdm import tqdm
|
|
39 |
import flax
|
40 |
import jax
|
41 |
import jax.numpy as jnp
|
42 |
-
import kenlm
|
43 |
import optax
|
44 |
from flax import jax_utils, traverse_util
|
|
|
45 |
from flax.training import train_state
|
46 |
from flax.training.common_utils import get_metrics, onehot, shard
|
47 |
from transformers import (
|
@@ -334,6 +338,26 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
|
334 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
335 |
|
336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
if __name__ == "__main__":
|
338 |
# See all possible arguments in src/transformers/training_args.py
|
339 |
# or by passing the --help flag to this script.
|
@@ -391,19 +415,31 @@ if __name__ == "__main__":
|
|
391 |
filepaths["train"] = data_args.train_file
|
392 |
if data_args.validation_file:
|
393 |
filepaths["validation"] = data_args.validation_file
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
if model_args.config_name:
|
409 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
@@ -662,15 +698,25 @@ if __name__ == "__main__":
|
|
662 |
write_eval_metric(summary_writer, eval_metrics, step)
|
663 |
eval_metrics = []
|
664 |
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
|
675 |
# update tqdm bar
|
676 |
steps.update(1)
|
|
|
21 |
https://huggingface.co/models?filter=masked-lm
|
22 |
"""
|
23 |
import logging
|
24 |
+
import json
|
25 |
import os
|
26 |
+
import shutil
|
27 |
import sys
|
28 |
import time
|
29 |
from collections import defaultdict
|
30 |
from dataclasses import dataclass, field
|
31 |
|
32 |
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
33 |
+
import joblib
|
34 |
from pathlib import Path
|
35 |
from typing import Dict, List, Optional, Tuple
|
36 |
|
|
|
42 |
import flax
|
43 |
import jax
|
44 |
import jax.numpy as jnp
|
45 |
+
import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
|
46 |
import optax
|
47 |
from flax import jax_utils, traverse_util
|
48 |
+
from flax.serialization import from_bytes, to_bytes
|
49 |
from flax.training import train_state
|
50 |
from flax.training.common_utils import get_metrics, onehot, shard
|
51 |
from transformers import (
|
|
|
338 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
339 |
|
340 |
|
341 |
+
def save_checkpoint_files(state, data_collator, training_args, save_dir):
|
342 |
+
unreplicated_state = jax_utils.unreplicate(state)
|
343 |
+
with open(os.path.join(save_dir, "optimizer_state.msgpack"), "wb") as f:
|
344 |
+
f.write(to_bytes(unreplicated_state.opt_state))
|
345 |
+
joblib.dump(training_args, os.path.join(save_dir, "training_args.joblib"))
|
346 |
+
joblib.dump(data_collator, os.path.join(save_dir, "data_collator.joblib"))
|
347 |
+
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
|
348 |
+
json.dump({"step": unreplicated_state.step.item()}, f)
|
349 |
+
|
350 |
+
|
351 |
+
def rotate_checkpoints(path, max_checkpoints=5):
|
352 |
+
paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
|
353 |
+
if len(paths) > max_checkpoints:
|
354 |
+
for path_to_delete in paths[max_checkpoints:]:
|
355 |
+
try:
|
356 |
+
shutil.rmtree(path_to_delete)
|
357 |
+
except OSError:
|
358 |
+
os.remove(path_to_delete)
|
359 |
+
|
360 |
+
|
361 |
if __name__ == "__main__":
|
362 |
# See all possible arguments in src/transformers/training_args.py
|
363 |
# or by passing the --help flag to this script.
|
|
|
415 |
filepaths["train"] = data_args.train_file
|
416 |
if data_args.validation_file:
|
417 |
filepaths["validation"] = data_args.validation_file
|
418 |
+
try:
|
419 |
+
dataset = load_dataset(
|
420 |
+
data_args.dataset_name,
|
421 |
+
data_args.dataset_config_name,
|
422 |
+
cache_dir=model_args.cache_dir,
|
423 |
+
streaming=True,
|
424 |
+
split="train",
|
425 |
+
sampling_method=sampling_args.sampling_method,
|
426 |
+
sampling_factor=sampling_args.sampling_factor,
|
427 |
+
boundaries=sampling_args.boundaries,
|
428 |
+
perplexity_model=sampling_args.perplexity_model,
|
429 |
+
seed=training_args.seed,
|
430 |
+
data_files=filepaths,
|
431 |
+
)
|
432 |
+
except Exception as exc:
|
433 |
+
logger.warning(
|
434 |
+
f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
|
435 |
+
)
|
436 |
+
dataset = load_dataset(
|
437 |
+
data_args.dataset_name,
|
438 |
+
data_args.dataset_config_name,
|
439 |
+
cache_dir=model_args.cache_dir,
|
440 |
+
streaming=True,
|
441 |
+
split="train",
|
442 |
+
)
|
443 |
|
444 |
if model_args.config_name:
|
445 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
|
|
698 |
write_eval_metric(summary_writer, eval_metrics, step)
|
699 |
eval_metrics = []
|
700 |
|
701 |
+
# save checkpoint after eval_steps
|
702 |
+
if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
|
703 |
+
print(f"Saving checkpoint at {step + 1} steps")
|
704 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
705 |
+
model.save_pretrained(
|
706 |
+
training_args.output_dir,
|
707 |
+
params=params,
|
708 |
+
push_to_hub=training_args.push_to_hub,
|
709 |
+
commit_message=f"Saving weights and logs of step {step + 1}",
|
710 |
+
)
|
711 |
+
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
712 |
+
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step + 1}"
|
713 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
714 |
+
model.save_pretrained(checkpoints_dir, params=params,)
|
715 |
+
save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
|
716 |
+
rotate_checkpoints(
|
717 |
+
Path(training_args.output_dir) / "checkpoints",
|
718 |
+
max_checkpoints=training_args.save_total_limit
|
719 |
+
)
|
720 |
|
721 |
# update tqdm bar
|
722 |
steps.update(1)
|
run_stream.sh
CHANGED
@@ -4,9 +4,10 @@ python ./run_mlm_flax_stream.py \
|
|
4 |
--output_dir="./outputs" \
|
5 |
--model_type="roberta" \
|
6 |
--config_name="./configs/base" \
|
7 |
-
--tokenizer_name="./" \
|
8 |
--dataset_name="./mc4" \
|
9 |
--dataset_config_name="es" \
|
|
|
10 |
--max_seq_length="128" \
|
11 |
--pad_to_max_length \
|
12 |
--per_device_train_batch_size="256" \
|
@@ -16,13 +17,11 @@ python ./run_mlm_flax_stream.py \
|
|
16 |
--adam_epsilon="1e-6" \
|
17 |
--learning_rate="6e-4" \
|
18 |
--weight_decay="0.01" \
|
19 |
-
--
|
20 |
-
--save_steps="1000" \
|
21 |
--save_total_limit="5" \
|
22 |
--warmup_steps="24000" \
|
23 |
--overwrite_output_dir \
|
24 |
-
--num_train_steps="
|
25 |
-
--eval_steps="
|
26 |
--dtype="bfloat16" \
|
27 |
-
--sampling_method="steps" \
|
28 |
--logging_steps="500" 2>&1 | tee run_stream.log
|
|
|
4 |
--output_dir="./outputs" \
|
5 |
--model_type="roberta" \
|
6 |
--config_name="./configs/base" \
|
7 |
+
--tokenizer_name="./configs/base" \
|
8 |
--dataset_name="./mc4" \
|
9 |
--dataset_config_name="es" \
|
10 |
+
--train_file="path/to/mc4-es-train-50M-XXX.jsonl" \
|
11 |
--max_seq_length="128" \
|
12 |
--pad_to_max_length \
|
13 |
--per_device_train_batch_size="256" \
|
|
|
17 |
--adam_epsilon="1e-6" \
|
18 |
--learning_rate="6e-4" \
|
19 |
--weight_decay="0.01" \
|
20 |
+
--save_steps="10000" \
|
|
|
21 |
--save_total_limit="5" \
|
22 |
--warmup_steps="24000" \
|
23 |
--overwrite_output_dir \
|
24 |
+
--num_train_steps="250000" \
|
25 |
+
--eval_steps="10000" \
|
26 |
--dtype="bfloat16" \
|
|
|
27 |
--logging_steps="500" 2>&1 | tee run_stream.log
|