versae commited on
Commit
853cd83
1 Parent(s): a1f93c9

Scripts for perplexity sampling and fixes

Browse files
Files changed (7) hide show
  1. config.py +3 -0
  2. convert.py +10 -6
  3. run_mlm_flax.py +60 -60
  4. run_mlm_flax_stream.py +719 -0
  5. run_stream.sh +27 -0
  6. test_script.py +0 -45
  7. tokens.py +2 -2
config.py CHANGED
@@ -2,3 +2,6 @@
2
  from transformers import RobertaConfig
3
  config = RobertaConfig.from_pretrained("roberta-large")
4
  config.save_pretrained("./")
 
 
 
 
2
  from transformers import RobertaConfig
3
  config = RobertaConfig.from_pretrained("roberta-large")
4
  config.save_pretrained("./")
5
+
6
+ config = RobertaConfig.from_pretrained("roberta-base")
7
+ config.save_pretrained("./config-base.json")
convert.py CHANGED
@@ -1,8 +1,12 @@
1
- from transformers.modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
2
- from transformers import RobertaConfig, RobertaModel
3
 
 
 
4
 
5
- config = RobertaConfig.from_pretrained("./")
6
- model = RobertaModel(config)
7
- load_flax_checkpoint_in_pytorch_model(model, "./flax_model.msgpack")
8
- model.save_pretrained("./")
 
 
 
1
+ from jax import numpy as jnp
2
+ from transformers import FlaxRobertaForMaskedLM, RobertaForMaskedLM
3
 
4
+ def to_f32(t):
5
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
6
 
7
+ flax_model = FlaxRobertaForMaskedLM.from_pretrained("./")
8
+ flax_model.params = to_f32(flax_model.params)
9
+ flax_model.save_pretrained("./")
10
+
11
+ model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
12
+ model.save_pretrained("./", save_config=False)
run_mlm_flax.py CHANGED
@@ -110,9 +110,6 @@ class DataTrainingArguments:
110
  dataset_config_name: Optional[str] = field(
111
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
112
  )
113
- dataset_streaming: bool = field(
114
- default=False, metadata={"help": "Whether dataset_name should be retrieved using streaming if available."}
115
- )
116
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
117
  validation_file: Optional[str] = field(
118
  default=None,
@@ -322,7 +319,7 @@ if __name__ == "__main__":
322
  # download the dataset.
323
  if data_args.dataset_name is not None:
324
  # Downloading and loading a dataset from the hub.
325
- datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.dataset_streaming)
326
 
327
  if "validation" not in datasets.keys():
328
  datasets["validation"] = load_dataset(
@@ -330,14 +327,12 @@ if __name__ == "__main__":
330
  data_args.dataset_config_name,
331
  split=f"train[:{data_args.validation_split_percentage}%]",
332
  cache_dir=model_args.cache_dir,
333
- streaming=data_args.dataset_streaming,
334
  )
335
  datasets["train"] = load_dataset(
336
  data_args.dataset_name,
337
  data_args.dataset_config_name,
338
  split=f"train[{data_args.validation_split_percentage}%:]",
339
  cache_dir=model_args.cache_dir,
340
- streaming=data_args.dataset_streaming,
341
  )
342
  else:
343
  data_files = {}
@@ -456,6 +451,7 @@ if __name__ == "__main__":
456
  num_proc=data_args.preprocessing_num_workers,
457
  load_from_cache_file=not data_args.overwrite_cache,
458
  )
 
459
  # Enable tensorboard only on the master node
460
  has_tensorboard = is_tensorboard_available()
461
  if has_tensorboard and jax.process_index() == 0:
@@ -483,6 +479,7 @@ if __name__ == "__main__":
483
  "Please run pip install tensorboard to enable."
484
  )
485
 
 
486
  # Data collator
487
  # This one will take care of randomly masking the tokens.
488
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
@@ -491,7 +488,14 @@ if __name__ == "__main__":
491
  rng = jax.random.PRNGKey(training_args.seed)
492
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
493
 
494
- model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
 
 
 
 
 
495
 
496
  # Store some constant
497
  num_epochs = int(training_args.num_train_epochs)
@@ -526,17 +530,24 @@ if __name__ == "__main__":
526
  return traverse_util.unflatten_dict(flat_mask)
527
 
528
  # create adam optimizer
