|
import dataclasses |
|
import pprint |
|
import time |
|
from functools import partial |
|
import json |
|
import base64 |
|
from multiprocessing import Pool |
|
|
|
import h5py |
|
import mlxu |
|
from ml_collections.config_dict import config_dict |
|
from ml_collections import ConfigDict |
|
from tqdm import tqdm, trange |
|
import numpy as np |
|
|
|
from datasets import load_dataset, load_from_disk |
|
|
|
|
|
class DatasetFactory(object): |
|
""" Datset builder class. """ |
|
|
|
@staticmethod |
|
def get_default_config(updates=None): |
|
config = ConfigDict() |
|
config.type = 'huggingface' |
|
config.text_processor = TextProcessor.get_default_config() |
|
config.huggingface_dataset = HuggingfaceDataset.get_default_config() |
|
config.json_dataset = JsonDataset.get_default_config() |
|
|
|
if updates is not None: |
|
config.update(ConfigDict(updates).copy_and_resolve_references()) |
|
return config |
|
|
|
@classmethod |
|
def load_dataset(cls, config, tokenizer, **kwargs): |
|
config = cls.get_default_config(config) |
|
text_processor = TextProcessor(config.text_processor, tokenizer) |
|
if config.type == 'huggingface': |
|
return HuggingfaceDataset( |
|
config.huggingface_dataset, tokenizer, text_processor, **kwargs |
|
) |
|
elif config.type == 'json': |
|
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs) |
|
else: |
|
raise ValueError(f'Unknown dataset type: {config.type}') |
|
|
|
def __init__(self): |
|
raise ValueError('DatasetFactory is a static class and should not be instantiated.') |
|
|
|
|
|
class TextProcessor(object): |
|
""" Example processor that converts a dictionary of texts into tokens. """ |
|
|
|
@staticmethod |
|
def get_default_config(updates=None): |
|
config = ConfigDict() |
|
config.fields_from_example = '' |
|
config.fields = '' |
|
config.subfield_separator = ' ' |
|
config.add_bos_token = True |
|
config.add_eos_token = True |
|
config.prepend_text = '' |
|
config.base64_token_dtype = 'i4' |
|
if updates is not None: |
|
config.update(ConfigDict(updates).copy_and_resolve_references()) |
|
return config |
|
|
|
def __init__(self, config, tokenizer): |
|
self.config = self.get_default_config(config) |
|
assert self.config.fields != '' or self.config.fields_from_example != '', ( |
|
'Either fields or fields_from_example must be specified.' |
|
) |
|
self.tokenizer = tokenizer |
|
|
|
def __call__(self, example, has_aux=False): |
|
if has_aux: |
|
example, *aux = example |
|
else: |
|
aux = tuple() |
|
token_buffer = [] |
|
loss_mask_buffer = [] |
|
|
|
if self.config.add_bos_token: |
|
token_buffer.append(self.tokenizer.bos_token_id) |
|
loss_mask_buffer.append(0.0) |
|
|
|
if self.config.fields_from_example != '': |
|
fields = example[self.config.fields_from_example].split(',') |
|
else: |
|
fields = self.config.fields.split(',') |
|
|
|
for i, field in enumerate(fields): |
|
if field.startswith('[') and field.endswith(']'): |
|
|
|
field = field[1:-1] |
|
mask = 0.0 |
|
else: |
|
mask = 1.0 |
|
|
|
if field.startswith('<|') and field.endswith('|>'): |
|
|
|
field = field[2:-2] |
|
if field == 'bos': |
|
token_buffer.append(self.tokenizer.bos_token_id) |
|
elif field == 'eos': |
|
token_buffer.append(self.tokenizer.eos_token_id) |
|
else: |
|
|
|
token_buffer.append(int(field)) |
|
loss_mask_buffer.append(mask) |
|
elif field.startswith('{') and field.endswith('}'): |
|
field = field[1:-1] |
|
|
|
tokens = np.frombuffer( |
|
base64.b64decode(example[field]), |
|
dtype=self.config.base64_token_dtype |
|
).tolist() |
|
token_buffer.extend(tokens) |
|
loss_mask_buffer.extend([mask for _ in range(len(tokens))]) |
|
else: |
|
subfields = field.split('+') |
|
text = self.config.subfield_separator.join( |
|
[example[subfield] for subfield in subfields] |
|
) |
|
if i == 0: |
|
text = self.config.prepend_text + text |
|
tokens = self.tokenizer.encode(text) |
|
token_buffer.extend(tokens) |
|
loss_mask_buffer.extend([mask for _ in range(len(tokens))]) |
|
|
|
if self.config.add_eos_token: |
|
token_buffer.append(self.tokenizer.eos_token_id) |
|
loss_mask_buffer.append(1.0) |
|
|
|
return token_buffer, loss_mask_buffer, *aux |
|
|
|
|
|
class HuggingfaceDataset(object): |
|
""" Huggingface dataset, where the dataset is loaded using the huggingface |
|
datasets.load_dataset() function. |
|
""" |
|
|
|
@staticmethod |
|
def get_default_config(updates=None): |
|
config = ConfigDict() |
|
config.path = 'c4' |
|
config.name = 'en' |
|
config.split = 'train' |
|
config.streaming = False |
|
config.seq_length = 1024 |
|
config.batch_size = 8 |
|
config.always_start_with_bos = False |
|
config.start_seek_loc = 0 |
|
config.tokens_count_at_start = 0 |
|
config.batch_token_dtype = 'i4' |
|
|
|
if updates is not None: |
|
config.update(ConfigDict(updates).copy_and_resolve_references()) |
|
return config |
|
|
|
def __init__(self, config, tokenizer, text_processor, eval_dataset=False): |
|
self.config = self.get_default_config(config) |
|
name = self.config.name if self.config.name != '' else None |
|
split = self.config.split if self.config.split != '' else None |
|
self._tokenizer = tokenizer |
|
self._text_processor = text_processor |
|
self._dataset = load_from_disk( |
|
self.config.path |
|
)[split] |
|
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset)) |
|
self._eval_dataset = eval_dataset |
|
self._train_epochs = 0 |
|
self._dataset_loc = self.config.start_seek_loc |
|
self._total_tokens = self.config.tokens_count_at_start |
|
self._index = 0 |
|
|
|
def __iter__(self): |
|
chunk_size = self.config.batch_size * self.config.seq_length |
|
total_tokens = 0 |
|
while True: |
|
token_buffer = [] |
|
loss_mask_buffer = [] |
|
if not self._eval_dataset: |
|
self._shuffle() |
|
for index, example in enumerate(self._dataset): |
|
self._index = index |
|
if not self._eval_dataset and self._dataset_loc > index: |
|
continue |
|
tokens, loss_masks = self.text_processor(example) |
|
token_buffer.extend(tokens) |
|
loss_mask_buffer.extend(loss_masks) |
|
while len(token_buffer) > chunk_size + 1: |
|
self._total_tokens += chunk_size |
|
metrics = { |
|
'dataset_example_index': index, |
|
'dataset_total_tokens': self._total_tokens, |
|
'epoch': self._train_epochs, |
|
} |
|
batch = { |
|
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape( |
|
self.config.batch_size, -1 |
|
), |
|
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape( |
|
self.config.batch_size, -1 |
|
), |
|
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( |
|
self.config.batch_size, -1 |
|
), |
|
} |
|
if self.config.always_start_with_bos: |
|
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id |
|
yield batch, metrics |
|
token_buffer = token_buffer[chunk_size:] |
|
loss_mask_buffer = loss_mask_buffer[chunk_size:] |
|
|
|
if self._eval_dataset: |
|
break |
|
else: |
|
self._dataset_loc = 0 |
|
self._shuffle() |
|
self._train_epochs += 1 |
|
print(f"TRAIN {self._train_epochs} EPOCH DONE") |
|
|
|
def _shuffle(self): |
|
self._dataset = self._dataset.shuffle(buffer_size=100) |
|
|
|
def get_state_dict(self): |
|
return dict( |
|
config=self.config, |
|
dataset_loc=self._index, |
|
total_tokens=self._total_tokens, |
|
epochs=self._train_epochs, |
|
) |
|
|
|
def load_state_dict(self, state_dict): |
|
if 'config' in state_dict: |
|
self.config.update(ConfigDict(state_dict['config'])) |
|
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc) |
|
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) |
|
self._train_epochs = state_dict.get('epochs', 0) |
|
|
|
@property |
|
def seq_length(self): |
|
return self.config.seq_length |
|
|
|
@property |
|
def tokenizer(self): |
|
return self._tokenizer |
|
|
|
@property |
|
def text_processor(self): |
|
return self._text_processor |
|
|
|
@property |
|
def dataset(self): |
|
return self._dataset |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self._tokenizer) |
|
|
|
|
|
class JsonDataset(object): |
|
""" JSON dataset, where each line of the data file contains a JSON |
|
dictionary with text fields. |
|
""" |
|
|
|
@staticmethod |
|
def get_default_config(updates=None): |
|
config = ConfigDict() |
|
config.path = '' |
|
config.seq_length = 1024 |
|
config.batch_size = 8 |
|
config.always_start_with_bos = False |
|
config.start_seek_loc = 0 |
|
config.example_index_at_start = 0 |
|
config.tokens_count_at_start = 0 |
|
config.tokenizer_processes = 1 |
|
config.tokenizer_parallel_chunk_size = 32 |
|
config.tokenizer_parallel_batch_size = 1024 |
|
config.throughput_average_window_size = 200 |
|
|
|
if updates is not None: |
|
config.update(ConfigDict(updates).copy_and_resolve_references()) |
|
return config |
|
|
|
def __init__(self, config, tokenizer, text_processor): |
|
self.config = self.get_default_config(config) |
|
assert self.config.path != '' |
|
self._tokenizer = tokenizer |
|
self._text_processor = text_processor |
|
self._index = self.config.example_index_at_start |
|
self._file_loc = self.config.start_seek_loc |
|
self._total_tokens = self.config.tokens_count_at_start |
|
|
|
def parse_json(self, line): |
|
if not line or line == '\n': |
|
return None |
|
try: |
|
data = json.loads(line) |
|
except json.decoder.JSONDecodeError: |
|
print(f'Error parsing json line:\n{line}') |
|
return None |
|
return data |
|
|
|
def json_iterator(self): |
|
with mlxu.open_file(self.config.path, 'r') as fin: |
|
fin.seek(self._file_loc) |
|
while True: |
|
line = fin.readline() |
|
self._file_loc = fin.tell() |
|
if not line: |
|
self._index = 0 |
|
fin.seek(0) |
|
continue |
|
|
|
data = self.parse_json(line) |
|
if data is not None: |
|
|
|
yield data, self._file_loc, self._index |
|
self._index += 1 |
|
|
|
def batched(self, iterator, batch_size): |
|
batch = [] |
|
for example in iterator: |
|
batch.append(example) |
|
if len(batch) == batch_size: |
|
yield batch |
|
batch = [] |
|
if len(batch) > 0: |
|
yield batch |
|
|
|
def parallel_example_iterator(self): |
|
if self.config.tokenizer_processes == 1: |
|
for example, loc, index in self.json_iterator(): |
|
yield self.text_processor((example, loc, index), has_aux=True) |
|
else: |
|
process_pool = Pool(self.config.tokenizer_processes) |
|
batched_iterator = self.batched( |
|
self.json_iterator(), self.config.tokenizer_parallel_batch_size |
|
) |
|
with process_pool as pool: |
|
map_fn = partial(self.text_processor, has_aux=True) |
|
next_batch = pool.map_async( |
|
map_fn, next(batched_iterator), |
|
chunksize=self.config.tokenizer_parallel_chunk_size |
|
) |
|
while True: |
|
current_batch = next_batch |
|
next_batch = pool.map_async( |
|
map_fn, next(batched_iterator), |
|
chunksize=self.config.tokenizer_parallel_chunk_size |
|
) |
|
for example in current_batch.get(): |
|
yield example |
|
|
|
def __iter__(self): |
|
chunk_size = self.config.batch_size * self.config.seq_length |
|
token_buffer = [] |
|
loss_mask_buffer = [] |
|
last_time = 0.0 |
|
step_times = [] |
|
start_time = time.time() |
|
start_tokens = self._total_tokens |
|
for tokens, loss_masks, loc, index in self.parallel_example_iterator(): |
|
token_buffer.extend(tokens) |
|
loss_mask_buffer.extend(loss_masks) |
|
while len(token_buffer) > chunk_size + 1: |
|
self._total_tokens += chunk_size |
|
step_times.append(time.time() - last_time) |
|
last_time = time.time() |
|
if len(step_times) > self.config.throughput_average_window_size: |
|
step_times = step_times[-self.config.throughput_average_window_size:] |
|
average_throughput = chunk_size / np.mean(step_times) |
|
accumulated_throughput = ( |
|
(self._total_tokens - start_tokens) / (time.time() - start_time) |
|
) |
|
metrics = { |
|
'dataset_file_loc': loc, |
|
'dataset_example_index': index, |
|
'dataset_total_tokens': self._total_tokens, |
|
'dataset_accumulated_tps': accumulated_throughput, |
|
'dataset_average_tps': average_throughput, |
|
} |
|
batch = { |
|
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape( |
|
self.config.batch_size, -1 |
|
), |
|
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape( |
|
self.config.batch_size, -1 |
|
), |
|
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( |
|
self.config.batch_size, -1 |
|
), |
|
} |
|
if self.config.always_start_with_bos: |
|
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id |
|
yield batch, metrics |
|
token_buffer = token_buffer[chunk_size:] |
|
loss_mask_buffer = loss_mask_buffer[chunk_size:] |
|
|
|
def get_state_dict(self): |
|
return dict( |
|
config=self.config, |
|
index=self._index, |
|
file_loc=self._file_loc, |
|
total_tokens=self._total_tokens, |
|
) |
|
|
|
def load_state_dict(self, state_dict): |
|
if 'config' in state_dict: |
|
self.config.update(ConfigDict(state_dict['config'])) |
|
self._index = state_dict.get('index', self.config.example_index_at_start) |
|
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc) |
|
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) |
|
|
|
@property |
|
def seq_length(self): |
|
return self.config.seq_length |
|
|
|
@property |
|
def tokenizer(self): |
|
return self._tokenizer |
|
|
|
@property |
|
def text_processor(self): |
|
return self._text_processor |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self.tokenizer) |
|
|