versae commited on
Commit
ea0132b
1 Parent(s): 346a10a

Preparing code for final runs

Browse files
Files changed (3) hide show
  1. config.py +2 -2
  2. run_mlm_flax_stream.py +69 -23
  3. run_stream.sh +5 -6
config.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python
2
  from transformers import RobertaConfig
3
  config = RobertaConfig.from_pretrained("roberta-large")
4
- config.save_pretrained("./")
5
 
6
  config = RobertaConfig.from_pretrained("roberta-base")
7
- config.save_pretrained("./config-base.json")
 
1
  #!/usr/bin/env python
2
  from transformers import RobertaConfig
3
  config = RobertaConfig.from_pretrained("roberta-large")
4
+ config.save_pretrained("./configs/large")
5
 
6
  config = RobertaConfig.from_pretrained("roberta-base")
7
+ config.save_pretrained("./configs/base")
run_mlm_flax_stream.py CHANGED
@@ -21,13 +21,16 @@ Here is the full list of checkpoints on the hub that can be fine-tuned by this s
21
  https://huggingface.co/models?filter=masked-lm
22
  """
23
  import logging
 
24
  import os
 
25
  import sys
26
  import time
27
  from collections import defaultdict
28
  from dataclasses import dataclass, field
29
 
30
  # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
 
31
  from pathlib import Path
32
  from typing import Dict, List, Optional, Tuple
33
 
@@ -39,9 +42,10 @@ from tqdm import tqdm
39
  import flax
40
  import jax
41
  import jax.numpy as jnp
42
- import kenlm
43
  import optax
44
  from flax import jax_utils, traverse_util
 
45
  from flax.training import train_state
46
  from flax.training.common_utils import get_metrics, onehot, shard
47
  from transformers import (
@@ -334,6 +338,26 @@ def write_eval_metric(summary_writer, eval_metrics, step):
334
  summary_writer.scalar(f"eval_{metric_name}", value, step)
335
 
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  if __name__ == "__main__":
338
  # See all possible arguments in src/transformers/training_args.py
339
  # or by passing the --help flag to this script.
@@ -391,19 +415,31 @@ if __name__ == "__main__":
391
  filepaths["train"] = data_args.train_file
392
  if data_args.validation_file:
393
  filepaths["validation"] = data_args.validation_file
394
- dataset = load_dataset(
395
- data_args.dataset_name,
396
- data_args.dataset_config_name,
397
- cache_dir=model_args.cache_dir,
398
- streaming=True,
399
- split="train",
400
- sampling_method=sampling_args.sampling_method,
401
- sampling_factor=sampling_args.sampling_factor,
402
- boundaries=sampling_args.boundaries,
403
- perplexity_model=sampling_args.perplexity_model,
404
- seed=training_args.seed,
405
- data_files=filepaths,
406
- )
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
  if model_args.config_name:
409
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
@@ -662,15 +698,25 @@ if __name__ == "__main__":
662
  write_eval_metric(summary_writer, eval_metrics, step)
663
  eval_metrics = []
664
 
665
- # save checkpoint after each epoch and push checkpoint to the hub
666
- if jax.process_index() == 0:
667
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
668
- model.save_pretrained(
669
- training_args.output_dir,
670
- params=params,
671
- push_to_hub=training_args.push_to_hub,
672
- commit_message=f"Saving weights and logs of step {step+1}",
673
- )
 
 
 
 
 
 
 
 
 
 
674
 
675
  # update tqdm bar
676
  steps.update(1)
 
21
  https://huggingface.co/models?filter=masked-lm
22
  """
23
  import logging
24
+ import json
25
  import os
26
+ import shutil
27
  import sys
28
  import time
29
  from collections import defaultdict
30
  from dataclasses import dataclass, field
31
 
32
  # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
33
+ import joblib
34
  from pathlib import Path
35
  from typing import Dict, List, Optional, Tuple
36
 
 
42
  import flax
43
  import jax
44
  import jax.numpy as jnp
45
+ import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
46
  import optax
47
  from flax import jax_utils, traverse_util
48
+ from flax.serialization import from_bytes, to_bytes
49
  from flax.training import train_state
50
  from flax.training.common_utils import get_metrics, onehot, shard
