File size: 2,699 Bytes
24f6ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93


import os

import argparse
import yaml
import torch

from audioldm import LatentDiffusion, seed_everything
from audioldm.utils import default_audioldm_config


import time

def make_batch_for_text_to_audio(text, batchsize=1):
    text = [text] * batchsize
    if batchsize < 1:
        print("Warning: Batchsize must be at least 1. Batchsize is set to .")
    fbank = torch.zeros((batchsize, 1024, 64))  # Not used, here to keep the code format
    stft = torch.zeros((batchsize, 1024, 512))  # Not used
    waveform = torch.zeros((batchsize, 160000))  # Not used
    fname = [""] * batchsize # Not used
    batch = (
        fbank,
        stft,
        None,
        fname,
        waveform,
        text,
    )  
    return batch



def build_model(
    ckpt_path=None,
    config=None,
    model_name="audioldm-s-full"
):
    print("Load AudioLDM: %s" % model_name)
    
    resume_from_checkpoint = "ckpt/%s.ckpt" % model_name
    
    # if(ckpt_path is None):
    #     ckpt_path = get_metadata()[model_name]["path"]
    
    # if(not os.path.exists(ckpt_path)):
    #     download_checkpoint(model_name)
        
    if(torch.cuda.is_available()):
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
        
    if(config is not None):
        assert type(config) is str
        config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
    else:
        config = default_audioldm_config(model_name)

    # Use text as condition instead of using waveform during training
    config["model"]["params"]["device"] = device
    config["model"]["params"]["cond_stage_key"] = "text"

    # No normalization here
    latent_diffusion = LatentDiffusion(**config["model"]["params"])

    checkpoint = torch.load(resume_from_checkpoint, map_location=device)
    latent_diffusion.load_state_dict(checkpoint["state_dict"])

    latent_diffusion.eval()
    latent_diffusion = latent_diffusion.to(device)

    latent_diffusion.cond_stage_model.embed_mode = "text"
    return latent_diffusion

def duration_to_latent_t_size(duration):
    return int(duration * 25.6) 

def text_to_audio(latent_diffusion, text, seed=42, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None):
    seed_everything(int(seed))
    batch = make_batch_for_text_to_audio(text, batchsize=batchsize)
    
    latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
    with torch.no_grad():
        waveform = latent_diffusion.generate_sample(
            [batch],
            unconditional_guidance_scale=guidance_scale,
            n_candidate_gen_per_text=n_candidate_gen_per_text,
            duration=duration
        )
    return waveform