MasonCriner / jukebox /transformer /factored_attention.py
MasonCrinr's picture
Upload 331 files
8026e91
# Factored attention
import math
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from jukebox.transformer.ops import Conv1D
from jukebox.utils.checkpoint import checkpoint
def repeat(x, n, dim):
if dim == -1:
dim = len(x.shape) - 1
return x.view(int(np.prod(x.shape[:dim+1])), 1, int(np.prod(x.shape[dim+1:]))).repeat(1,n,1).view(*x.shape[:dim], n * x.shape[dim], *x.shape[dim+1:])
def get_mask(mask, q_l, kv_l, blocks, spread, device, sample, sample_t):
# returns a mask of shape 1 x 1 x q_l x kv_l or None if masking is not needed.
if mask is None or q_l == 1:
return None
offset = sample_t - q_l if sample else max(kv_l - q_l, 0)
if mask == 'autoregressive':
# Masked dense
mask = t.ones(q_l, kv_l, device=device).tril(offset)
elif mask == 'summary':
# Masked summary
mask = t.nn.functional.pad(t.ones(q_l, q_l, device=device).tril().view(q_l, blocks, q_l // blocks)[:,:-1,-kv_l//blocks:],(0,0,1,0),value=1).contiguous().view(q_l, kv_l)
elif mask == 'prime':
mask = t.ones(q_l, kv_l, device=device).tril(offset)
return mask.view(1,1,q_l,kv_l)
class FactoredAttention(nn.Module):
def __init__(self, n_in, n_ctx, n_state, n_head,
attn_dropout=0.0, resid_dropout=0.0,
scale=True, mask=False,
zero_out=False, init_scale=1.0,
checkpoint_attn=0,
attn_func=0, blocks=None, spread=None,
encoder_dims=None, prime_len=None):
super().__init__()
self.n_in = n_in
self.n_ctx = n_ctx # NOTE: n_ctx could be different within operations. This is complete n_ctx
self.n_state = n_state
assert n_state % n_head == 0
self.n_head = n_head
self.scale = scale
self.mask = mask
if attn_func == 6:
self.c_attn = Conv1D(n_in, n_state, init_scale=init_scale)
self.c_enc_kv = Conv1D(n_in, n_state * 2, init_scale=init_scale)
else:
self.c_attn = Conv1D(n_in, n_state * 3, init_scale=init_scale)
self.c_proj = Conv1D(n_state, n_in, zero_out, init_scale=init_scale)
self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0.0 else lambda x: x
self.resid_dropout = nn.Dropout(resid_dropout) if resid_dropout > 0.0 else lambda x: x
# Sequence of length l is factored as [blocks, l // blocks]
self.attn_func = attn_func
self.qkv, self.attn, self.attn_mask = {
0: (self.factored_qkv, self.dense_attn, 'autoregressive'), # Attend to all positions
1: (self.factored_qkv, self.block_attn, 'autoregressive'), # Attend to your block
2: (self.factored_qkv, self.transpose_block_attn, 'autoregressive'), # Attend to transpose block
3: (self.factored_qkv, self.prev_block_attn, None), # Attend to previous block
4: (self.factored_qkv, self.summary_attn, 'summary'), # Attend to last position of each block
5: (self.factored_qkv, self.summary_spread_attn, 'summary'),
6: (self.decode_qkv, self.decode_attn, None),
7: (self.prime_qkv, self.prime_attn, 'prime')
}[attn_func] # Attend to last k position of each block
self.blocks = blocks
self.spread = spread
if blocks is not None:
assert n_ctx % blocks == 0
self.block_ctx = n_ctx // blocks
self.checkpoint_attn = checkpoint_attn # 0: None, 1: Attn after heads split, 2: Attn
self.sample_t = 0
self.cache = {}
self.encoder_dims = encoder_dims
self.prime_len = prime_len
self.record_attn = False
self.w = None
def _attn(self, q, k, v, sample):
scale = 1. / math.sqrt(math.sqrt(self.n_state // self.n_head))
if self.training:
w = t.matmul(q * scale, k * scale)
else:
w = t.matmul(q, k)
w.mul_(scale*scale)
wtype = w.dtype
w = w.float()
if self.mask:
# Generate appropriate mask to mask out all positions before current
# Might take up lot of memory for dense, so can cache it
mask = get_mask(self.attn_mask, q.size(-2), k.size(-1), self.blocks, self.spread, w.device, sample, self.sample_t)
if mask is not None:
#print(mask)
w = w * mask + -1e9 * (1 - mask)
w = F.softmax(w, dim=-1).type(wtype)
else:
w = F.softmax(w, dim=-1).type(wtype)
if self.record_attn:
self.w = w #.float().cpu().numpy()
if self.attn_func == 7:
# only keep music queries and lyrics keys/values
self.w = self.w[:,:,self.prime_len:,:self.prime_len]
w = self.attn_dropout(w)
a = t.matmul(w, v)
return a
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = (*x.size()[:-2], x.size(-2) * x.size(-1))
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = (*x.size()[:-1], self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
def dense_attn(self, query, key, value, sample):
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if self.checkpoint_attn == 1 and not sample:
a = checkpoint(lambda q,k,v,s=sample: self._attn(q,k,v,s), (query, key, value),
(), True)
else:
a = self._attn(query,key,value,sample)
a = self.merge_heads(a)
return a
def block_attn(self, q, k, v, sample):
blocks, block_ctx = self.blocks, self.block_ctx # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
if sample:
assert l == self._suff_cache_len(), f"{l} != {self._suff_cache_len()}"
return self.dense_attn(q, k, v, sample).view(bs, 1, d)
else:
ql = q.shape[1]
q = q.view(bs * ql // block_ctx, block_ctx, d)
if ql < l:
l = ql
k = k[:, -l:].contiguous()
v = v[:, -l:].contiguous()
k = k.view(bs * l // block_ctx, block_ctx, d)
v = v.view(bs * l // block_ctx, block_ctx, d)
return self.dense_attn(q, k, v, sample).view(bs, l, d)
def transpose_block_attn(self, q, k, v, sample):
blocks, block_ctx = self.blocks, self.block_ctx # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
if sample:
block_l = (l - 1) % block_ctx
k = k[:,block_l::block_ctx,:]
v = v[:,block_l::block_ctx,:]
return self.dense_attn(q, k, v, sample).view(bs, 1, d)
else:
ql = q.shape[1]
q = q.view(bs, ql // block_ctx, block_ctx, d).transpose(1,2).contiguous().view(bs * block_ctx, ql // block_ctx, d)
k = k.view(bs, l // block_ctx, block_ctx, d).transpose(1,2).contiguous().view(bs * block_ctx, l // block_ctx, d)
v = v.view(bs, l // block_ctx, block_ctx, d).transpose(1,2).contiguous().view(bs * block_ctx, l // block_ctx, d)
return self.dense_attn(q, k, v, sample).view(bs, block_ctx, ql // block_ctx, d).transpose(1,2).contiguous().view(bs, ql, d)
def prev_block_attn(self, q, k, v, sample):
blocks, block_ctx = self.blocks, self.block_ctx # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
if sample:
assert l == self._suff_cache_len(), f"{l} != {self._suff_cache_len()}"
block = (l - 1) // block_ctx
prev_l = (block - 1) * block_ctx
if block > 0:
assert prev_l == 0
k = k[:, prev_l:prev_l + block_ctx, :]
v = v[:, prev_l:prev_l + block_ctx, :]
else:
k = t.zeros(bs, block_ctx, d, device=q.device, dtype=q.dtype)
v = t.zeros(bs, block_ctx, d, device=q.device, dtype=q.dtype)
return self.dense_attn(q, k, v, sample).view(bs, 1, d)
else:
ql = q.shape[1]
q = q.view(bs * ql // block_ctx, block_ctx, d)
k = t.nn.functional.pad(k.view(bs, l // block_ctx, block_ctx, d)[:, :-1, :, :], (0,0,0,0,1,0)).view(bs * l // block_ctx, block_ctx, d)
v = t.nn.functional.pad(v.view(bs, l // block_ctx, block_ctx, d)[:, :-1, :, :], (0,0,0,0,1,0)).view(bs * l // block_ctx, block_ctx, d)
if ql < l:
qb = ql // block_ctx
kb = l // block_ctx
l = ql
k = k.view(bs, kb, block_ctx, d)[:, -qb:].contiguous().view(bs * qb, block_ctx, d)
v = v.view(bs, kb, block_ctx, d)[:, -qb:].contiguous().view(bs * qb, block_ctx, d)
return self.dense_attn(q, k, v, sample).view(bs, l, d)
def summary_attn(self, q, k, v, sample):
blocks, block_ctx = self.blocks, self.block_ctx # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
if sample:
k = t.nn.functional.pad(k[:, block_ctx-1:blocks*block_ctx-1:block_ctx, :],(0,0,1,0))
v = t.nn.functional.pad(v[:, block_ctx-1:blocks*block_ctx-1:block_ctx, :],(0,0,1,0))
return self.dense_attn(q, k, v, sample).view(bs, 1, d)
else:
k = t.nn.functional.pad(k.view(bs, blocks, l // blocks, d)[:, :-1, -1, :],(0,0,1,0)) # bs, blocks, d
v = t.nn.functional.pad(v.view(bs, blocks, l // blocks, d)[:, :-1, -1, :],(0,0,1,0)) # bs, blocks, d
return self.dense_attn(q, k, v, sample).view(bs, l, d)
def summary_spread_attn(self, q, k, v, sample):
blocks, block_ctx, spread = self.blocks, self.block_ctx, self.spread # block_ctx is l // blocks for complete l ie l = n_ctx. Sampling has less l
bs, l, d = v.shape # For sample, q_l = 1, k_l = v_l = sample_t
if sample:
assert False, "Not yet implemented"
# k = t.nn.functional.pad(k,(0,0,block_ctx,(-l)%block_ctx)).view(bs, -1, block_ctx, d)[:,:-1,-spread:,:].contiguous().view(bs, -1, d)
# v = t.nn.functional.pad(v,(0,0,block_ctx,(-l)%block_ctx)).view(bs, -1, block_ctx, d)[:,:-1,-spread:,:].contiguous().view(bs, -1, d)
# return self.dense_attn(q, k, v, sample).view(bs, 1, d)
else:
k = t.nn.functional.pad(k.view(bs, blocks, l // blocks, d)[:, :-1, -spread:, :],(0,0,0,0,1,0)).contiguous().view(bs, blocks * spread, d) # bs, blocks * spread, d
v = t.nn.functional.pad(v.view(bs, blocks, l // blocks, d)[:, :-1, -spread:, :],(0,0,0,0,1,0)).contiguous().view(bs, blocks * spread, d) # bs, blocks * spread, d
return self.dense_attn(q, k, v, sample).view(bs, l, d)
def prime_attn(self, q, k, v, sample):
prime_len = self._prime_len
k = k[:, :prime_len]
v = v[:, :prime_len]
return self.dense_attn(q, k, v, sample)
def decode_attn(self, q, k, v, sample):
assert k.shape[1] == v.shape[1] == self.encoder_dims, f'k: {k.shape}, v: {v.shape}, enc_dims: {self.encoder_dims}'
return self.dense_attn(q, k, v, sample)
def factored_qkv(self, x, encoder_kv=None, sample=False):
curr_ctx = x.shape[1]
assert encoder_kv is None
query, key, value = x.chunk(3, dim=2)
if sample:
self.sample_t += curr_ctx
key, value = self._append_cache(key, value)
l_cache = self._suff_cache_len()
if self._cache_len() > l_cache:
self._slice_cache(-l_cache)
if curr_ctx > 1:
if self.attn_func != 0:
query = self._pad_to_block_ctx(query, query=True)
key = self._pad_to_block_ctx(key)
value = self._pad_to_block_ctx(value)
assert key.shape[1] % self.block_ctx == 0
assert query.shape[1] % self.block_ctx == 0
assert key.shape[1] == value.shape[1]
assert query.shape[1] <= key.shape[1]
sample = False
else:
key = self.cache['key']
value = self.cache['value']
return query, key, value, sample
def prime_qkv(self, x, encoder_kv=None, sample=False):
curr_ctx = x.shape[1]
assert encoder_kv is None
query, key, value = x.chunk(3, dim=2)
if sample:
if self._cache_len() < self._prime_len:
self._append_cache(key, value)
if self._cache_len() > self._prime_len:
self._slice_cache(0, self._prime_len)
key, value = self.cache['key'], self.cache['value']
self.sample_t += curr_ctx
assert key.shape[1] == value.shape[1] == self._suff_cache_len(), f'k: {key.shape}, v: {value.shape}, prime_dims: {self._suff_cache_len()}'
else:
assert key.shape[1] == value.shape[1] == self.n_ctx, f'k: {key.shape}, v: {value.shape}, prime_dims: {self.n_ctx}'
assert key.shape[0] == value.shape[0] == query.shape[0], f'k: {key.shape}, v: {value.shape}, q: {query.shape}'
assert key.shape[2] == value.shape[2] == query.shape[2], f'k: {key.shape}, v: {value.shape}, q: {query.shape}'
return query, key, value, sample
def decode_qkv(self, x, encoder_kv=None, sample=False):
curr_ctx = x.shape[1]
assert encoder_kv is not None
query = x
if sample:
if self.sample_t == 0:
self.cache['key'], self.cache['value'] = self.c_enc_kv(encoder_kv.type_as(x)).chunk(2, dim=2)
key, value = self.cache['key'], self.cache['value']
self.sample_t += curr_ctx
else:
key, value = self.c_enc_kv(encoder_kv.type_as(x)).chunk(2, dim=2)
assert key.shape[0] == value.shape[0] == query.shape[0], f'k: {key.shape}, v: {value.shape}, q: {query.shape}'
assert key.shape[1] == value.shape[1] == self.encoder_dims, f'k: {key.shape}, v: {value.shape}, enc_dims: {self.encoder_dims}'
assert key.shape[2] == value.shape[2] == query.shape[2], f'k: {key.shape}, v: {value.shape}, q: {query.shape}'
return query, key, value, sample
def forward(self, x, encoder_kv=None, sample=False):
curr_ctx = x.shape[1]
x = self.c_attn(x)
query, key, value, sample = self.qkv(x, encoder_kv=encoder_kv, sample=sample)
if self.checkpoint_attn == 2 and not sample:
a = checkpoint(lambda q,k,v,s=sample: self.attn(q,k,v,s), (query, key, value), (), True)
else:
a = self.attn(query,key,value,sample)
if a.shape[1] != curr_ctx:
offset = self._offset(curr_ctx)
a = a[:,offset:offset + curr_ctx,:].contiguous()
a = self.c_proj(a)
return self.resid_dropout(a)
@property
def _prime_len(self):
prime_len = self.prime_len
assert prime_len is not None
prime_blocks = (prime_len // self.blocks) + 1
return prime_blocks * self.blocks
def _offset(self, curr_ctx):
if self.attn_func == 0:
return 0
return (self.sample_t - curr_ctx) % self.block_ctx
def _pad_to_block_ctx(self, x, query=False):
l = x.shape[1]
offset = self._offset(l) if query else 0
n_blocks = (l + offset + self.block_ctx - 1) // self.block_ctx
pad = n_blocks * self.block_ctx - l - offset
if pad == 0 and offset == 0:
return x
else:
return F.pad(x, (0, 0, offset, pad))
def _cache_len(self):
return 0 if 'key' not in self.cache else self.cache['key'].shape[1]
def _suff_cache_len(self):
"""
Precondition:
key and value are appended with the current context and
self.sample_t reflects the 1-indexed sample location in the
context.
"""
if self.attn_func == 0:
return self.sample_t
elif self.attn_func == 1:
return (self.sample_t - 1) % self.block_ctx + 1
elif self.attn_func == 2:
return self.sample_t
elif self.attn_func == 3:
if self.sample_t <= self.block_ctx:
return self.sample_t
else:
curr_block = (self.sample_t - 1) % self.block_ctx + 1
prev_block = self.block_ctx
return curr_block + prev_block
elif self.attn_func == 6:
return self.encoder_dims
elif self.attn_func == 7:
return min(self.sample_t, self._prime_len)
else:
raise NotImplementedError()
def _slice_cache(self, start, end=None):
self.cache['key'] = self.cache['key'][:, start:end]
self.cache['value'] = self.cache['value'][:, start:end]
def _append_cache(self, key, value):
if 'key' not in self.cache:
self.cache['key'] = key
self.cache['value'] = value
else:
old_key, old_value = key, value
key = t.cat([self.cache['key'], key], dim=1)
value = t.cat([self.cache['value'], value], dim=1)
del self.cache['key']
del self.cache['value']
del old_key
del old_value
self.cache['key'] = key
self.cache['value'] = value
return self.cache['key'], self.cache['value']
def del_cache(self):
self.sample_t = 0
if 'key' in self.cache:
del self.cache['key']
if 'value' in self.cache:
del self.cache['value']
self.cache = {}
def check(self):
blocks = self.blocks or 1
spread = self.spread or 1
bs, l, d = (4, self.n_ctx, self.n_in)
x = t.randn(bs, l, d).cuda()
x.requires_grad = True
x_out = self.forward(x) # bs, l, d
loss = x_out.mean(dim = -1) # bs, l
pos = 60
grad = t.autograd.grad(loss[2, pos], x)[0]
assert grad.shape == (bs, l, d)
assert (grad[:2] == 0).all()
assert (grad[3:] == 0).all()
assert (grad[2, (pos + 1):] == 0).all()
pos_grad = (t.sum(grad[2] ** 2, dim=-1) > 0).nonzero().view(-1).cpu()
block_pos = pos - (pos % (l // blocks))
exp_pos_grad = {0: t.arange(pos),
1: t.arange(block_pos, pos),
2: t.arange(pos % (l // blocks), pos, l // blocks),
3: t.arange(block_pos - l // blocks, block_pos),
4: t.arange(l // blocks - 1, pos, l // blocks),
5: ((t.arange(pos) % (l // blocks) >= (l // blocks - spread)) & (t.arange(pos) < block_pos)).nonzero().view(-1)}[self.attn_func]
exp_pos_grad = t.cat([exp_pos_grad, t.tensor([pos])], dim=-1)
assert (len(pos_grad) == len(exp_pos_grad)) and (pos_grad == exp_pos_grad).all(), \
f"Expected pos grad {exp_pos_grad} got {pos_grad} for attn_func {self.attn_func} pos {pos} l {l} blocks {blocks}"
def check_cache(self, n_samples, sample_t, fp16):
assert self.sample_t == sample_t, f"{self.sample_t} != {sample_t}"
if sample_t == 0:
assert self.cache == {}
else:
dtype = {True: t.float16, False: t.float32}[fp16]
l_cache = self._suff_cache_len()
assert self.cache['key'].shape == (n_samples, l_cache, self.n_state)
assert self.cache['value'].shape == (n_samples, l_cache, self.n_state)
assert self.cache['key'].dtype == dtype, f"Expected {dtype}, got {self.cache['key'].dtype}"
assert self.cache['value'].dtype == dtype, f"Expected {dtype}, got {self.cache['value'].dtype}"
def check_sample(self):
t.manual_seed(42)
bs, l, d = (4, self.n_ctx, self.n_in)
prime = 5
x = t.randn(bs, l, d).cuda()
xs = t.chunk(x, l, dim=1)
assert self.sample_t == 0
assert self.cache == {}
with t.no_grad():
enc_l = self.encoder_dims
encoder_kv = None
if self.attn_func == 6:
encoder_kv = t.randn(bs, enc_l, d).cuda()
# Normal path
x_out_normal = self.forward(x, encoder_kv=encoder_kv)
# Sampling path
x_out_sample = t.cat([self.forward(xs[i], encoder_kv=encoder_kv, sample=True) for i in range(l)],dim=1)
max_err = t.max(t.abs(x_out_sample - x_out_normal))
assert max_err < 1e-8, f"Max sampling err is {max_err} {[i for i in range(l) if t.max(t.abs(x_out_sample - x_out_normal)[:,i,:]) > 1e-8]}"
with t.no_grad():
x_out_normal = x_out_normal[:,:prime,:]
# Prime sampling path
self.del_cache()
x_out_sample = self.forward(x[:,:prime,:].contiguous(), encoder_kv=encoder_kv, sample=True)
self.check_cache(bs, prime, False)
max_err = t.max(t.abs(x_out_sample - x_out_normal))
assert max_err < 1e-8, f"Max prime sampling err is {max_err} {[i for i in range(prime) if t.max(t.abs(x_out_sample - x_out_normal)[:,i,:]) > 1e-8]}"
def check_chunks(self, chunk_size):
t.manual_seed(42)
bs, l, d = (4, self.n_ctx, self.n_in)
enc_l = self.encoder_dims
assert l % chunk_size == 0
n_chunks = l // chunk_size
with t.no_grad():
encoder_kv = None
x = t.randn(bs, l, d).cuda()
if self.attn_func == 6:
encoder_kv = t.randn(bs, enc_l, d).cuda()
self.del_cache()
y_forw = self.forward(x, encoder_kv=encoder_kv, sample=False)
self.del_cache()
y_forw_sample = self.forward(x, encoder_kv=encoder_kv, sample=True)
max_err = t.max(t.abs(y_forw - y_forw_sample))
assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_sample)[:, i, :]) > 1e-6]}"
self.del_cache()
x_chunks = t.chunk(x, n_chunks, dim=1)
y_chunks = []
total_len = 0
for x_chunk in x_chunks:
y_chunk = self.forward(x_chunk.contiguous(), encoder_kv=encoder_kv, sample=True)
total_len += x_chunk.shape[1]
self.check_cache(bs, total_len, False)
y_chunks.append(y_chunk)
y_forw_in_chunks = t.cat(y_chunks, dim=1)
max_err = t.max(t.abs(y_forw - y_forw_in_chunks))
assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_in_chunks)[:, i, :]) > 1e-6]}"
if __name__ == '__main__':
from jukebox.utils.dist_utils import setup_dist_from_mpi
setup_dist_from_mpi(port=29600)
n_in = 16
n_state = n_in * 2
n_ctx = 6144
n_head = 4
n_depth = 12
blocks = 64
chunk_size = 8
for attn_func in [0, 1, 2, 3, 6, 7]:
encoder_dims = {0: 0, 1: 0, 2: 0, 3: 0, 6: 64, 7: 0}[attn_func]
prime_len = {0: 0, 1: 0, 2: 0, 3: 0, 6: 0, 7: 384}[attn_func]
attn = FactoredAttention(n_in, n_ctx + prime_len, n_state, n_head, mask=True,
attn_func=attn_func, blocks=blocks,
encoder_dims=encoder_dims, prime_len=prime_len)
attn.training = False
attn.check_sample()
attn.check_chunks(chunk_size)
print(f"Checked attn_func: {attn_func}")