51
  from transformers import (
 
338
  summary_writer.scalar(f"eval_{metric_name}", value, step)
339
 
340
 
341
+ def save_checkpoint_files(state, data_collator, training_args, save_dir):
342
+ unreplicated_state = jax_utils.unreplicate(state)
343
+ with open(os.path.join(save_dir, "optimizer_state.msgpack"), "wb") as f:
344
+ f.write(to_bytes(unreplicated_state.opt_state))
345
+ joblib.dump(training_args, os.path.join(save_dir, "training_args.joblib"))
346
+ joblib.dump(data_collator, os.path.join(save_dir, "data_collator.joblib"))
347
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
348
+ json.dump({"step": unreplicated_state.step.item()}, f)
349
+
350
+
351
+ def rotate_checkpoints(path, max_checkpoints=5):
352
+ paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
353
+ if len(paths) > max_checkpoints:
354
+ for path_to_delete in paths[max_checkpoints:]:
355
+ try:
356
+ shutil.rmtree(path_to_delete)
357
+ except OSError:
358
+ os.remove(path_to_delete)
359
+
360
+
361
  if __name__ == "__main__":
362
  # See all possible arguments in src/transformers/training_args.py
363
  # or by passing the --help flag to this script.
 
415
  filepaths["train"] = data_args.train_file
416
  if data_args.validation_file:
417
  filepaths["validation"] = data_args.validation_file
418
+ try:
419
+ dataset = load_dataset(
420
+ data_args.dataset_name,
421
+ data_args.dataset_config_name,
422
+ cache_dir=model_args.cache_dir,
423
+ streaming=True,
424
+ split="train",
425
+ sampling_method=sampling_args.sampling_method,
426
+ sampling_factor=sampling_args.sampling_factor,
427
+ boundaries=sampling_args.boundaries,
428
+ perplexity_model=sampling_args.perplexity_model,
429
+ seed=training_args.seed,
430
+ data_files=filepaths,
431
+ )
432
+ except Exception as exc:
433
+ logger.warning(
434
+ f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
435
+ )
436
+ dataset = load_dataset(
437
+ data_args.dataset_name,
438
+ data_args.dataset_config_name,
439
+ cache_dir=model_args.cache_dir,
440
+ streaming=True,
441
+ split="train",
442
+ )
443
 
444
  if model_args.config_name:
445
  config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
 
698
  write_eval_metric(summary_writer, eval_metrics, step)
699
  eval_metrics = []
700
 
701
+ # save checkpoint after eval_steps
702
+ if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
703
+ print(f"Saving checkpoint at {step + 1} steps")
704
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
705
+ model.save_pretrained(
706
+ training_args.output_dir,
707
+ params=params,
708
+ push_to_hub=training_args.push_to_hub,
709
+ commit_message=f"Saving weights and logs of step {step + 1}",
710
+ )
711
+ save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
712
+ checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step + 1}"
713
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
714
+ model.save_pretrained(checkpoints_dir, params=params,)
715
+ save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
716
+ rotate_checkpoints(
717
+ Path(training_args.output_dir) / "checkpoints",
718
+ max_checkpoints=training_args.save_total_limit
719
+ )
720
 
721
  # update tqdm bar
722
  steps.update(1)
run_stream.sh CHANGED
@@ -4,9 +4,10 @@ python ./run_mlm_flax_stream.py \
4
  --output_dir="./outputs" \
5
  --model_type="roberta" \
6
  --config_name="./configs/base" \
7
- --tokenizer_name="./" \
8
  --dataset_name="./mc4" \
9
  --dataset_config_name="es" \
 
10
  --max_seq_length="128" \
11
  --pad_to_max_length \
12
  --per_device_train_batch_size="256" \
@@ -16,13 +17,11 @@ python ./run_mlm_flax_stream.py \
16
  --adam_epsilon="1e-6" \
17
  --learning_rate="6e-4" \
18
  --weight_decay="0.01" \
19
- --save_strategy="steps" \
20
- --save_steps="1000" \
21
  --save_total_limit="5" \
22
  --warmup_steps="24000" \
23
  --overwrite_output_dir \
24
- --num_train_steps="500000" \
25
- --eval_steps="1000" \
26
  --dtype="bfloat16" \
27
- --sampling_method="steps" \
28
  --logging_steps="500" 2>&1 | tee run_stream.log
 
4
  --output_dir="./outputs" \
5
  --model_type="roberta" \
6
  --config_name="./configs/base" \
7
+ --tokenizer_name="./configs/base" \
8
  --dataset_name="./mc4" \
9
  --dataset_config_name="es" \
10
+ --train_file="path/to/mc4-es-train-50M-XXX.jsonl" \
11
  --max_seq_length="128" \
12
  --pad_to_max_length \
13
  --per_device_train_batch_size="256" \
 
17
  --adam_epsilon="1e-6" \
18
  --learning_rate="6e-4" \
19
  --weight_decay="0.01" \
20
+ --save_steps="10000" \
 
21
  --save_total_limit="5" \
22
  --warmup_steps="24000" \
23
  --overwrite_output_dir \
24
+ --num_train_steps="250000" \
25
+ --eval_steps="10000" \
26
  --dtype="bfloat16" \
 
27
  --logging_steps="500" 2>&1 | tee run_stream.log