|
import os |
|
import json |
|
|
|
import torch |
|
import numpy as np |
|
|
|
import audioldm.hifigan as hifigan |
|
|
|
HIFIGAN_16K_64 = { |
|
"resblock": "1", |
|
"num_gpus": 6, |
|
"batch_size": 16, |
|
"learning_rate": 0.0002, |
|
"adam_b1": 0.8, |
|
"adam_b2": 0.99, |
|
"lr_decay": 0.999, |
|
"seed": 1234, |
|
"upsample_rates": [5, 4, 2, 2, 2], |
|
"upsample_kernel_sizes": [16, 16, 8, 4, 4], |
|
"upsample_initial_channel": 1024, |
|
"resblock_kernel_sizes": [3, 7, 11], |
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
|
"segment_size": 8192, |
|
"num_mels": 64, |
|
"num_freq": 1025, |
|
"n_fft": 1024, |
|
"hop_size": 160, |
|
"win_size": 1024, |
|
"sampling_rate": 16000, |
|
"fmin": 0, |
|
"fmax": 8000, |
|
"fmax_for_loss": None, |
|
"num_workers": 4, |
|
"dist_config": { |
|
"dist_backend": "nccl", |
|
"dist_url": "tcp://localhost:54321", |
|
"world_size": 1, |
|
}, |
|
} |
|
|
|
|
|
def get_available_checkpoint_keys(model, ckpt): |
|
print("==> Attemp to reload from %s" % ckpt) |
|
state_dict = torch.load(ckpt)["state_dict"] |
|
current_state_dict = model.state_dict() |
|
new_state_dict = {} |
|
for k in state_dict.keys(): |
|
if ( |
|
k in current_state_dict.keys() |
|
and current_state_dict[k].size() == state_dict[k].size() |
|
): |
|
new_state_dict[k] = state_dict[k] |
|
else: |
|
print("==> WARNING: Skipping %s" % k) |
|
print( |
|
"%s out of %s keys are matched" |
|
% (len(new_state_dict.keys()), len(state_dict.keys())) |
|
) |
|
return new_state_dict |
|
|
|
|
|
def get_param_num(model): |
|
num_param = sum(param.numel() for param in model.parameters()) |
|
return num_param |
|
|
|
|
|
def get_vocoder(config, device): |
|
config = hifigan.AttrDict(HIFIGAN_16K_64) |
|
vocoder = hifigan.Generator(config) |
|
vocoder.eval() |
|
vocoder.remove_weight_norm() |
|
vocoder.to(device) |
|
return vocoder |
|
|
|
|
|
def vocoder_infer(mels, vocoder, lengths=None): |
|
with torch.no_grad(): |
|
wavs = vocoder(mels).squeeze(1) |
|
|
|
wavs = (wavs.cpu().numpy() * 32768).astype("int16") |
|
|
|
if lengths is not None: |
|
wavs = wavs[:, :lengths] |
|
|
|
return wavs |
|
|