|
import os |
|
import torch |
|
import numpy as np |
|
import torch.distributed as dist |
|
from transformers.utils import logging |
|
from transformers import AutoTokenizer |
|
from itertools import cycle |
|
from typing import List |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Memory(torch.nn.Module): |
|
def __init__( |
|
self, |
|
model_config, |
|
k_seq_dim:int=2, |
|
v_seq_dim:int=2, |
|
): |
|
"""Setup necessary attributes.""" |
|
super().__init__() |
|
|
|
self.model_config = model_config |
|
|
|
|
|
self.k_seq_dim = k_seq_dim |
|
self.v_seq_dim = v_seq_dim |
|
self.num_layers = model_config.num_hidden_layers |
|
self.max_position_embeddings = model_config.max_position_embeddings |
|
self.rng = np.random.default_rng(42) |
|
|
|
self.ultragist_window = model_config.ultragist_window |
|
self.ultragist_stride = model_config.ultragist_stride |
|
self.ultragist_attn = model_config.ultragist_attn |
|
self.ultragist_ratio = model_config.ultragist_ratio |
|
self.ultragist_ratio_mix = model_config.ultragist_ratio_mix |
|
self.ultragist_param = model_config.ultragist_param |
|
self.ultragist_sink_size = model_config.ultragist_sink_size |
|
self.ultragist_attend_prev = model_config.ultragist_attend_prev |
|
|
|
self.ultragist_tokens = torch.zeros(1, dtype=torch.long) + model_config.vocab_size |
|
|
|
self._post_validation() |
|
self.reset() |
|
|
|
def _post_validation(self, verbose=True): |
|
assert self.ultragist_window >= self.ultragist_stride, f"Make sure the ultragist_window {self.ultragist_window} >= ultragist_stride {self.ultragist_stride}!" |
|
for ratio in self.ultragist_ratio: |
|
assert ratio >= 0, f"Make sure all ultragist ratios are greater than or equal to 0, found {self.ultragist_ratio}!" |
|
assert self.ultragist_attn in ["segmentation", "step-expansion", "full-coverage"], f"ultragist_attn {self.ultragist_attn} not implemented!" |
|
assert self.ultragist_ratio_mix in ["instance-random", "step-random", "sequence", "join"] or "adapt-" in self.ultragist_ratio_mix, f"ultragist_ratio_mix {self.ultragist_ratio_mix} not implemented!" |
|
if self.ultragist_ratio_mix == "join": |
|
|
|
|
|
pass |
|
|
|
self._cpu = torch.device("cpu") |
|
|
|
if verbose: |
|
info = f"applying ultragist on {self.ultragist_param} (the ultragist embedding is initialized from {'bos' if self.model_config.ultragist_embed_init == 'bos' else 'eos'} embedding), with window size {self.ultragist_window}, stride {self.ultragist_stride}, {self.ultragist_attn} attention{' (attending to previous ultragists)' if self.ultragist_attend_prev else ' (no attending to previous ultragists)'}, sink size {self.ultragist_sink_size}, condensing ratio {self.ultragist_ratio} (mixed by {self.ultragist_ratio_mix})..." |
|
logger.info(info) |
|
|
|
def set(self, verbose=True, **kwargs): |
|
if "ultragist_ratio_mix" in kwargs and kwargs["ultragist_ratio_mix"] == "join" and self.ultragist_ratio_mix != "join": |
|
raise ValueError(f"You cannot switch ultragist_ratio_mix from non-join strategy to join!") |
|
if self.ultragist_ratio_mix == "join" and "ultragist_ratio" in kwargs and sorted(kwargs["ultragist_ratio"]) != sorted(self.ultragist_ratio): |
|
raise ValueError(f"You cannot change ultragist_ratio given ultragist_ratio_mix=join!") |
|
for k, v in kwargs.items(): |
|
setattr(self, k, v) |
|
self._post_validation(verbose=verbose) |
|
|
|
def reset(self): |
|
"""Initialize attributes for a new sequence.""" |
|
|
|
self._start_idx = 0 |
|
|
|
self._end_idx = 0 |
|
|
|
self._total_ultragist_sizes = [] |
|
|
|
self._main_ultragist_sizes = [] |
|
|
|
self._batch_loss = None |
|
|
|
self._valid_token_num = None |
|
|
|
self._step_idx = 0 |
|
|
|
|
|
self._ratio = None |
|
self._ultragist_ratio_iter = None |
|
|
|
self.all_input_ids = torch.tensor([], dtype=torch.long) |
|
self.all_attention_mask = torch.tensor([], dtype=torch.long) |
|
if hasattr(self, "all_labels"): |
|
del self.all_labels |
|
|
|
|
|
self.raw_activations = [(None, None) for _ in range(self.num_layers)] |
|
|
|
self.sink_activations = [(None, None) for _ in range(self.num_layers)] |
|
|
|
|
|
if self.ultragist_ratio_mix == "join": |
|
self.l1_to_ln_ultragist_activations = [ |
|
[(None, None) for _ in range(self.num_layers)] |
|
for _ in self.ultragist_ratio |
|
] |
|
else: |
|
self.l1_to_ln_ultragist_activations = [ |
|
[(None, None) for _ in range(self.num_layers)] |
|
] |
|
|
|
def rewind(self, size=None, trim=False): |
|
""" |
|
Rewind raw activations that have not been condensed yet. |
|
|
|
Args: |
|
trim: if true, the input_ids corresponding to the raw activations are trimmed. |
|
""" |
|
raw_memory_size = self.get_memory_size()[1] |
|
if size is None: |
|
size = raw_memory_size |
|
assert size <= raw_memory_size, f"Make sure the rewind size ({size}) is smaller or equal to the raw memory size ({raw_memory_size})!" |
|
|
|
if size > 0: |
|
self._end_idx -= size |
|
for layer_idx, (key, value) in enumerate(self.raw_activations): |
|
key = slice_tensor(key, end=-size, dim=self.k_seq_dim) |
|
value = slice_tensor(value, end=-size, dim=self.v_seq_dim) |
|
self.raw_activations[layer_idx] = (key, value) |
|
|
|
if trim: |
|
self.all_input_ids = self.all_input_ids[:, :-size] |
|
self.all_attention_mask = self.all_attention_mask[:, :-size] |
|
if hasattr(self, "all_labels"): |
|
self.all_labels = self.all_labels[:, :-size] |
|
|
|
@property |
|
def finish(self): |
|
is_finish = self._end_idx == self.all_sequence_length |
|
|
|
|
|
|
|
|
|
return is_finish |
|
|
|
def get_memory_size(self): |
|
ultragist_memory_size = 0 |
|
raw_memory_size = 0 |
|
sink_memory_size = 0 |
|
if self.l1_to_ln_ultragist_activations[0][0][0] is not None: |
|
ultragist_memory_size += self.l1_to_ln_ultragist_activations[0][0][0].shape[self.k_seq_dim] |
|
if self.raw_activations[0][0] is not None: |
|
raw_memory_size += self.raw_activations[0][0].shape[self.k_seq_dim] |
|
if self.sink_activations[0][0] is not None: |
|
sink_memory_size += self.sink_activations[0][0].shape[self.k_seq_dim] |
|
return ultragist_memory_size, raw_memory_size, sink_memory_size |
|
|
|
def get_memory(self, ultragist_sizes=None, total_ultragist_size=None, raw_size_to_cache=None, window_size=None): |
|
""" |
|
Get the compressed kv cache for generating next tokens. |
|
""" |
|
past_key_values = [] |
|
for layer_idx in range(self.num_layers): |
|
sink_key, sink_value = self.sink_activations[layer_idx] |
|
ultragist_key, ultragist_value = self.l1_to_ln_ultragist_activations[0][layer_idx] |
|
raw_key, raw_value = self.raw_activations[layer_idx] |
|
|
|
key = cat_tensor([ |
|
sink_key, ultragist_key, raw_key, |
|
], dim=self.k_seq_dim) |
|
value = cat_tensor([ |
|
sink_value, ultragist_value, raw_value, |
|
], dim=self.v_seq_dim) |
|
|
|
if ultragist_sizes is not None: |
|
layer_past_key_values = (key, value, ultragist_sizes, total_ultragist_size, raw_size_to_cache, window_size) |
|
else: |
|
layer_past_key_values = (key, value) |
|
|
|
past_key_values.append(layer_past_key_values) |
|
return past_key_values |
|
|
|
def prepare(self, input_ids, attention_mask, labels): |
|
""" |
|
Prepare inputs for the model. These inputs belong to the same sequence. |
|
""" |
|
assert input_ids.shape[0] == 1, "Make sure the batch size is 1!" |
|
assert attention_mask is None or (attention_mask == 1).all(), "Make sure there is no padding!" |
|
|
|
if not hasattr(self, "_device"): |
|
self._device = input_ids.device |
|
|
|
|
|
self.all_input_ids = torch.cat([self.all_input_ids, input_ids.cpu()], dim=1) |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
self.all_attention_mask = torch.cat([self.all_attention_mask, attention_mask.cpu()], dim=1) |
|
self.all_sequence_length = self.all_input_ids.shape[1] |
|
|
|
if labels is not None: |
|
|
|
labels = torch.cat([labels[:, 1:].cpu(), torch.tensor([-100]).expand(labels.shape[0], 1)], dim=1) |
|
if not hasattr(self, "all_labels"): |
|
self.all_labels = labels |
|
else: |
|
self.all_labels = torch.cat([self.all_labels, labels], dim=1) |
|
assert self.all_input_ids.shape[1] == self.all_labels.shape[1], f"Found inconsistent all_input_ids {self.all_input_ids.shape} and all_labels {self.all_labels.shape}!" |
|
|
|
def set_compression_ratio(self, start_idx, end_idx): |
|
"""Choose a condensing ratio from self.ultragist_ratio""" |
|
def filter_ratio(ratios, stride): |
|
valid_ratios = [] |
|
for ratio in ratios: |
|
|
|
if stride < ratio: |
|
continue |
|
|
|
if ratio > 0 and (stride % ratio) != 0: |
|
continue |
|
|
|
if ratio == 0 and self.training: |
|
previous_has_zero = -1 in self._main_ultragist_sizes |
|
following_has_nonzero = (start_idx + stride + self.ultragist_window) <= self.all_sequence_length |
|
if previous_has_zero or (not following_has_nonzero): |
|
continue |
|
valid_ratios.append(ratio) |
|
assert len(valid_ratios), f"Cannot find valid condensing ratio (among {ratios}) for stride {stride}!" |
|
return valid_ratios |
|
|
|
def get_max_length(ratios): |
|
max_lengths = [] |
|
for condensing_ratio in ratios: |
|
if condensing_ratio > 0: |
|
max_lengths.append((self.max_position_embeddings - self.ultragist_window) * condensing_ratio + self.ultragist_window) |
|
else: |
|
max_lengths.append(self.max_position_embeddings) |
|
return max_lengths |
|
|
|
if len(self.ultragist_ratio) == 1: |
|
return [self.ultragist_ratio[0]] |
|
|
|
ratio_mix = self.ultragist_ratio_mix |
|
|
|
ultragist_ratio = filter_ratio(self.ultragist_ratio, self.ultragist_stride) |
|
|
|
if ratio_mix == "instance-random": |
|
if self._ratio is None: |
|
ultragist_ratio = self.rng.choice(ultragist_ratio, size=1).tolist() |
|
self._ratio = ultragist_ratio |
|
else: |
|
ultragist_ratio = self._ratio |
|
|
|
elif ratio_mix == "step-random": |
|
ultragist_ratio = self.rng.choice(ultragist_ratio, size=1).tolist() |
|
|
|
elif ratio_mix == "sequence": |
|
if self._ultragist_ratio_iter is None: |
|
self._ultragist_ratio_iter = cycle(ultragist_ratio) |
|
ultragist_ratio = [next(self._ultragist_ratio_iter)] |
|
|
|
elif ratio_mix == "join": |
|
ultragist_ratio = ultragist_ratio |
|
|
|
elif "adapt" in ratio_mix: |
|
if self._ratio is None: |
|
future_length = int(ratio_mix.split("-")[1]) |
|
sequence_length = self.all_input_ids.shape[1] + future_length |
|
max_lengths = get_max_length(ultragist_ratio) |
|
|
|
valid_max_lengths_and_indices = [x for x in enumerate(max_lengths) if x[1] >= sequence_length] |
|
if len(valid_max_lengths_and_indices): |
|
minimum_length_index = min(valid_max_lengths_and_indices, key=lambda x: x[1])[0] |
|
|
|
ultragist_ratio = [ultragist_ratio[minimum_length_index]] |
|
else: |
|
ultragist_ratio = [max(ultragist_ratio)] |
|
|
|
self._ratio = ultragist_ratio |
|
else: |
|
ultragist_ratio = self._ratio |
|
|
|
return ultragist_ratio |
|
|
|
def step(self): |
|
""" |
|
Yield one window with the following logic: |
|
|
|
The window size is L, the stride is S. |
|
The window moves over S tokens at a time. The raw activations passed by the window are condensed according to a condensing_ratio. |
|
The ultragists are added if and only if the raw activations fulfill the window. |
|
In the future, we may switch window size to decrease cache size of raw activations. |
|
""" |
|
|
|
start_idx = self._start_idx |
|
|
|
end_idx = start_idx + self.ultragist_window |
|
|
|
|
|
|
|
if end_idx > self.all_sequence_length: |
|
|
|
end_idx = self.all_sequence_length |
|
is_full_window = False |
|
else: |
|
is_full_window = True |
|
|
|
|
|
|
|
if self.training and end_idx == self.all_sequence_length: |
|
is_full_window = False |
|
|
|
|
|
window_size = end_idx - start_idx |
|
|
|
if is_full_window: |
|
ultragist_stride = self.ultragist_stride |
|
|
|
compression_ratios = self.set_compression_ratio(start_idx=start_idx, end_idx=end_idx) |
|
|
|
ultragist_sizes = [] |
|
for condensing_ratio in compression_ratios: |
|
if condensing_ratio > 0: |
|
|
|
ultragist_sizes.append(ultragist_stride // condensing_ratio) |
|
else: |
|
|
|
ultragist_sizes.append(-1) |
|
|
|
next_start_idx = start_idx + ultragist_stride |
|
|
|
raw_size_to_cache = end_idx - next_start_idx |
|
|
|
else: |
|
|
|
next_start_idx = start_idx |
|
|
|
raw_size_to_cache = window_size |
|
ultragist_sizes = [0] |
|
compression_ratios = [0] |
|
|
|
total_ultragist_size = sum(s for s in ultragist_sizes if s >= 0) |
|
|
|
past_key_values = self.get_memory( |
|
ultragist_sizes, |
|
total_ultragist_size, |
|
raw_size_to_cache, |
|
window_size |
|
) |
|
|
|
|
|
input_ids = self.all_input_ids[:, self._end_idx: end_idx].to(self._device) |
|
attention_mask = self.all_attention_mask[:, self._end_idx: end_idx].to(self._device) |
|
if hasattr(self, "all_labels"): |
|
labels = self.all_labels[:, self._end_idx: end_idx].to(self._device) |
|
else: |
|
labels = None |
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
if is_full_window: |
|
if total_ultragist_size > 0: |
|
input_ids = torch.cat([input_ids, self.ultragist_tokens.expand(batch_size, total_ultragist_size).to(input_ids.device, dtype=input_ids.dtype)], dim=1) |
|
|
|
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(batch_size, total_ultragist_size)], dim=1) |
|
if labels is not None: |
|
labels = torch.cat([labels, labels.new_zeros(batch_size, total_ultragist_size) - 100], dim=1) |
|
|
|
|
|
first_key = past_key_values[0][0] |
|
memory_size = first_key.shape[self.k_seq_dim] if first_key is not None else 0 |
|
if memory_size > 0: |
|
attention_mask = torch.cat([attention_mask.new_ones(batch_size, memory_size), attention_mask], dim=1) |
|
|
|
|
|
self._total_ultragist_sizes.append(total_ultragist_size) |
|
|
|
self._main_ultragist_sizes.append(ultragist_sizes[0]) |
|
|
|
|
|
self._start_idx = next_start_idx |
|
self._end_idx = end_idx |
|
self._step_idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return input_ids, attention_mask, past_key_values, labels |
|
|
|
def update_memory(self, past_key_values): |
|
""" |
|
Accumulate ultragist activations and raw activations. |
|
""" |
|
for layer_idx, (key, value, ultragist_sizes, total_ultragist_size, raw_size_to_cache, window_size) in enumerate(past_key_values): |
|
|
|
|
|
|
|
|
|
|
|
|
|
previous_raw_key, previous_raw_value = self.raw_activations[layer_idx] |
|
|
|
if self._step_idx == 1: |
|
|
|
|
|
self.sink_activations[layer_idx] = [ |
|
slice_tensor(key, end=self.ultragist_sink_size, dim=self.k_seq_dim), |
|
slice_tensor(value, end=self.ultragist_sink_size, dim=self.v_seq_dim), |
|
] |
|
|
|
if ultragist_sizes == [0]: |
|
|
|
|
|
assert raw_size_to_cache == window_size |
|
raw_key = cat_tensor([ |
|
previous_raw_key, |
|
key |
|
], dim=self.k_seq_dim) |
|
raw_value = cat_tensor([ |
|
previous_raw_value, |
|
value |
|
], dim=self.v_seq_dim) |
|
self.raw_activations[layer_idx] = (raw_key, raw_value) |
|
|
|
else: |
|
for ultragist_size_idx, ultragist_size in enumerate(ultragist_sizes): |
|
|
|
previous_ultragist_key, previous_ultragist_value = self.l1_to_ln_ultragist_activations[ultragist_size_idx][layer_idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultragist_key, ultragist_value, raw_key, raw_value = self._extract_ultragist_and_raw_memory(key, value, previous_ultragist_key, previous_ultragist_value, previous_raw_key, previous_raw_value, raw_size_to_cache, total_ultragist_size, ultragist_sizes, ultragist_size_idx) |
|
|
|
self.l1_to_ln_ultragist_activations[ultragist_size_idx][layer_idx] = (ultragist_key, ultragist_value) |
|
if ultragist_size_idx == 0: |
|
self.raw_activations[layer_idx] = (raw_key, raw_value) |
|
|
|
|
|
|
|
|
|
def update_loss(self, batch_loss, valid_token_num): |
|
""" |
|
Accumulate loss for later perplexity computation and backward pass; past_key_values according to cache_method. |
|
""" |
|
|
|
if self._batch_loss is None: |
|
|
|
self._batch_loss = batch_loss * valid_token_num |
|
self._valid_token_num = valid_token_num |
|
else: |
|
|
|
self._batch_loss = self._batch_loss + batch_loss * valid_token_num |
|
self._valid_token_num = self._valid_token_num + valid_token_num |
|
|
|
def output(self, model_outputs): |
|
""" |
|
Override loss with accumulated loss. |
|
""" |
|
|
|
if self._batch_loss is not None: |
|
|
|
loss = self._batch_loss.sum() / self._valid_token_num.sum() |
|
|
|
|
|
batch_loss = self._batch_loss / self._valid_token_num |
|
if (self._valid_token_num == 0).any(): |
|
batch_loss = batch_loss.masked_fill(self._valid_token_num == 0, 0.) |
|
|
|
|
|
model_outputs["loss"] = loss |
|
model_outputs["batch_loss"] = batch_loss |
|
model_outputs["valid_token_num"] = self._valid_token_num |
|
|
|
|
|
ultragist_size = self._total_ultragist_sizes[-1] |
|
|
|
if ultragist_size > 0: |
|
model_outputs["logits"] = model_outputs["logits"][:, :-ultragist_size] |
|
|
|
return model_outputs |
|
|
|
def _extract_ultragist_and_raw_memory(self, key, value, previous_ultragist_key, previous_ultragist_value, previous_raw_key, previous_raw_value, raw_size_to_cache, total_ultragist_size, ultragist_sizes, ultragist_size_idx): |
|
"""Extract ultragist and raw memory from the returned key and value. The raw memory is computed only if the ultragist_size_idx == 0.""" |
|
ultragist_size = ultragist_sizes[ultragist_size_idx] |
|
|
|
previous_ultragist_size = sum(x for x in ultragist_sizes[:ultragist_size_idx] if x > 0) |
|
|
|
if previous_ultragist_key is not None: |
|
target_device = previous_ultragist_key.device |
|
else: |
|
if ultragist_size_idx == 0: |
|
target_device = self._device |
|
else: |
|
target_device = self._cpu |
|
|
|
if ultragist_size == -1: |
|
actual_ultragist_size = self.ultragist_window - raw_size_to_cache |
|
|
|
|
|
concat_raw_key = cat_tensor([ |
|
previous_raw_key, |
|
key |
|
], dim=self.k_seq_dim) |
|
concat_raw_value = cat_tensor([ |
|
previous_raw_value, |
|
value |
|
], dim=self.v_seq_dim) |
|
|
|
ultragist_key = cat_tensor([ |
|
previous_ultragist_key, |
|
slice_tensor(concat_raw_key, end=actual_ultragist_size, dim=self.k_seq_dim).to(target_device, non_blocking=True) |
|
], dim=self.k_seq_dim) |
|
ultragist_value = cat_tensor([ |
|
previous_ultragist_value, |
|
slice_tensor(concat_raw_value, end=actual_ultragist_size, dim=self.v_seq_dim).to(target_device, non_blocking=True) |
|
], dim=self.v_seq_dim) |
|
|
|
if ultragist_size_idx == 0: |
|
raw_key = slice_tensor(concat_raw_key, start=actual_ultragist_size, end=self.ultragist_window, dim=self.k_seq_dim) |
|
raw_value = slice_tensor(concat_raw_value, start=actual_ultragist_size, end=self.ultragist_window, dim=self.v_seq_dim) |
|
|
|
else: |
|
|
|
|
|
|
|
ultragist_start_idx = - total_ultragist_size + previous_ultragist_size |
|
ultragist_end_idx = ultragist_start_idx + ultragist_size |
|
|
|
|
|
if ultragist_end_idx == 0: |
|
ultragist_end_idx = None |
|
|
|
ultragist_key = cat_tensor([ |
|
previous_ultragist_key, |
|
slice_tensor(key, start=ultragist_start_idx, end=ultragist_end_idx, dim=self.k_seq_dim).to(target_device, non_blocking=True) |
|
], dim=self.k_seq_dim) |
|
ultragist_value = cat_tensor([ |
|
previous_ultragist_value, |
|
slice_tensor(value, start=ultragist_start_idx, end=ultragist_end_idx, dim=self.v_seq_dim).to(target_device, non_blocking=True) |
|
], dim=self.v_seq_dim) |
|
|
|
|
|
if ultragist_size_idx == 0: |
|
if key.shape[self.k_seq_dim] < raw_size_to_cache + ultragist_size: |
|
concat_raw_key = cat_tensor([ |
|
previous_raw_key, |
|
key |
|
], dim=self.k_seq_dim) |
|
concat_raw_value = cat_tensor([ |
|
previous_raw_value, |
|
value |
|
], dim=self.v_seq_dim) |
|
raw_key = slice_tensor(concat_raw_key, start=self.ultragist_window - raw_size_to_cache, end=self.ultragist_window, dim=self.k_seq_dim) |
|
raw_value = slice_tensor(concat_raw_value, start=self.ultragist_window - raw_size_to_cache, end=self.ultragist_window, dim=self.v_seq_dim) |
|
else: |
|
|
|
raw_key = slice_tensor(key, start=ultragist_start_idx - raw_size_to_cache, end=ultragist_start_idx, dim=self.k_seq_dim) |
|
raw_value = slice_tensor(value, start=ultragist_start_idx - raw_size_to_cache, end=ultragist_start_idx, dim=self.v_seq_dim) |
|
|
|
if ultragist_size_idx == 0: |
|
return ultragist_key, ultragist_value, raw_key, raw_value |
|
else: |
|
|
|
return ultragist_key.detach().to(target_device, non_blocking=True), ultragist_value.detach().to(target_device, non_blocking=True), None, None |
|
|
|
|
|
|
|
def slice_tensor(x, start=None, end=None, dim=2): |
|
if x is None: |
|
return None |
|
if end == 0: |
|
return None |
|
if start == x.shape[dim]: |
|
return None |
|
if start == end: |
|
return None |
|
if dim == 2: |
|
if start is None and end is not None: |
|
return x[:, :, :end, ...] |
|
elif start is not None and end is None: |
|
return x[:, :, start:, ...] |
|
elif start is not None and end is not None: |
|
return x[:, :, start:end, ...] |
|
elif dim == 1: |
|
if start is None and end is not None: |
|
return x[:, :end, ...] |
|
elif start is not None and end is None: |
|
return x[:, start:, ...] |
|
elif start is not None and end is not None: |
|
return x[:, start:end, ...] |
|
else: |
|
raise NotImplementedError |
|
|
|
def cat_tensor(list_of_tensors, dim=-1): |
|
list_of_tensors = [t for t in list_of_tensors if t is not None] |
|
if len(list_of_tensors) > 1: |
|
result = torch.cat(list_of_tensors, dim=dim) |
|
elif len(list_of_tensors) == 1: |
|
result = list_of_tensors[0] |
|
else: |
|
result = None |
|
return result |
|
|
|
def slice_activations(activations, start=None, end=None, k_seq_dim=2, v_seq_dim=2): |
|
new_activations = [] |
|
for key, value in activations: |
|
new_key = slice_tensor(key, start=start, end=end, dim=k_seq_dim) |
|
new_value = slice_tensor(value, start=start, end=end, dim=v_seq_dim) |
|
new_activations.append([new_key, new_value]) |
|
return new_activations |
|
|
|
def cat_activations(list_of_activations, k_seq_dim=2, v_seq_dim=2): |
|
assert all(len(x) == len(list_of_activations[0]) for x in list_of_activations), f"Make sure all activations have the same number of layers! Found {[len(x) for x in list_of_activations]}." |
|
|
|
new_activations = [] |
|
for layer_idx in range(len(list_of_activations[0])): |
|
keys = [x[layer_idx][0] for x in list_of_activations] |
|
values = [x[layer_idx][1] for x in list_of_activations] |
|
|
|
new_key = cat_tensor(keys, dim=k_seq_dim) |
|
new_value = cat_tensor(values, dim=v_seq_dim) |
|
new_activations.append([new_key, new_value]) |
|
return new_activations |
|
|
|
def interleave_activations(main_activations, augment_activations, main_spans, augment_spans, k_seq_dim=2, v_seq_dim=2, device=torch.device("cuda")): |
|
""" Interleave main_activations and augment_activations according to main_span and augment_span. |
|
|
|
Args: |
|
main_span: a list of tuples (start_idx, end_idx). when start_idx and end_idx is None, the augment_activations will be plugged in. |
|
augment_span: a list of tuples (start_idx, end_idx) |
|
""" |
|
assert len(main_activations) == len(augment_activations) , f"Make sure main and augment activations have the same number of layers! Found {len(main_activations)} and {len(augment_activations)}!" |
|
assert sum(x[0] is None and x[1] is None for x in main_spans) == len(augment_spans), f"Make sure the number of slots for augmentation (start_idx=None and end_idx=None in main_spans) matches the number of augmentations. Found {sum(x for x in main_spans if x[0] is None and x[1] is None)} slots but {len(augment_spans)} augmentations!" |
|
|
|
new_activations = [] |
|
for layer_idx in range(len(main_activations)): |
|
main_key, main_value = main_activations[layer_idx] |
|
augment_key, augment_value = augment_activations[layer_idx] |
|
|
|
sliced_keys = [] |
|
sliced_values = [] |
|
|
|
augment_idx = 0 |
|
for start, end in main_spans: |
|
if start is None and end is None: |
|
|
|
augment_start, augment_end = augment_spans[augment_idx] |
|
sliced_key = slice_tensor( |
|
augment_key, |
|
start=augment_start, |
|
end=augment_end, |
|
dim=k_seq_dim |
|
).to(device) |
|
sliced_value = slice_tensor( |
|
augment_value, |
|
start=augment_start, |
|
end=augment_end, |
|
dim=v_seq_dim |
|
).to(device) |
|
|
|
else: |
|
sliced_key = slice_tensor( |
|
main_key, |
|
start=start, |
|
end=end, |
|
dim=k_seq_dim |
|
) |
|
sliced_value = slice_tensor( |
|
main_value, |
|
start=start, |
|
end=end, |
|
dim=v_seq_dim |
|
) |
|
|
|
sliced_keys.append(sliced_key) |
|
sliced_values.append(sliced_value) |
|
|
|
new_key = cat_tensor(sliced_keys, dim=k_seq_dim) |
|
new_value = cat_tensor(sliced_values, dim=v_seq_dim) |
|
new_activations.append([new_key, new_value]) |
|
|
|
return new_activations |
|
|
|
def softmax(x:np.ndarray, axis=-1, temperature=1): |
|
if isinstance(x, list): |
|
x = np.array(x) |
|
x = x / temperature |
|
x = x - x.max(axis=axis, keepdims=True) |
|
y = np.exp(x) |
|
return y / y.sum(axis=axis, keepdims=True) |
|
|
|
def l1_norm(x): |
|
sum_x = sum(x) |
|
x = [y/sum_x for y in x] |
|
return x |