Integrate new diffusion network
Browse files- api.py +24 -25
- models/arch_util.py +13 -8
- models/diffusion_decoder.py +155 -360
api.py
CHANGED
@@ -49,6 +49,15 @@ def download_models():
|
|
49 |
print('Done.')
|
50 |
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
|
53 |
"""
|
54 |
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
@@ -96,26 +105,25 @@ def fix_autoregressive_output(codes, stop_token):
|
|
96 |
return codes
|
97 |
|
98 |
|
99 |
-
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes,
|
100 |
"""
|
101 |
-
Uses the specified diffusion model
|
102 |
"""
|
103 |
with torch.no_grad():
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
mel = torch.nn.functional.pad(mel_codes, (0, gap))
|
111 |
|
112 |
-
output_shape = (
|
113 |
-
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes,
|
114 |
|
115 |
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
|
116 |
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
117 |
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
118 |
-
return denormalize_tacotron_mel(mel)[:,:,:
|
119 |
|
120 |
|
121 |
class TextToSpeech:
|
@@ -137,12 +145,9 @@ class TextToSpeech:
|
|
137 |
use_xformers=True).cpu().eval()
|
138 |
self.clip.load_state_dict(torch.load('.models/clip.pth'))
|
139 |
|
140 |
-
self.diffusion = DiffusionTts(model_channels=
|
141 |
-
|
142 |
-
|
143 |
-
dropout=0, attention_resolutions=[4, 8], num_heads=8, kernel_size=3, scale_factor=2,
|
144 |
-
time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
|
145 |
-
conditioning_expansion=1).cpu().eval()
|
146 |
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
|
147 |
|
148 |
self.vocoder = UnivNetGenerator().cpu()
|
@@ -164,12 +169,6 @@ class TextToSpeech:
|
|
164 |
for vs in voice_samples:
|
165 |
conds.append(load_conditioning(vs))
|
166 |
conds = torch.stack(conds, dim=1)
|
167 |
-
cond_diffusion = voice_samples[0].cuda()
|
168 |
-
# The diffusion model expects = 88200 conditioning samples.
|
169 |
-
if cond_diffusion.shape[-1] < 88200:
|
170 |
-
cond_diffusion = F.pad(cond_diffusion, (0, 88200-cond_diffusion.shape[-1]))
|
171 |
-
else:
|
172 |
-
cond_diffusion = cond_diffusion[:, :88200]
|
173 |
|
174 |
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
175 |
|
@@ -211,7 +210,7 @@ class TextToSpeech:
|
|
211 |
self.vocoder = self.vocoder.cuda()
|
212 |
for b in range(best_results.shape[0]):
|
213 |
code = best_results[b].unsqueeze(0)
|
214 |
-
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code,
|
215 |
wav = self.vocoder.inference(mel)
|
216 |
wav_candidates.append(wav.cpu())
|
217 |
self.diffusion = self.diffusion.cpu()
|
|
|
49 |
print('Done.')
|
50 |
|
51 |
|
52 |
+
def pad_or_truncate(t, length):
|
53 |
+
if t.shape[-1] == length:
|
54 |
+
return t
|
55 |
+
elif t.shape[-1] < length:
|
56 |
+
return F.pad(t, (0, length-t.shape[-1]))
|
57 |
+
else:
|
58 |
+
return t[..., :length]
|
59 |
+
|
60 |
+
|
61 |
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
|
62 |
"""
|
63 |
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
|
|
105 |
return codes
|
106 |
|
107 |
|
108 |
+
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1):
|
109 |
"""
|
110 |
+
Uses the specified diffusion model to convert discrete codes into a spectrogram.
|
111 |
"""
|
112 |
with torch.no_grad():
|
113 |
+
cond_mels = []
|
114 |
+
for sample in conditioning_samples:
|
115 |
+
sample = pad_or_truncate(sample, 102400)
|
116 |
+
cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False)
|
117 |
+
cond_mels.append(cond_mel)
|
118 |
+
cond_mels = torch.stack(cond_mels, dim=1)
|
|
|
119 |
|
120 |
+
output_shape = (mel_codes.shape[0], 100, mel_codes.shape[-1]*4)
|
121 |
+
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, False)
|
122 |
|
123 |
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
|
124 |
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
125 |
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
126 |
+
return denormalize_tacotron_mel(mel)[:,:,:mel_codes.shape[-1]*4]
|
127 |
|
128 |
|
129 |
class TextToSpeech:
|
|
|
145 |
use_xformers=True).cpu().eval()
|
146 |
self.clip.load_state_dict(torch.load('.models/clip.pth'))
|
147 |
|
148 |
+
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
149 |
+
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
150 |
+
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
|
|
|
|
|
|
151 |
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
|
152 |
|
153 |
self.vocoder = UnivNetGenerator().cpu()
|
|
|
169 |
for vs in voice_samples:
|
170 |
conds.append(load_conditioning(vs))
|
171 |
conds = torch.stack(conds, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
174 |
|
|
|
210 |
self.vocoder = self.vocoder.cuda()
|
211 |
for b in range(best_results.shape[0]):
|
212 |
code = best_results[b].unsqueeze(0)
|
213 |
+
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, voice_samples, temperature=diffusion_temperature)
|
214 |
wav = self.vocoder.inference(mel)
|
215 |
wav_candidates.append(wav.cpu())
|
216 |
self.diffusion = self.diffusion.cpu()
|
models/arch_util.py
CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
|
|
6 |
import torch.nn.functional as F
|
7 |
import torchaudio
|
8 |
from x_transformers import ContinuousTransformerWrapper
|
|
|
9 |
|
10 |
|
11 |
def zero_module(module):
|
@@ -49,7 +50,7 @@ class QKVAttentionLegacy(nn.Module):
|
|
49 |
super().__init__()
|
50 |
self.n_heads = n_heads
|
51 |
|
52 |
-
def forward(self, qkv, mask=None):
|
53 |
"""
|
54 |
Apply QKV attention.
|
55 |
|
@@ -64,6 +65,8 @@ class QKVAttentionLegacy(nn.Module):
|
|
64 |
weight = torch.einsum(
|
65 |
"bct,bcs->bts", q * scale, k * scale
|
66 |
) # More stable with f16 than dividing afterwards
|
|
|
|
|
67 |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
68 |
if mask is not None:
|
69 |
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
@@ -87,9 +90,12 @@ class AttentionBlock(nn.Module):
|
|
87 |
channels,
|
88 |
num_heads=1,
|
89 |
num_head_channels=-1,
|
|
|
|
|
90 |
):
|
91 |
super().__init__()
|
92 |
self.channels = channels
|
|
|
93 |
if num_head_channels == -1:
|
94 |
self.num_heads = num_heads
|
95 |
else:
|
@@ -99,21 +105,20 @@ class AttentionBlock(nn.Module):
|
|
99 |
self.num_heads = channels // num_head_channels
|
100 |
self.norm = normalization(channels)
|
101 |
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
|
|
102 |
self.attention = QKVAttentionLegacy(self.num_heads)
|
103 |
|
104 |
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
105 |
-
|
106 |
-
|
107 |
-
if mask is not None:
|
108 |
-
return self._forward(x, mask)
|
109 |
else:
|
110 |
-
|
111 |
|
112 |
-
def
|
113 |
b, c, *spatial = x.shape
|
114 |
x = x.reshape(b, c, -1)
|
115 |
qkv = self.qkv(self.norm(x))
|
116 |
-
h = self.attention(qkv, mask)
|
117 |
h = self.proj_out(h)
|
118 |
return (x + h).reshape(b, c, *spatial)
|
119 |
|
|
|
6 |
import torch.nn.functional as F
|
7 |
import torchaudio
|
8 |
from x_transformers import ContinuousTransformerWrapper
|
9 |
+
from x_transformers.x_transformers import RelativePositionBias
|
10 |
|
11 |
|
12 |
def zero_module(module):
|
|
|
50 |
super().__init__()
|
51 |
self.n_heads = n_heads
|
52 |
|
53 |
+
def forward(self, qkv, mask=None, rel_pos=None):
|
54 |
"""
|
55 |
Apply QKV attention.
|
56 |
|
|
|
65 |
weight = torch.einsum(
|
66 |
"bct,bcs->bts", q * scale, k * scale
|
67 |
) # More stable with f16 than dividing afterwards
|
68 |
+
if rel_pos is not None:
|
69 |
+
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
70 |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
71 |
if mask is not None:
|
72 |
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
|
|
90 |
channels,
|
91 |
num_heads=1,
|
92 |
num_head_channels=-1,
|
93 |
+
do_checkpoint=True,
|
94 |
+
relative_pos_embeddings=False,
|
95 |
):
|
96 |
super().__init__()
|
97 |
self.channels = channels
|
98 |
+
self.do_checkpoint = do_checkpoint
|
99 |
if num_head_channels == -1:
|
100 |
self.num_heads = num_heads
|
101 |
else:
|
|
|
105 |
self.num_heads = channels // num_head_channels
|
106 |
self.norm = normalization(channels)
|
107 |
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
108 |
+
# split heads before split qkv
|
109 |
self.attention = QKVAttentionLegacy(self.num_heads)
|
110 |
|
111 |
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
112 |
+
if relative_pos_embeddings:
|
113 |
+
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
|
|
|
|
114 |
else:
|
115 |
+
self.relative_pos_embeddings = None
|
116 |
|
117 |
+
def forward(self, x, mask=None):
|
118 |
b, c, *spatial = x.shape
|
119 |
x = x.reshape(b, c, -1)
|
120 |
qkv = self.qkv(self.norm(x))
|
121 |
+
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
122 |
h = self.proj_out(h)
|
123 |
return (x + h).reshape(b, c, *spatial)
|
124 |
|
models/diffusion_decoder.py
CHANGED
@@ -1,22 +1,13 @@
|
|
1 |
-
"""
|
2 |
-
This model is based on OpenAI's UNet from improved diffusion, with modifications to support a MEL conditioning signal
|
3 |
-
and an audio conditioning input. It has also been simplified somewhat.
|
4 |
-
Credit: https://github.com/openai/improved-diffusion
|
5 |
-
"""
|
6 |
-
import functools
|
7 |
import math
|
|
|
8 |
from abc import abstractmethod
|
9 |
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
13 |
from torch import autocast
|
14 |
-
from torch.nn import Linear
|
15 |
-
from torch.utils.checkpoint import checkpoint
|
16 |
-
from x_transformers import ContinuousTransformerWrapper, Encoder
|
17 |
|
18 |
-
from models.arch_util import normalization,
|
19 |
-
CheckpointedXTransformerEncoder
|
20 |
|
21 |
|
22 |
def is_latent(t):
|
@@ -27,13 +18,6 @@ def is_sequence(t):
|
|
27 |
return t.dtype == torch.long
|
28 |
|
29 |
|
30 |
-
def ceil_multiple(base, multiple):
|
31 |
-
res = base % multiple
|
32 |
-
if res == 0:
|
33 |
-
return base
|
34 |
-
return base + (multiple - res)
|
35 |
-
|
36 |
-
|
37 |
def timestep_embedding(timesteps, dim, max_period=10000):
|
38 |
"""
|
39 |
Create sinusoidal timestep embeddings.
|
@@ -56,10 +40,6 @@ def timestep_embedding(timesteps, dim, max_period=10000):
|
|
56 |
|
57 |
|
58 |
class TimestepBlock(nn.Module):
|
59 |
-
"""
|
60 |
-
Any module where forward() takes timestep embeddings as a second argument.
|
61 |
-
"""
|
62 |
-
|
63 |
@abstractmethod
|
64 |
def forward(self, x, emb):
|
65 |
"""
|
@@ -68,11 +48,6 @@ class TimestepBlock(nn.Module):
|
|
68 |
|
69 |
|
70 |
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
71 |
-
"""
|
72 |
-
A sequential module that passes timestep embeddings to the children that
|
73 |
-
support it as an extra input.
|
74 |
-
"""
|
75 |
-
|
76 |
def forward(self, x, emb):
|
77 |
for layer in self:
|
78 |
if isinstance(layer, TimestepBlock):
|
@@ -89,6 +64,7 @@ class ResBlock(TimestepBlock):
|
|
89 |
emb_channels,
|
90 |
dropout,
|
91 |
out_channels=None,
|
|
|
92 |
kernel_size=3,
|
93 |
efficient_config=True,
|
94 |
use_scale_shift_norm=False,
|
@@ -111,7 +87,7 @@ class ResBlock(TimestepBlock):
|
|
111 |
|
112 |
self.emb_layers = nn.Sequential(
|
113 |
nn.SiLU(),
|
114 |
-
Linear(
|
115 |
emb_channels,
|
116 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
117 |
),
|
@@ -120,9 +96,7 @@ class ResBlock(TimestepBlock):
|
|
120 |
normalization(self.out_channels),
|
121 |
nn.SiLU(),
|
122 |
nn.Dropout(p=dropout),
|
123 |
-
|
124 |
-
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
125 |
-
),
|
126 |
)
|
127 |
|
128 |
if self.out_channels == channels:
|
@@ -131,18 +105,6 @@ class ResBlock(TimestepBlock):
|
|
131 |
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
|
132 |
|
133 |
def forward(self, x, emb):
|
134 |
-
"""
|
135 |
-
Apply the block to a Tensor, conditioned on a timestep embedding.
|
136 |
-
|
137 |
-
:param x: an [N x C x ...] Tensor of features.
|
138 |
-
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
139 |
-
:return: an [N x C x ...] Tensor of outputs.
|
140 |
-
"""
|
141 |
-
return checkpoint(
|
142 |
-
self._forward, x, emb
|
143 |
-
)
|
144 |
-
|
145 |
-
def _forward(self, x, emb):
|
146 |
h = self.in_layers(x)
|
147 |
emb_out = self.emb_layers(emb).type(h.dtype)
|
148 |
while len(emb_out.shape) < len(h.shape):
|
@@ -158,372 +120,205 @@ class ResBlock(TimestepBlock):
|
|
158 |
return self.skip_connection(x) + h
|
159 |
|
160 |
|
161 |
-
class
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
GPT-style model.
|
167 |
-
|
168 |
-
:param in_channels: channels in the input Tensor.
|
169 |
-
:param in_latent_channels: channels from the input latent.
|
170 |
-
:param model_channels: base channel count for the model.
|
171 |
-
:param out_channels: channels in the output Tensor.
|
172 |
-
:param num_res_blocks: number of residual blocks per downsample.
|
173 |
-
:param attention_resolutions: a collection of downsample rates at which
|
174 |
-
attention will take place. May be a set, list, or tuple.
|
175 |
-
For example, if this contains 4, then at 4x downsampling, attention
|
176 |
-
will be used.
|
177 |
-
:param dropout: the dropout probability.
|
178 |
-
:param channel_mult: channel multiplier for each level of the UNet.
|
179 |
-
:param conv_resample: if True, use learned convolutions for upsampling and
|
180 |
-
downsampling.
|
181 |
-
:param num_heads: the number of attention heads in each attention layer.
|
182 |
-
:param num_heads_channels: if specified, ignore num_heads and instead use
|
183 |
-
a fixed channel width per attention head.
|
184 |
-
:param num_heads_upsample: works with num_heads to set a different number
|
185 |
-
of heads for upsampling. Deprecated.
|
186 |
-
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
187 |
-
:param resblock_updown: use residual blocks for up/downsampling.
|
188 |
-
:param use_new_attention_order: use a different attention pattern for potentially
|
189 |
-
increased efficiency.
|
190 |
-
"""
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
def __init__(
|
193 |
self,
|
194 |
-
model_channels,
|
195 |
-
|
196 |
-
|
|
|
197 |
in_tokens=8193,
|
198 |
-
|
199 |
-
conditioning_expansion=4,
|
200 |
-
out_channels=2, # mean and variance
|
201 |
dropout=0,
|
202 |
-
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
203 |
-
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
204 |
-
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
205 |
-
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
206 |
-
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
207 |
-
token_conditioning_resolutions=(1,16,),
|
208 |
-
attention_resolutions=(512,1024,2048),
|
209 |
-
conv_resample=True,
|
210 |
use_fp16=False,
|
211 |
-
num_heads=
|
212 |
-
num_head_channels=-1,
|
213 |
-
num_heads_upsample=-1,
|
214 |
-
kernel_size=3,
|
215 |
-
scale_factor=2,
|
216 |
-
time_embed_dim_multiplier=4,
|
217 |
-
freeze_main_net=False,
|
218 |
-
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
219 |
-
use_scale_shift_norm=True,
|
220 |
# Parameters for regularization.
|
|
|
221 |
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
222 |
-
# Parameters for super-sampling.
|
223 |
-
super_sampling=False,
|
224 |
-
super_sampling_max_noising_factor=.1,
|
225 |
):
|
226 |
super().__init__()
|
227 |
|
228 |
-
if num_heads_upsample == -1:
|
229 |
-
num_heads_upsample = num_heads
|
230 |
-
|
231 |
-
if super_sampling:
|
232 |
-
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
|
233 |
self.in_channels = in_channels
|
234 |
self.model_channels = model_channels
|
235 |
self.out_channels = out_channels
|
236 |
-
self.attention_resolutions = attention_resolutions
|
237 |
self.dropout = dropout
|
238 |
-
self.channel_mult = channel_mult
|
239 |
-
self.conv_resample = conv_resample
|
240 |
self.num_heads = num_heads
|
241 |
-
self.num_head_channels = num_head_channels
|
242 |
-
self.num_heads_upsample = num_heads_upsample
|
243 |
-
self.super_sampling_enabled = super_sampling
|
244 |
-
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
245 |
self.unconditioned_percentage = unconditioned_percentage
|
246 |
self.enable_fp16 = use_fp16
|
247 |
-
self.
|
248 |
-
self.freeze_main_net = freeze_main_net
|
249 |
-
padding = 1 if kernel_size == 3 else 2
|
250 |
-
down_kernel = 1 if efficient_convs else 3
|
251 |
|
252 |
-
|
253 |
self.time_embed = nn.Sequential(
|
254 |
-
Linear(model_channels,
|
255 |
nn.SiLU(),
|
256 |
-
Linear(
|
257 |
)
|
258 |
|
259 |
-
conditioning_dim = model_channels * conditioning_dim_factor
|
260 |
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
261 |
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
262 |
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
263 |
# transformer network.
|
|
|
264 |
self.code_converter = nn.Sequential(
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
max_seq_len=-1,
|
269 |
-
use_pos_emb=False,
|
270 |
-
attn_layers=Encoder(
|
271 |
-
dim=conditioning_dim,
|
272 |
-
depth=3,
|
273 |
-
heads=num_heads,
|
274 |
-
ff_dropout=dropout,
|
275 |
-
attn_dropout=dropout,
|
276 |
-
use_rmsnorm=True,
|
277 |
-
ff_glu=True,
|
278 |
-
rotary_emb_dim=True,
|
279 |
-
)
|
280 |
-
))
|
281 |
-
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
|
282 |
-
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
|
283 |
-
if in_channels > 60: # It's a spectrogram.
|
284 |
-
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2),
|
285 |
-
CheckpointedXTransformerEncoder(
|
286 |
-
needs_permute=True,
|
287 |
-
max_seq_len=-1,
|
288 |
-
use_pos_emb=False,
|
289 |
-
attn_layers=Encoder(
|
290 |
-
dim=conditioning_dim,
|
291 |
-
depth=4,
|
292 |
-
heads=num_heads,
|
293 |
-
ff_dropout=dropout,
|
294 |
-
attn_dropout=dropout,
|
295 |
-
use_rmsnorm=True,
|
296 |
-
ff_glu=True,
|
297 |
-
rotary_emb_dim=True,
|
298 |
-
)
|
299 |
-
))
|
300 |
-
else:
|
301 |
-
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
|
302 |
-
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
|
303 |
-
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
|
304 |
-
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
305 |
-
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
306 |
-
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
307 |
-
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
308 |
-
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
309 |
-
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
310 |
-
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
311 |
)
|
312 |
-
self.
|
313 |
-
|
314 |
-
self.
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
self.
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
328 |
-
if ds in token_conditioning_resolutions:
|
329 |
-
token_conditioning_block = nn.Conv1d(conditioning_dim, ch, 1)
|
330 |
-
token_conditioning_block.weight.data *= .02
|
331 |
-
self.input_blocks.append(token_conditioning_block)
|
332 |
-
token_conditioning_blocks.append(token_conditioning_block)
|
333 |
-
|
334 |
-
for _ in range(num_blocks):
|
335 |
-
layers = [
|
336 |
-
ResBlock(
|
337 |
-
ch,
|
338 |
-
time_embed_dim,
|
339 |
-
dropout,
|
340 |
-
out_channels=int(mult * model_channels),
|
341 |
-
kernel_size=kernel_size,
|
342 |
-
efficient_config=efficient_convs,
|
343 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
344 |
-
)
|
345 |
-
]
|
346 |
-
ch = int(mult * model_channels)
|
347 |
-
if ds in attention_resolutions:
|
348 |
-
layers.append(
|
349 |
-
AttentionBlock(
|
350 |
-
ch,
|
351 |
-
num_heads=num_heads,
|
352 |
-
num_head_channels=num_head_channels,
|
353 |
-
)
|
354 |
-
)
|
355 |
-
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
356 |
-
self._feature_size += ch
|
357 |
-
input_block_chans.append(ch)
|
358 |
-
if level != len(channel_mult) - 1:
|
359 |
-
out_ch = ch
|
360 |
-
self.input_blocks.append(
|
361 |
-
TimestepEmbedSequential(
|
362 |
-
Downsample(
|
363 |
-
ch, conv_resample, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
|
364 |
-
)
|
365 |
-
)
|
366 |
-
)
|
367 |
-
ch = out_ch
|
368 |
-
input_block_chans.append(ch)
|
369 |
-
ds *= 2
|
370 |
-
self._feature_size += ch
|
371 |
-
|
372 |
-
self.middle_block = TimestepEmbedSequential(
|
373 |
-
ResBlock(
|
374 |
-
ch,
|
375 |
-
time_embed_dim,
|
376 |
-
dropout,
|
377 |
-
kernel_size=kernel_size,
|
378 |
-
efficient_config=efficient_convs,
|
379 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
380 |
-
),
|
381 |
-
AttentionBlock(
|
382 |
-
ch,
|
383 |
-
num_heads=num_heads,
|
384 |
-
num_head_channels=num_head_channels,
|
385 |
-
),
|
386 |
-
ResBlock(
|
387 |
-
ch,
|
388 |
-
time_embed_dim,
|
389 |
-
dropout,
|
390 |
-
kernel_size=kernel_size,
|
391 |
-
efficient_config=efficient_convs,
|
392 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
393 |
-
),
|
394 |
)
|
395 |
-
self.
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
ich = input_block_chans.pop()
|
401 |
-
layers = [
|
402 |
-
ResBlock(
|
403 |
-
ch + ich,
|
404 |
-
time_embed_dim,
|
405 |
-
dropout,
|
406 |
-
out_channels=int(model_channels * mult),
|
407 |
-
kernel_size=kernel_size,
|
408 |
-
efficient_config=efficient_convs,
|
409 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
410 |
-
)
|
411 |
-
]
|
412 |
-
ch = int(model_channels * mult)
|
413 |
-
if ds in attention_resolutions:
|
414 |
-
layers.append(
|
415 |
-
AttentionBlock(
|
416 |
-
ch,
|
417 |
-
num_heads=num_heads_upsample,
|
418 |
-
num_head_channels=num_head_channels,
|
419 |
-
)
|
420 |
-
)
|
421 |
-
if level and i == num_blocks:
|
422 |
-
out_ch = ch
|
423 |
-
layers.append(
|
424 |
-
Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
|
425 |
-
)
|
426 |
-
ds //= 2
|
427 |
-
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
428 |
-
self._feature_size += ch
|
429 |
|
430 |
self.out = nn.Sequential(
|
431 |
-
normalization(
|
432 |
nn.SiLU(),
|
433 |
-
|
434 |
)
|
435 |
|
436 |
-
def
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
aligned_conditioning = torch.cat([aligned_conditioning,
|
448 |
-
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
|
449 |
-
else:
|
450 |
-
aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
|
451 |
-
return x, aligned_conditioning
|
452 |
-
|
453 |
-
def timestep_independent(self, aligned_conditioning, conditioning_input):
|
454 |
# Shuffle aligned_latent to BxCxS format
|
455 |
if is_latent(aligned_conditioning):
|
456 |
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
457 |
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
|
470 |
-
def forward(self, x, timesteps, precomputed_aligned_embeddings, conditioning_free=False):
|
471 |
-
assert x.shape[-1] % self.alignment_size == 0
|
472 |
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
code_emb = precomputed_aligned_embeddings
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
first = True
|
484 |
-
time_emb = time_emb.float()
|
485 |
-
h = x
|
486 |
-
hs = []
|
487 |
-
for k, module in enumerate(self.input_blocks):
|
488 |
-
if isinstance(module, nn.Conv1d):
|
489 |
-
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
490 |
-
h = h + h_tok
|
491 |
else:
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
|
|
|
|
|
|
505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
506 |
return out
|
507 |
|
508 |
|
509 |
if __name__ == '__main__':
|
510 |
-
clip = torch.randn(2,
|
511 |
-
aligned_latent = torch.randn(2,388,
|
512 |
-
aligned_sequence = torch.randint(0,8192,(2,
|
513 |
-
cond = torch.randn(2,
|
514 |
ts = torch.LongTensor([600, 600])
|
515 |
-
model = DiffusionTts(
|
516 |
-
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
517 |
-
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
518 |
-
token_conditioning_resolutions=[1,4,16,64],
|
519 |
-
attention_resolutions=[],
|
520 |
-
num_heads=8,
|
521 |
-
kernel_size=3,
|
522 |
-
scale_factor=2,
|
523 |
-
time_embed_dim_multiplier=4,
|
524 |
-
super_sampling=False,
|
525 |
-
efficient_convs=False)
|
526 |
# Test with latent aligned conditioning
|
527 |
-
o = model(clip, ts, aligned_latent, cond)
|
528 |
# Test with sequence aligned conditioning
|
529 |
o = model(clip, ts, aligned_sequence, cond)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import math
|
2 |
+
import random
|
3 |
from abc import abstractmethod
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
from torch import autocast
|
|
|
|
|
|
|
9 |
|
10 |
+
from models.arch_util import normalization, AttentionBlock
|
|
|
11 |
|
12 |
|
13 |
def is_latent(t):
|
|
|
18 |
return t.dtype == torch.long
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def timestep_embedding(timesteps, dim, max_period=10000):
|
22 |
"""
|
23 |
Create sinusoidal timestep embeddings.
|
|
|
40 |
|
41 |
|
42 |
class TimestepBlock(nn.Module):
|
|
|
|
|
|
|
|
|
43 |
@abstractmethod
|
44 |
def forward(self, x, emb):
|
45 |
"""
|
|
|
48 |
|
49 |
|
50 |
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
|
|
|
|
|
|
|
|
|
51 |
def forward(self, x, emb):
|
52 |
for layer in self:
|
53 |
if isinstance(layer, TimestepBlock):
|
|
|
64 |
emb_channels,
|
65 |
dropout,
|
66 |
out_channels=None,
|
67 |
+
dims=2,
|
68 |
kernel_size=3,
|
69 |
efficient_config=True,
|
70 |
use_scale_shift_norm=False,
|
|
|
87 |
|
88 |
self.emb_layers = nn.Sequential(
|
89 |
nn.SiLU(),
|
90 |
+
nn.Linear(
|
91 |
emb_channels,
|
92 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
93 |
),
|
|
|
96 |
normalization(self.out_channels),
|
97 |
nn.SiLU(),
|
98 |
nn.Dropout(p=dropout),
|
99 |
+
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
|
|
|
|
|
100 |
)
|
101 |
|
102 |
if self.out_channels == channels:
|
|
|
105 |
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
|
106 |
|
107 |
def forward(self, x, emb):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
h = self.in_layers(x)
|
109 |
emb_out = self.emb_layers(emb).type(h.dtype)
|
110 |
while len(emb_out.shape) < len(h.shape):
|
|
|
120 |
return self.skip_connection(x) + h
|
121 |
|
122 |
|
123 |
+
class DiffusionLayer(TimestepBlock):
|
124 |
+
def __init__(self, model_channels, dropout, num_heads):
|
125 |
+
super().__init__()
|
126 |
+
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
127 |
+
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
def forward(self, x, time_emb):
|
130 |
+
y = self.resblk(x, time_emb)
|
131 |
+
return self.attn(y)
|
132 |
+
|
133 |
+
|
134 |
+
class DiffusionTts(nn.Module):
|
135 |
def __init__(
|
136 |
self,
|
137 |
+
model_channels=512,
|
138 |
+
num_layers=8,
|
139 |
+
in_channels=100,
|
140 |
+
in_latent_channels=512,
|
141 |
in_tokens=8193,
|
142 |
+
out_channels=200, # mean and variance
|
|
|
|
|
143 |
dropout=0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
use_fp16=False,
|
145 |
+
num_heads=16,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
# Parameters for regularization.
|
147 |
+
layer_drop=.1,
|
148 |
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
|
|
|
|
|
|
149 |
):
|
150 |
super().__init__()
|
151 |
|
|
|
|
|
|
|
|
|
|
|
152 |
self.in_channels = in_channels
|
153 |
self.model_channels = model_channels
|
154 |
self.out_channels = out_channels
|
|
|
155 |
self.dropout = dropout
|
|
|
|
|
156 |
self.num_heads = num_heads
|
|
|
|
|
|
|
|
|
157 |
self.unconditioned_percentage = unconditioned_percentage
|
158 |
self.enable_fp16 = use_fp16
|
159 |
+
self.layer_drop = layer_drop
|
|
|
|
|
|
|
160 |
|
161 |
+
self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
|
162 |
self.time_embed = nn.Sequential(
|
163 |
+
nn.Linear(model_channels, model_channels),
|
164 |
nn.SiLU(),
|
165 |
+
nn.Linear(model_channels, model_channels),
|
166 |
)
|
167 |
|
|
|
168 |
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
169 |
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
170 |
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
171 |
# transformer network.
|
172 |
+
self.code_embedding = nn.Embedding(in_tokens, model_channels)
|
173 |
self.code_converter = nn.Sequential(
|
174 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
175 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
176 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
)
|
178 |
+
self.code_norm = normalization(model_channels)
|
179 |
+
self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
|
180 |
+
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
181 |
+
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
182 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
183 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
184 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
185 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
186 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
|
187 |
+
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
188 |
+
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
189 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
190 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
191 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
)
|
193 |
+
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
|
194 |
+
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
195 |
+
|
196 |
+
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
197 |
+
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
self.out = nn.Sequential(
|
200 |
+
normalization(model_channels),
|
201 |
nn.SiLU(),
|
202 |
+
nn.Conv1d(model_channels, out_channels, 3, padding=1),
|
203 |
)
|
204 |
|
205 |
+
def get_grad_norm_parameter_groups(self):
|
206 |
+
groups = {
|
207 |
+
'minicoder': list(self.contextual_embedder.parameters()),
|
208 |
+
'layers': list(self.layers.parameters()),
|
209 |
+
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
|
210 |
+
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
|
211 |
+
'time_embed': list(self.time_embed.parameters()),
|
212 |
+
}
|
213 |
+
return groups
|
214 |
+
|
215 |
+
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
# Shuffle aligned_latent to BxCxS format
|
217 |
if is_latent(aligned_conditioning):
|
218 |
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
219 |
|
220 |
+
# Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
|
221 |
+
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
|
222 |
+
conditioning_input.shape) == 3 else conditioning_input
|
223 |
+
conds = []
|
224 |
+
for j in range(speech_conditioning_input.shape[1]):
|
225 |
+
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
226 |
+
conds = torch.cat(conds, dim=-1)
|
227 |
+
cond_emb = conds.mean(dim=-1)
|
228 |
+
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
229 |
+
if is_latent(aligned_conditioning):
|
230 |
+
code_emb = self.latent_converter(aligned_conditioning)
|
231 |
+
else:
|
232 |
+
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
233 |
+
code_emb = self.code_converter(code_emb)
|
234 |
+
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
235 |
+
|
236 |
+
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
237 |
+
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
238 |
+
if self.training and self.unconditioned_percentage > 0:
|
239 |
+
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
240 |
+
device=code_emb.device) < self.unconditioned_percentage
|
241 |
+
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
242 |
+
code_emb)
|
243 |
+
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
|
244 |
+
|
245 |
+
if not return_code_pred:
|
246 |
+
return expanded_code_emb
|
247 |
+
else:
|
248 |
+
mel_pred = self.mel_head(expanded_code_emb)
|
249 |
+
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
|
250 |
+
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
251 |
+
return expanded_code_emb, mel_pred
|
252 |
|
|
|
|
|
253 |
|
254 |
+
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
255 |
+
"""
|
256 |
+
Apply the model to an input batch.
|
257 |
+
|
258 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
259 |
+
:param timesteps: a 1-D batch of timesteps.
|
260 |
+
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
261 |
+
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
262 |
+
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
|
263 |
+
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
264 |
+
:return: an [N x C x ...] Tensor of outputs.
|
265 |
+
"""
|
266 |
+
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
|
267 |
+
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
268 |
+
|
269 |
+
unused_params = []
|
270 |
+
if conditioning_free:
|
271 |
+
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
272 |
+
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
273 |
+
unused_params.extend(list(self.latent_converter.parameters()))
|
274 |
+
else:
|
275 |
+
if precomputed_aligned_embeddings is not None:
|
276 |
code_emb = precomputed_aligned_embeddings
|
277 |
+
else:
|
278 |
+
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
|
279 |
+
if is_latent(aligned_conditioning):
|
280 |
+
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
else:
|
282 |
+
unused_params.extend(list(self.latent_converter.parameters()))
|
283 |
+
unused_params.append(self.unconditioned_embedding)
|
284 |
+
|
285 |
+
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
286 |
+
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
287 |
+
x = self.inp_block(x)
|
288 |
+
x = torch.cat([x, code_emb], dim=1)
|
289 |
+
x = self.integrating_conv(x)
|
290 |
+
for i, lyr in enumerate(self.layers):
|
291 |
+
# Do layer drop where applicable. Do not drop first and last layers.
|
292 |
+
if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop:
|
293 |
+
unused_params.extend(list(lyr.parameters()))
|
294 |
+
else:
|
295 |
+
# First and last blocks will have autocast disabled for improved precision.
|
296 |
+
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
|
297 |
+
x = lyr(x, time_emb)
|
298 |
|
299 |
+
x = x.float()
|
300 |
+
out = self.out(x)
|
301 |
+
|
302 |
+
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
303 |
+
extraneous_addition = 0
|
304 |
+
for p in unused_params:
|
305 |
+
extraneous_addition = extraneous_addition + p.mean()
|
306 |
+
out = out + extraneous_addition * 0
|
307 |
+
|
308 |
+
if return_code_pred:
|
309 |
+
return out, mel_pred
|
310 |
return out
|
311 |
|
312 |
|
313 |
if __name__ == '__main__':
|
314 |
+
clip = torch.randn(2, 100, 400)
|
315 |
+
aligned_latent = torch.randn(2,388,512)
|
316 |
+
aligned_sequence = torch.randint(0,8192,(2,100))
|
317 |
+
cond = torch.randn(2, 100, 400)
|
318 |
ts = torch.LongTensor([600, 600])
|
319 |
+
model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
# Test with latent aligned conditioning
|
321 |
+
#o = model(clip, ts, aligned_latent, cond)
|
322 |
# Test with sequence aligned conditioning
|
323 |
o = model(clip, ts, aligned_sequence, cond)
|
324 |
+
|