MyModel2 / tortoise /read.py
MasonCrinr's picture
Upload 580 files
762a084
raw
history blame contribute delete
No virus
5.76 kB
import argparse
import os
from time import time
import torch
import torchaudio
from api import TextToSpeech, MODELS_DIR
from utils.audio import load_audio, load_voices
from utils.text import split_and_recombine_text
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="tortoise/data/riding_hood.txt")
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat')
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
parser.add_argument('--output_name', type=str, help='How to name the output file', default='combined.wav')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard')
parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None)
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1)
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=False)
parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True)
parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True)
args = parser.parse_args()
if torch.backends.mps.is_available():
args.use_deepspeed = False
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
outpath = args.output_path
outname = args.output_name
selected_voices = args.voice.split(',')
regenerate = args.regenerate
if regenerate is not None:
regenerate = [int(e) for e in regenerate.split(',')]
# Process text
with open(args.textfile, 'r', encoding='utf-8') as f:
text = ' '.join([l for l in f.readlines()])
if '|' in text:
print("Found the '|' character in your text, which I will use as a cue for where to split it up. If this was not"
"your intent, please remove all '|' characters from the input.")
texts = text.split('|')
else:
texts = split_and_recombine_text(text)
seed = int(time()) if args.seed is None else args.seed
for selected_voice in selected_voices:
voice_outpath = os.path.join(outpath, selected_voice)
os.makedirs(voice_outpath, exist_ok=True)
if '&' in selected_voice:
voice_sel = selected_voice.split('&')
else:
voice_sel = [selected_voice]
voice_samples, conditioning_latents = load_voices(voice_sel)
all_parts = []
for j, text in enumerate(texts):
if regenerate is not None and j not in regenerate:
all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
continue
gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
preset=args.preset, k=args.candidates, use_deterministic_seed=seed)
if args.candidates == 1:
audio_ = gen.squeeze(0).cpu()
torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), audio_, 24000)
else:
candidate_dir = os.path.join(voice_outpath, str(j))
os.makedirs(candidate_dir, exist_ok=True)
for k, g in enumerate(gen):
torchaudio.save(os.path.join(candidate_dir, f'{k}.wav'), g.squeeze(0).cpu(), 24000)
audio_ = gen[0].squeeze(0).cpu()
all_parts.append(audio_)
if args.candidates == 1:
full_audio = torch.cat(all_parts, dim=-1)
torchaudio.save(os.path.join(voice_outpath, f"{outname}.wav"), full_audio, 24000)
if args.produce_debug_state:
os.makedirs('debug_states', exist_ok=True)
dbg_state = (seed, texts, voice_samples, conditioning_latents)
torch.save(dbg_state, f'debug_states/read_debug_{selected_voice}.pth')
# Combine each candidate's audio clips.
if args.candidates > 1:
audio_clips = []
for candidate in range(args.candidates):
for line in range(len(texts)):
wav_file = os.path.join(voice_outpath, str(line), f"{candidate}.wav")
audio_clips.append(load_audio(wav_file, 24000))
audio_clips = torch.cat(audio_clips, dim=-1)
torchaudio.save(os.path.join(voice_outpath, f"{outname}_{candidate:02d}.wav"), audio_clips, 24000)
audio_clips = []