File size: 8,831 Bytes
8026e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import numpy as np
import torch as t
import torch.nn as nn
from jukebox.vqvae.encdec import Encoder, Decoder, assert_shape
from jukebox.vqvae.bottleneck import NoBottleneck, Bottleneck
from jukebox.utils.logger import average_metrics
from jukebox.utils.audio_utils import spectral_convergence, spectral_loss, multispectral_loss, audio_postprocess
def dont_update(params):
for param in params:
param.requires_grad = False
def update(params):
for param in params:
param.requires_grad = True
def calculate_strides(strides, downs):
return [stride ** down for stride, down in zip(strides, downs)]
def _loss_fn(loss_fn, x_target, x_pred, hps):
if loss_fn == 'l1':
return t.mean(t.abs(x_pred - x_target)) / hps.bandwidth['l1']
elif loss_fn == 'l2':
return t.mean((x_pred - x_target) ** 2) / hps.bandwidth['l2']
elif loss_fn == 'linf':
residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
values, _ = t.topk(residual, hps.linf_k, dim=1)
return t.mean(values) / hps.bandwidth['l2']
elif loss_fn == 'lmix':
loss = 0.0
if hps.lmix_l1:
loss += hps.lmix_l1 * _loss_fn('l1', x_target, x_pred, hps)
if hps.lmix_l2:
loss += hps.lmix_l2 * _loss_fn('l2', x_target, x_pred, hps)
if hps.lmix_linf:
loss += hps.lmix_linf * _loss_fn('linf', x_target, x_pred, hps)
return loss
else:
assert False, f"Unknown loss_fn {loss_fn}"
class VQVAE(nn.Module):
def __init__(self, input_shape, levels, downs_t, strides_t,
emb_width, l_bins, mu, commit, spectral, multispectral,
multipliers=None, use_bottleneck=True, **block_kwargs):
super().__init__()
self.sample_length = input_shape[0]
x_shape, x_channels = input_shape[:-1], input_shape[-1]
self.x_shape = x_shape
self.downsamples = calculate_strides(strides_t, downs_t)
self.hop_lengths = np.cumprod(self.downsamples)
self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)]
self.levels = levels
if multipliers is None:
self.multipliers = [1] * levels
else:
assert len(multipliers) == levels, "Invalid number of multipliers"
self.multipliers = multipliers
def _block_kwargs(level):
this_block_kwargs = dict(block_kwargs)
this_block_kwargs["width"] *= self.multipliers[level]
this_block_kwargs["depth"] *= self.multipliers[level]
return this_block_kwargs
encoder = lambda level: Encoder(x_channels, emb_width, level + 1,
downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
decoder = lambda level: Decoder(x_channels, emb_width, level + 1,
downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level))
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
for level in range(levels):
self.encoders.append(encoder(level))
self.decoders.append(decoder(level))
if use_bottleneck:
self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels)
else:
self.bottleneck = NoBottleneck(levels)
self.downs_t = downs_t
self.strides_t = strides_t
self.l_bins = l_bins
self.commit = commit
self.spectral = spectral
self.multispectral = multispectral
def preprocess(self, x):
# x: NTC [-1,1] -> NCT [-1,1]
assert len(x.shape) == 3
x = x.permute(0,2,1).float()
return x
def postprocess(self, x):
# x: NTC [-1,1] <- NCT [-1,1]
x = x.permute(0,2,1)
return x
def _decode(self, zs, start_level=0, end_level=None):
# Decode
if end_level is None:
end_level = self.levels
assert len(zs) == end_level - start_level
xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level)
assert len(xs_quantised) == end_level - start_level
# Use only lowest level
decoder, x_quantised = self.decoders[start_level], xs_quantised[0:1]
x_out = decoder(x_quantised, all_levels=False)
x_out = self.postprocess(x_out)
return x_out
def decode(self, zs, start_level=0, end_level=None, bs_chunks=1):
z_chunks = [t.chunk(z, bs_chunks, dim=0) for z in zs]
x_outs = []
for i in range(bs_chunks):
zs_i = [z_chunk[i] for z_chunk in z_chunks]
x_out = self._decode(zs_i, start_level=start_level, end_level=end_level)
x_outs.append(x_out)
return t.cat(x_outs, dim=0)
def _encode(self, x, start_level=0, end_level=None):
# Encode
if end_level is None:
end_level = self.levels
x_in = self.preprocess(x)
xs = []
for level in range(self.levels):
encoder = self.encoders[level]
x_out = encoder(x_in)
xs.append(x_out[-1])
zs = self.bottleneck.encode(xs)
return zs[start_level:end_level]
def encode(self, x, start_level=0, end_level=None, bs_chunks=1):
x_chunks = t.chunk(x, bs_chunks, dim=0)
zs_list = []
for x_i in x_chunks:
zs_i = self._encode(x_i, start_level=start_level, end_level=end_level)
zs_list.append(zs_i)
zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)]
return zs
def sample(self, n_samples):
zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device='cuda') for z_shape in self.z_shapes]
return self.decode(zs)
def forward(self, x, hps, loss_fn='l1'):
metrics = {}
N = x.shape[0]
# Encode/Decode
x_in = self.preprocess(x)
xs = []
for level in range(self.levels):
encoder = self.encoders[level]
x_out = encoder(x_in)
xs.append(x_out[-1])
zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs)
x_outs = []
for level in range(self.levels):
decoder = self.decoders[level]
x_out = decoder(xs_quantised[level:level+1], all_levels=False)
assert_shape(x_out, x_in.shape)
x_outs.append(x_out)
# Loss
def _spectral_loss(x_target, x_out, hps):
if hps.use_nonrelative_specloss:
sl = spectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
else:
sl = spectral_convergence(x_target, x_out, hps)
sl = t.mean(sl)
return sl
def _multispectral_loss(x_target, x_out, hps):
sl = multispectral_loss(x_target, x_out, hps) / hps.bandwidth['spec']
sl = t.mean(sl)
return sl
recons_loss = t.zeros(()).to(x.device)
spec_loss = t.zeros(()).to(x.device)
multispec_loss = t.zeros(()).to(x.device)
x_target = audio_postprocess(x.float(), hps)
for level in reversed(range(self.levels)):
x_out = self.postprocess(x_outs[level])
x_out = audio_postprocess(x_out, hps)
this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps)
this_spec_loss = _spectral_loss(x_target, x_out, hps)
this_multispec_loss = _multispectral_loss(x_target, x_out, hps)
metrics[f'recons_loss_l{level + 1}'] = this_recons_loss
metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss
metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss
recons_loss += this_recons_loss
spec_loss += this_spec_loss
multispec_loss += this_multispec_loss
commit_loss = sum(commit_losses)
loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss
with t.no_grad():
sc = t.mean(spectral_convergence(x_target, x_out, hps))
l2_loss = _loss_fn("l2", x_target, x_out, hps)
l1_loss = _loss_fn("l1", x_target, x_out, hps)
linf_loss = _loss_fn("linf", x_target, x_out, hps)
quantiser_metrics = average_metrics(quantiser_metrics)
metrics.update(dict(
recons_loss=recons_loss,
spectral_loss=spec_loss,
multispectral_loss=multispec_loss,
spectral_convergence=sc,
l2_loss=l2_loss,
l1_loss=l1_loss,
linf_loss=linf_loss,
commit_loss=commit_loss,
**quantiser_metrics))
for key, val in metrics.items():
metrics[key] = val.detach()
return x_out, loss, metrics
|