MasonCrinr's picture
Upload 331 files
8026e91
import torch as t
import torch.nn as nn
from jukebox.vqvae.resnet import Resnet, Resnet1D
from jukebox.utils.torch_utils import assert_shape
class EncoderConvBlock(nn.Module):
def __init__(self, input_emb_width, output_emb_width, down_t,
stride_t, width, depth, m_conv,
dilation_growth_rate=1, dilation_cycle=None, zero_out=False,
res_scale=False):
super().__init__()
blocks = []
filter_t, pad_t = stride_t * 2, stride_t // 2
if down_t > 0:
for i in range(down_t):
block = nn.Sequential(
nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t),
Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale),
)
blocks.append(block)
block = nn.Conv1d(width, output_emb_width, 3, 1, 1)
blocks.append(block)
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class DecoderConvBock(nn.Module):
def __init__(self, input_emb_width, output_emb_width, down_t,
stride_t, width, depth, m_conv, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_decoder_dilation=False, checkpoint_res=False):
super().__init__()
blocks = []
if down_t > 0:
filter_t, pad_t = stride_t * 2, stride_t // 2
block = nn.Conv1d(output_emb_width, width, 3, 1, 1)
blocks.append(block)
for i in range(down_t):
block = nn.Sequential(
Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out, res_scale=res_scale, reverse_dilation=reverse_decoder_dilation, checkpoint_res=checkpoint_res),
nn.ConvTranspose1d(width, input_emb_width if i == (down_t - 1) else width, filter_t, stride_t, pad_t)
)
blocks.append(block)
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class Encoder(nn.Module):
def __init__(self, input_emb_width, output_emb_width, levels, downs_t,
strides_t, **block_kwargs):
super().__init__()
self.input_emb_width = input_emb_width
self.output_emb_width = output_emb_width
self.levels = levels
self.downs_t = downs_t
self.strides_t = strides_t
block_kwargs_copy = dict(**block_kwargs)
if 'reverse_decoder_dilation' in block_kwargs_copy:
del block_kwargs_copy['reverse_decoder_dilation']
level_block = lambda level, down_t, stride_t: EncoderConvBlock(input_emb_width if level == 0 else output_emb_width,
output_emb_width,
down_t, stride_t,
**block_kwargs_copy)
self.level_blocks = nn.ModuleList()
iterator = zip(list(range(self.levels)), downs_t, strides_t)
for level, down_t, stride_t in iterator:
self.level_blocks.append(level_block(level, down_t, stride_t))
def forward(self, x):
N, T = x.shape[0], x.shape[-1]
emb = self.input_emb_width
assert_shape(x, (N, emb, T))
xs = []
# 64, 32, ...
iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t)
for level, down_t, stride_t in iterator:
level_block = self.level_blocks[level]
x = level_block(x)
emb, T = self.output_emb_width, T // (stride_t ** down_t)
assert_shape(x, (N, emb, T))
xs.append(x)
return xs
class Decoder(nn.Module):
def __init__(self, input_emb_width, output_emb_width, levels, downs_t,
strides_t, **block_kwargs):
super().__init__()
self.input_emb_width = input_emb_width
self.output_emb_width = output_emb_width
self.levels = levels
self.downs_t = downs_t
self.strides_t = strides_t
level_block = lambda level, down_t, stride_t: DecoderConvBock(output_emb_width,
output_emb_width,
down_t, stride_t,
**block_kwargs)
self.level_blocks = nn.ModuleList()
iterator = zip(list(range(self.levels)), downs_t, strides_t)
for level, down_t, stride_t in iterator:
self.level_blocks.append(level_block(level, down_t, stride_t))
self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1)
def forward(self, xs, all_levels=True):
if all_levels:
assert len(xs) == self.levels
else:
assert len(xs) == 1
x = xs[-1]
N, T = x.shape[0], x.shape[-1]
emb = self.output_emb_width
assert_shape(x, (N, emb, T))
# 32, 64 ...
iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t)))
for level, down_t, stride_t in iterator:
level_block = self.level_blocks[level]
x = level_block(x)
emb, T = self.output_emb_width, T * (stride_t ** down_t)
assert_shape(x, (N, emb, T))
if level != 0 and all_levels:
x = x + xs[level - 1]
x = self.out(x)
return x