529
- adamw = optax.adamw(
530
- learning_rate=linear_decay_lr_schedule_fn,
531
- b1=training_args.adam_beta1,
532
- b2=training_args.adam_beta2,
533
- eps=training_args.adam_epsilon,
534
- weight_decay=training_args.weight_decay,
535
- mask=decay_mask_fn,
536
- )
 
 
 
 
 
 
 
537
 
538
  # Setup train state
539
- state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
540
 
541
  # Define gradient update step fn
542
  def train_step(state, batch, dropout_rng):
@@ -634,54 +645,43 @@ if __name__ == "__main__":
634
 
635
  train_metrics = []
636
 
637
- if training_args.save_strategy == "steps" and cur_step and cur_step % training_args.save_steps == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  if jax.process_index() == 0:
639
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
640
  model.save_pretrained(
641
- Path(str(training_args.output_dir)) / "checkpoints" / f"checkpoint-{cur_step}",
642
  params=params,
643
  push_to_hub=training_args.push_to_hub,
644
- temp_dir=True,
645
  commit_message=f"Saving weights and logs of step {cur_step}",
646
  )
647
-
648
- # ======================== Evaluating ==============================
649
- num_eval_samples = len(tokenized_datasets["validation"])
650
- eval_samples_idx = jnp.arange(num_eval_samples)
651
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
652
-
653
- eval_metrics = []
654
- for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
655
- samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
656
- model_inputs = data_collator(samples, pad_to_multiple_of=16)
657
-
658
- # Model forward
659
- model_inputs = shard(model_inputs.data)
660
- metrics = p_eval_step(state.params, model_inputs)
661
- eval_metrics.append(metrics)
662
-
663
- # normalize eval metrics
664
- eval_metrics = get_metrics(eval_metrics)
665
- eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
666
- eval_normalizer = eval_metrics.pop("normalizer")
667
- eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
668
-
669
- # Update progress bar
670
- epochs.desc = (
671
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
672
- )
673
-
674
- # Save metrics
675
- if has_tensorboard and jax.process_index() == 0:
676
- cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
677
- write_eval_metric(summary_writer, eval_metrics, cur_step)
678
-
679
- # save checkpoint after each epoch and push checkpoint to the hub
680
- if jax.process_index() == 0:
681
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
682
- model.save_pretrained(
683
- training_args.output_dir,
684
- params=params,
685
- push_to_hub=training_args.push_to_hub,
686
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
687
- )
 
110
  dataset_config_name: Optional[str] = field(
111
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
112
  )
 
 
 
113
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
114
  validation_file: Optional[str] = field(
115
  default=None,
 
319
  # download the dataset.
320
  if data_args.dataset_name is not None:
321
  # Downloading and loading a dataset from the hub.
322
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
323
 
324
  if "validation" not in datasets.keys():
325
  datasets["validation"] = load_dataset(
 
327
  data_args.dataset_config_name,
328
  split=f"train[:{data_args.validation_split_percentage}%]",
329
  cache_dir=model_args.cache_dir,
 
330
  )
331
  datasets["train"] = load_dataset(
332
  data_args.dataset_name,
333
  data_args.dataset_config_name,
334
  split=f"train[{data_args.validation_split_percentage}%:]",
335
  cache_dir=model_args.cache_dir,
 
336
  )
337
  else:
338
  data_files = {}
 
451
  num_proc=data_args.preprocessing_num_workers,
452
  load_from_cache_file=not data_args.overwrite_cache,
453
  )
454
+
455
  # Enable tensorboard only on the master node
456
  has_tensorboard = is_tensorboard_available()
457
  if has_tensorboard and jax.process_index() == 0:
 
479
  "Please run pip install tensorboard to enable."
480
  )
481
 
482
+
483
  # Data collator
484
  # This one will take care of randomly masking the tokens.
485
  data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
 
488
  rng = jax.random.PRNGKey(training_args.seed)
489
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
490
 
491
+ if model_args.model_name_or_path:
492
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
493
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
494
+ )
495
+ else:
496
+ model = FlaxAutoModelForMaskedLM.from_config(
497
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
498
+ )
499
 
500
  # Store some constant
501
  num_epochs = int(training_args.num_train_epochs)
 
530
  return traverse_util.unflatten_dict(flat_mask)
531
 
532
  # create adam optimizer
533
+ if training_args.adafactor:
534
+ # We use the default parameters here to initialize adafactor,
535
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
536
+ optimizer = optax.adafactor(
537
+ learning_rate=linear_decay_lr_schedule_fn,
538
+ )
539
+ else:
540
+ optimizer = optax.adamw(
541
+ learning_rate=linear_decay_lr_schedule_fn,
542
+ b1=training_args.adam_beta1,
543
+ b2=training_args.adam_beta2,
544
+ eps=training_args.adam_epsilon,
545
+ weight_decay=training_args.weight_decay,
546
+ mask=decay_mask_fn,
547
+ )
548
 
549
  # Setup train state
550
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
551
 
552
  # Define gradient update step fn
553
  def train_step(state, batch, dropout_rng):
 
645
 
646
  train_metrics = []
647
 
648
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
649
+ # ======================== Evaluating ==============================
650
+ num_eval_samples = len(tokenized_datasets["validation"])
651
+ eval_samples_idx = jnp.arange(num_eval_samples)
652
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
653
+
654
+ eval_metrics = []
655
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
656
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
657
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
658
+
659
+ # Model forward
660
+ model_inputs = shard(model_inputs.data)
661
+ metrics = p_eval_step(state.params, model_inputs)
662
+ eval_metrics.append(metrics)
663
+
664
+ # normalize eval metrics
665
+ eval_metrics = get_metrics(eval_metrics)
666
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
667
+ eval_normalizer = eval_metrics.pop("normalizer")
668
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
669
+
670
+ # Update progress bar
671
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
672
+
673
+ # Save metrics
674
+ if has_tensorboard and jax.process_index() == 0:
675
+ cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
676
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
677
+
678
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
679
+ # save checkpoint after each epoch and push checkpoint to the hub
680
  if jax.process_index() == 0:
681
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
682
  model.save_pretrained(
683
+ training_args.output_dir,
684
  params=params,
685
  push_to_hub=training_args.push_to_hub,
 
686
  commit_message=f"Saving weights and logs of step {cur_step}",
687
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_mlm_flax_stream.py ADDED
@@ -0,0 +1,719 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from collections import defaultdict
28
+ from dataclasses import dataclass, field
29
+
30
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
31
+ from pathlib import Path
32
+ from typing import Dict, List, Optional, Tuple
33
+
34
+ import datasets
35
+ import numpy as np
36
+ from datasets import load_dataset
37
+ from tqdm import tqdm
38
+
39
+ import flax
40
+ import jax
41
+ import jax.numpy as jnp
42
+ import kenlm
43
+ import optax
44
+ from flax import jax_utils, traverse_util
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard
47
+ from transformers import (
48
+ CONFIG_MAPPING,
49
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
50
+ AutoConfig,
51
+ AutoTokenizer,
52
+ FlaxAutoModelForMaskedLM,
53
+ HfArgumentParser,
54
+ PreTrainedTokenizerBase,
55
+ TensorType,
56
+ TrainingArguments,
57
+ is_tensorboard_available,
58
+ set_seed,
59
+ )
60
+
61
+
62
+ if datasets.__version__ <= "1.8.0":
63
+ raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
64
+
65
+
66
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
67
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
68
+
69
+
70
+ @dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
74
+ """
75
+
76
+ model_name_or_path: Optional[str] = field(
77
+ default=None,
78
+ metadata={
79
+ "help": "The model checkpoint for weights initialization."
80
+ "Don't set if you want to train a model from scratch."
81
+ },
82
+ )
83
+ model_type: Optional[str] = field(
84
+ default=None,
85
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
86
+ )
87
+ config_name: Optional[str] = field(
88
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
89
+ )
90
+ tokenizer_name: Optional[str] = field(
91
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
92
+ )
93
+ cache_dir: Optional[str] = field(
94
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
95
+ )
96
+ use_fast_tokenizer: bool = field(
97
+ default=True,
98
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
99
+ )
100
+ dtype: Optional[str] = field(
101
+ default="float32",
102
+ metadata={
103
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
104
+ },
105
+ )
106
+
107
+ @dataclass
108
+ class DataTrainingArguments:
109
+ """
110
+ Arguments pertaining to what data we are going to input our model for training and eval.
111
+ """
112
+
113
+ dataset_name: Optional[str] = field(
114
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
115
+ )
116
+ dataset_config_name: Optional[str] = field(
117
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
118
+ )
119
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
120
+ validation_file: Optional[str] = field(
121
+ default=None,
122
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
123
+ )
124
+ train_ref_file: Optional[str] = field(
125
+ default=None,
126
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
127
+ )
128
+ validation_ref_file: Optional[str] = field(
129
+ default=None,
130
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
131
+ )
132
+ overwrite_cache: bool = field(
133
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
134
+ )
135
+ validation_split_percentage: Optional[int] = field(
136
+ default=5,
137
+ metadata={
138
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
139
+ },
140
+ )
141
+ max_seq_length: Optional[int] = field(
142
+ default=None,
143
+ metadata={
144
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
145
+ "than this will be truncated. Default to the max input length of the model."
146
+ },
147
+ )
148
+ preprocessing_num_workers: Optional[int] = field(
149
+ default=None,
150
+ metadata={"help": "The number of processes to use for the preprocessing."},
151
+ )
152
+ mlm_probability: float = field(
153
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
154
+ )
155
+ pad_to_max_length: bool = field(
156
+ default=False,
157
+ metadata={
158
+ "help": "Whether to pad all samples to `max_seq_length`. "
159
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
160
+ },
161
+ )
162
+ line_by_line: bool = field(
163
+ default=False,
164
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
165
+ )
166
+ text_column_name: str = field(
167
+ default="text", metadata={"help": "The name of the column to retrieve the training text."}
168
+ )
169
+ shuffle_buffer_size: int = field(
170
+ default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
171
+ )
172
+ num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
173
+ num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
174
+
175
+ def __post_init__(self):
176
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
177
+ raise ValueError("Need either a dataset name or a training/validation file.")
178
+ else:
179
+ if self.train_file is not None:
180
+ extension = self.train_file.split(".")[-1]
181
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
182
+ if self.validation_file is not None:
183
+ extension = self.validation_file.split(".")[-1]
184
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
185
+
186
+
187
+ @flax.struct.dataclass
188
+ class FlaxDataCollatorForLanguageModeling:
189
+ """
190
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
191
+ are not all of the same length.
192
+
193
+ Args:
194
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
195
+ The tokenizer used for encoding the data.
196
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
197
+ The probability with which to (randomly) mask tokens in the input.
198
+
199
+ .. note::
200
+
201
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
202
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
203
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
204
+ argument :obj:`return_special_tokens_mask=True`.
205
+ """
206
+
207
+ tokenizer: PreTrainedTokenizerBase
208
+ mlm_probability: float = 0.15
209
+
210
+ def __post_init__(self):
211
+ if self.tokenizer.mask_token is None:
212
+ raise ValueError(
213
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
214
+ "You should pass `mlm=False` to train on causal language modeling instead."
215
+ )
216
+
217
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
218
+ # Handle dict or lists with proper padding and conversion to tensor.
219
+ batch = self.tokenizer.pad(examples, return_tensors=TensorType.NUMPY)
220
+
221
+ # If special token mask has been preprocessed, pop it from the dict.
222
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
223
+
224
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
225
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
226
+ )
227
+ return batch
228
+
229
+ def mask_tokens(
230
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
231
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
232
+ """
233
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
234
+ """
235
+ labels = inputs.copy()
236
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
237
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
238
+ special_tokens_mask = special_tokens_mask.astype("bool")
239
+
240
+ probability_matrix[special_tokens_mask] = 0.0
241
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
242
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
243
+
244
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
245
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
246
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
247
+
248
+ # 10% of the time, we replace masked input tokens with random word
249
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
250
+ indices_random &= masked_indices & ~indices_replaced
251
+
252
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
253
+ inputs[indices_random] = random_words[indices_random]
254
+
255
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
256
+ return inputs, labels
257
+
258
+
259
+
260
+ @dataclass
261
+ class SamplingArguments:
262
+ """
263
+ Arguments pertaining to how to perform sampling of the dataset.
264
+ """
265
+
266
+ perplexity_model: Optional[str] = field(
267
+ default="es.arpa.bin", metadata={"help": "kenlm model to use to get perplexity values."}
268
+ )
269
+ sampling_method: Optional[str] = field(
270
+ default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document."}
271
+ )
272
+ sampling_factor: Optional[int] = field(
273
+ default=1, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
274
+ )
275
+ quartiles: Optional[str] = field(
276
+ default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
277
+ )
278
+
279
+ def __post_init__(self):
280
+ self.quartiles = [float(q) for q in self.quartiles.split(",")]
281
+
282
+
283
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
284
+ num_samples = len(samples_idx)
285
+ samples_to_remove = num_samples % batch_size
286
+
287
+ if samples_to_remove != 0:
288
+ samples_idx = samples_idx[:-samples_to_remove]
289
+ sections_split = num_samples // batch_size
290
+ batch_idx = np.split(samples_idx, sections_split)
291
+ return batch_idx
292
+
293
+
294
+ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
295
+ """
296
+ The training iterator is advanced so that after groupifying the samples,
297
+ `num_samples` of length `max_seq_length` are returned.
298
+ """
299
+ num_total_tokens = max_seq_length * num_samples
300
+ samples = defaultdict(list)
301
+
302
+ i = 0
303
+ while i < num_total_tokens:
304
+ tokenized_samples = next(train_iterator)
305
+ i += len(tokenized_samples["input_ids"])
306
+
307
+ # concatenate tokenized samples to list
308
+ samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
309
+
310
+ # Concatenated tokens are split to lists of length `max_seq_length`.
311
+ # Note that remainedr of % max_seq_length are thrown away.
312
+ def group_texts(examples):
313
+ result = {
314
+ k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
315
+ for k, t in examples.items()
316
+ }
317
+ return result
318
+
319
+ grouped_samples = group_texts(samples)
320
+ return grouped_samples
321
+
322
+
323
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
324
+ summary_writer.scalar("train_time", train_time, step)
325
+
326
+ train_metrics = get_metrics(train_metrics)
327
+ for key, vals in train_metrics.items():
328
+ tag = f"train_{key}"
329
+ for i, val in enumerate(vals):
330
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
331
+
332
+
333
+ def write_eval_metric(summary_writer, eval_metrics, step):
334
+ for metric_name, value in eval_metrics.items():
335
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
336
+
337
+
338
+ if __name__ == "__main__":
339
+ # See all possible arguments in src/transformers/training_args.py
340
+ # or by passing the --help flag to this script.
341
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
342
+
343
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, SamplingArguments))
344
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345
+ # If we pass only one argument to the script and it's the path to a json file,
346
+ # let's parse it to get our arguments.
347
+ model_args, data_args, training_args, sampling_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348
+ else:
349
+ model_args, data_args, training_args, sampling_args = parser.parse_args_into_dataclasses()
350
+
351
+ if (
352
+ os.path.exists(training_args.output_dir)
353
+ and os.listdir(training_args.output_dir)
354
+ and training_args.do_train
355
+ and not training_args.overwrite_output_dir
356
+ ):
357
+ raise ValueError(
358
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
359
+ "Use --overwrite_output_dir to overcome."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ level="INFO",
366
+ datefmt="[%X]",
367
+ )
368
+
369
+ # Log on each process the small summary:
370
+ logger = logging.getLogger(__name__)
371
+ logger.warning(
372
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
373
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
374
+ )
375
+
376
+ # Set the verbosity to info of the Transformers logger (on main process only):
377
+ logger.info(f"Training/evaluation parameters {training_args}")
378
+
379
+ # Set seed before initializing model.
380
+ set_seed(training_args.seed)
381
+
382
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
383
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
384
+ # (the dataset will be downloaded automatically from the datasets Hub).
385
+ #
386
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
387
+ # 'text' is found. You can easily tweak this behavior (see below).
388
+ if data_args.dataset_name is not None:
389
+ # Downloading and loading a dataset from the hub.
390
+ dataset = load_dataset(
391
+ data_args.dataset_name,
392
+ data_args.dataset_config_name,
393
+ cache_dir=model_args.cache_dir,
394
+ streaming=True,
395
+ split="train",
396
+ )
397
+
398
+ if model_args.config_name:
399
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
400
+ elif model_args.model_name_or_path:
401
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
402
+ else:
403
+ config = CONFIG_MAPPING[model_args.model_type]()
404
+ logger.warning("You are instantiating a new config instance from scratch.")
405
+
406
+ if model_args.tokenizer_name:
407
+ tokenizer = AutoTokenizer.from_pretrained(
408
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
409
+ )
410
+ elif model_args.model_name_or_path:
411
+ tokenizer = AutoTokenizer.from_pretrained(
412
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
413
+ )
414
+ else:
415
+ raise ValueError(
416
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
417
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
418
+ )
419
+
420
+ # Loading 5-gram model
421
+ # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
422
+ if sampling_args.sampling_method:
423
+ pp_model = kenlm.Model(sampling_args.perplexity_model)
424
+
425
+ def get_perplexity(doc):
426
+ doc_log_score, doc_length = 0, 0
427
+ for line in doc.split("\n"):
428
+ log_score = pp_model.score(line)
429
+ length = len(line.split()) + 1
430
+ doc_log_score += log_score
431
+ doc_length += length
432
+ return 10.0 ** (-doc_log_score / doc_length)
433
+
434
+ def should_keep_doc_step(doc, factor=1, boundaires=None):
435
+ perplexity = get_perplexity(doc)
436
+ if boundaires is None:
437
+ boundaires = [536394.99320948, 662247.50212365, 919250.87225178]
438
+ if perplexity <= boundaires[0]:
439
+ quartile_range = boundaires[0]
440
+ elif boundaires[0] < perplexity < boundaires[1]:
441
+ quartile_range = boundaires[1] - boundaires[0]
442
+ elif boundaires[1] < perplexity < boundaires[2]:
443
+ quartile_range = boundaires[2] - boundaires[1]
444
+ elif perplexity >= boundaires[2]:
445
+ quartile_range = 100 * boundaires[2]
446
+ probability = factor / quartile_range
447
+ return np.random() < probability
448
+
449
+ def should_keep_doc_gaussian(doc, factor=0.4, boundaires=None):
450
+ perplexity = get_perplexity(doc)
451
+ if boundaires is not None:
452
+ m = boundaires[1]
453
+ else:
454
+ m = 662247.50212365
455
+ weighted_perplexity = factor*np.exp(-9/2*((perplexity-m)/m)**2)
456
+ return np.random.uniform() < weighted_perplexity
457
+
458
+ if sampling_args.sampling_method == "gaussian":
459
+ should_keep_doc = should_keep_doc_gaussian
460
+ else:
461
+ should_keep_doc = should_keep_doc_gaussian
462
+
463
+ def tokenize_function(examples):
464
+ return tokenizer([
465
+ example for example in examples[data_args.text_column_name]
466
+ if should_keep_doc(
467
+ example,
468
+ factor=sampling_args.sampling_factor,
469
+ boundaries=sampling_args.boundaries
470
+ )
471
+ ], return_special_tokens_mask=True)
472
+ else:
473
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
474
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
475
+ # efficient when it receives the `special_tokens_mask`.
476
+ def tokenize_function(examples):
477
+ return tokenizer(
478
+ examples[data_args.text_column_name],
479
+ return_special_tokens_mask=True
480
+ )
481
+
482
+ tokenized_datasets = dataset.map(
483
+ tokenize_function,
484
+ batched=True,
485
+ )
486
+
487
+ shuffle_seed = training_args.seed
488
+ tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
489
+
490
+ # Enable tensorboard only on the master node
491
+ has_tensorboard = is_tensorboard_available()
492
+ if has_tensorboard and jax.process_index() == 0:
493
+ try:
494
+ from flax.metrics.tensorboard import SummaryWriter
495
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
496
+ # Enable Weight&Biases
497
+ import wandb
498
+ wandb.init(
499
+ entity='wandb',
500
+ project='hf-flax-bertin-roberta-es',
501
+ sync_tensorboard=True,
502
+ )
503
+ wandb.config.update(training_args)
504
+ wandb.config.update(model_args)
505
+ wandb.config.update(data_args)
506
+ except ImportError as ie:
507
+ has_tensorboard = False
508
+ logger.warning(
509
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
510
+ )
511
+ else:
512
+ logger.warning(
513
+ "Unable to display metrics through TensorBoard because the package is not installed: "
514
+ "Please run pip install tensorboard to enable."
515
+ )
516
+
517
+ # Data collator
518
+ # This one will take care of randomly masking the tokens.
519
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
520
+
521
+ # Initialize our training
522
+ rng = jax.random.PRNGKey(training_args.seed)
523
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
524
+
525
+ if model_args.model_name_or_path:
526
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
527
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
528
+ )
529
+ else:
530
+ model = FlaxAutoModelForMaskedLM.from_config(
531
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
532
+ )
533
+
534
+ # Store some constant
535
+ num_epochs = int(training_args.num_train_epochs)
536
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
537
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
538
+
539
+ # define number steps per stream epoch
540
+ num_train_steps = data_args.num_train_steps
541
+
542
+ # Create learning rate schedule
543
+ warmup_fn = optax.linear_schedule(
544
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
545
+ )
546
+ decay_fn = optax.linear_schedule(
547
+ init_value=training_args.learning_rate,
548
+ end_value=0,
549
+ transition_steps=num_train_steps - training_args.warmup_steps,
550
+ )
551
+ linear_decay_lr_schedule_fn = optax.join_schedules(
552
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
553
+ )
554
+
555
+ # We use Optax's "masking" functionality to not apply weight decay
556
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
557
+ # mask boolean with the same structure as the parameters.
558
+ # The mask is True for parameters that should be decayed.
559
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
560
+ # For other models, one should correct the layer norm parameter naming
561
+ # accordingly.
562
+ def decay_mask_fn(params):
563
+ flat_params = traverse_util.flatten_dict(params)
564
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
565
+ return traverse_util.unflatten_dict(flat_mask)
566
+
567
+ # create adam optimizer
568
+ adamw = optax.adamw(
569
+ learning_rate=linear_decay_lr_schedule_fn,
570
+ b1=training_args.adam_beta1,
571
+ b2=training_args.adam_beta2,
572
+ eps=training_args.adam_epsilon,
573
+ weight_decay=training_args.weight_decay,
574
+ mask=decay_mask_fn,
575
+ )
576
+
577
+ # Setup train state
578
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
579
+
580
+ # Define gradient update step fn
581
+ def train_step(state, batch, dropout_rng):
582
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
583
+
584
+ def loss_fn(params):
585
+ labels = batch.pop("labels")
586
+
587
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
588
+
589
+ # compute loss, ignore padded input tokens
590
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
591
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
592
+
593
+ # take average
594
+ loss = loss.sum() / label_mask.sum()
595
+
596
+ return loss
597
+
598
+ grad_fn = jax.value_and_grad(loss_fn)
599
+ loss, grad = grad_fn(state.params)
600
+ grad = jax.lax.pmean(grad, "batch")
601
+ new_state = state.apply_gradients(grads=grad)
602
+
603
+ metrics = jax.lax.pmean(
604
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
605
+ )
606
+
607
+ return new_state, metrics, new_dropout_rng
608
+
609
+ # Create parallel version of the train step
610
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
611
+
612
+ # Define eval fn
613
+ def eval_step(params, batch):
614
+ labels = batch.pop("labels")
615
+
616
+ logits = model(**batch, params=params, train=False)[0]
617
+
618
+ # compute loss, ignore padded input tokens
619
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
620
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
621
+
622
+ # compute accuracy
623
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
624
+
625
+ # summarize metrics
626
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
627
+ metrics = jax.lax.psum(metrics, axis_name="batch")
628
+
629
+ return metrics
630
+
631
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
632
+
633
+ # Replicate the train state on each device
634
+ state = jax_utils.replicate(state)
635
+
636
+ train_time = 0
637
+ train_start = time.time()
638
+ train_metrics = []
639
+ eval_metrics = []
640
+
641
+ training_iter = iter(tokenized_datasets)
642
+
643
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
644
+ eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
645
+
646
+ steps = tqdm(range(num_train_steps), desc="Training...", position=0)
647
+ for step in range(num_train_steps):
648
+ # ======================== Training ================================
649
+ try:
650
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
651
+ except StopIteration:
652
+ # Once the end of the dataset stream is reached, the training iterator
653
+ # is reinitialized and reshuffled and a new eval dataset is randomely chosen.
654
+ shuffle_seed += 1
655
+ tokenized_datasets.set_epoch(shuffle_seed)
656
+
657
+ training_iter = iter(tokenized_datasets)
658
+
659
+ eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
660
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
661
+
662
+ # process input samples
663
+ model_inputs = data_collator(samples)
664
+
665
+ # Model forward
666
+ model_inputs = shard(model_inputs.data)
667
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
668
+
669
+ train_metrics.append(train_metric)
670
+
671
+ if step % training_args.logging_steps == 0 and step > 0:
672
+ steps.write(
673
+ f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
674
+ )
675
+ train_time += time.time() - train_start
676
+ if has_tensorboard and jax.process_index() == 0:
677
+ write_train_metric(summary_writer, train_metrics, train_time, step)
678
+ train_metrics = []
679
+
680
+ # ======================== Evaluating ==============================
681
+ if step % training_args.eval_steps == 0 and step > 0:
682
+ eval_samples_idx = jnp.arange(data_args.num_eval_samples)
683
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
684
+
685
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
686
+ # process input samples
687
+ batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
688
+ model_inputs = data_collator(batch_eval_samples)
689
+
690
+ # Model forward
691
+ model_inputs = shard(model_inputs.data)
692
+ metrics = p_eval_step(state.params, model_inputs)
693
+ eval_metrics.append(metrics)
694
+
695
+ # normalize eval metrics
696
+ eval_metrics = get_metrics(eval_metrics)
697
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
698
+ eval_normalizer = eval_metrics.pop("normalizer")
699
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
700
+
701
+ # Update progress bar
702
+ steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
703
+
704
+ if has_tensorboard and jax.process_index() == 0:
705
+ write_eval_metric(summary_writer, eval_metrics, step)
706
+ eval_metrics = []
707
+
708
+ # save checkpoint after each epoch and push checkpoint to the hub
709
+ if jax.process_index() == 0:
710
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
711
+ model.save_pretrained(
712
+ training_args.output_dir,
713
+ params=params,
714
+ push_to_hub=training_args.push_to_hub,
715
+ commit_message=f"Saving weights and logs of step {step+1}",
716
+ )
717
+
718
+ # update tqdm bar
719
+ steps.update(1)
run_stream.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://arxiv.org/pdf/1907.11692.pdf for base model
2
+ python -c "import jax; print('TPUs', jax.device_count())"
3
+ ./run_mlm_flax_stream.py \
4
+ --output_dir="./" \
5
+ --model_type="roberta" \
6
+ --config_name="./config-base.json" \
7
+ --tokenizer_name="./" \
8
+ --dataset_name="mc4" \
9
+ --dataset_config_name="es" \
10
+ --max_seq_length="128" \
11
+ --pad_to_max_length \
12
+ --per_device_train_batch_size="256" \
13
+ --per_device_eval_batch_size="256" \
14
+ --adam_beta1="0.9" \
15
+ --adam_beta2="0.98" \
16
+ --adam_epsilon="1e-6" \
17
+ --learning_rate="6e-4" \
18
+ --weight_decay="0.01" \
19
+ --save_strategy="steps" \
20
+ --save_steps="1000" \
21
+ --save_total_limit="5" \
22
+ --warmup_steps="24000" \
23
+ --overwrite_output_dir \
24
+ --num_train_steps="500000" \
25
+ --eval_steps="1000" \
26
+ --dtype="bfloat16" \
27
+ --logging_steps="500" 2>&1 | tee run_stream.log
test_script.py DELETED
@@ -1,45 +0,0 @@
1
- """CONFIG"""
2
- #!/usr/bin/env python3
3
- from transformers import RobertaConfig
4
- config = RobertaConfig.from_pretrained("roberta-large")
5
- config.save_pretrained("./")
6
-
7
- """TOKENIZER"""
8
- #!/usr/bin/env python3
9
- from datasets import load_dataset
10
- from tokenizers import ByteLevelBPETokenizer
11
- # load dataset
12
- dataset = load_dataset("large_spanish_corpus")
13
- # Instantiate tokenizer
14
- tokenizer = ByteLevelBPETokenizer()
15
- def batch_iterator(batch_size=1000):
16
- for i in range(0, len(dataset), batch_size):
17
- yield dataset[i: i + batch_size]["text"]
18
- # Customized training
19
- tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
20
- "<s>",
21
- "<pad>",
22
- "</s>",
23
- "<unk>",
24
- "<mask>",
25
- ])
26
- # Save files to disk
27
- tokenizer.save("./tokenizer.json")
28
-
29
- """TOKENIZER"""
30
- #!/usr/bin/env bash
31
- ./run_mlm_flax.py \
32
- --output_dir="./" \
33
- --model_type="roberta" \
34
- --config_name="./" \
35
- --tokenizer_name="./" \
36
- --dataset_name="large_spanish_corpus" \
37
- --dataset_config_name \ # I think this would be empty
38
- --max_seq_length="128" \
39
- --per_device_train_batch_size="4" \
40
- --per_device_eval_batch_size="4" \
41
- --learning_rate="3e-4" \
42
- --warmup_steps="1000" \
43
- --overwrite_output_dir \
44
- --num_train_epochs="8" \
45
- --push_to_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokens.py CHANGED
@@ -3,11 +3,11 @@ from datasets import load_dataset
3
  from tokenizers import ByteLevelBPETokenizer
4
 
5
  # Load dataset
6
- dataset = load_dataset("oscar", "unshuffled_deduplicated_es")
7
 
8
  # Instantiate tokenizer
9
  tokenizer = ByteLevelBPETokenizer()
10
- def batch_iterator(batch_size=100_000_000):
11
  for i in range(0, len(dataset), batch_size):
12
  yield dataset["text"][i: i + batch_size]
13
 
 
3
  from tokenizers import ByteLevelBPETokenizer
4
 
5
  # Load dataset
6
+ dataset = load_dataset("oscar", "unshuffled_deduplicated_es", split="train[:5000000]")
7
 
8
  # Instantiate tokenizer
9
  tokenizer = ByteLevelBPETokenizer()
10
+ def batch_iterator(batch_size=100_000):
11
  for i in range(0, len(dataset), batch_size):
12
  yield dataset["text"][i: i + batch_size]
13