Enable redaction by default
Browse files- tortoise/api.py +1 -2
- tortoise/utils/wav2vec_alignment.py +85 -30
tortoise/api.py
CHANGED
@@ -165,7 +165,7 @@ class TextToSpeech:
|
|
165 |
Main entry point into Tortoise.
|
166 |
"""
|
167 |
|
168 |
-
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=
|
169 |
"""
|
170 |
Constructor
|
171 |
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
@@ -275,7 +275,6 @@ class TextToSpeech:
|
|
275 |
"""
|
276 |
# Use generally found best tuning knobs for generation.
|
277 |
kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
|
278 |
-
#'typical_sampling': True,
|
279 |
'top_p': .8,
|
280 |
'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
|
281 |
# Presets are defined here.
|
|
|
165 |
Main entry point into Tortoise.
|
166 |
"""
|
167 |
|
168 |
+
def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
|
169 |
"""
|
170 |
Constructor
|
171 |
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
|
|
275 |
"""
|
276 |
# Use generally found best tuning knobs for generation.
|
277 |
kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
|
|
|
278 |
'top_p': .8,
|
279 |
'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
|
280 |
# Presets are defined here.
|
tortoise/utils/wav2vec_alignment.py
CHANGED
@@ -7,13 +7,52 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTo
|
|
7 |
from tortoise.utils.audio import load_audio
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
class Wav2VecAlignment:
|
|
|
|
|
|
|
11 |
def __init__(self):
|
12 |
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
13 |
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
14 |
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
|
15 |
|
16 |
-
def align(self, audio, expected_text, audio_sample_rate=24000
|
17 |
orig_len = audio.shape[-1]
|
18 |
|
19 |
with torch.no_grad():
|
@@ -25,32 +64,59 @@ class Wav2VecAlignment:
|
|
25 |
self.model = self.model.cpu()
|
26 |
|
27 |
logits = logits[0]
|
|
|
|
|
|
|
28 |
w2v_compression = orig_len // logits.shape[0]
|
29 |
-
expected_tokens = self.tokenizer.encode(
|
|
|
30 |
if len(expected_tokens) == 1:
|
31 |
return [0] # The alignment is simple; there is only one token.
|
32 |
expected_tokens.pop(0) # The first token is a given.
|
33 |
-
|
|
|
34 |
alignments = [0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
for i, logit in enumerate(logits):
|
36 |
-
top = logit.
|
37 |
-
if next_expected_token
|
38 |
alignments.append(i * w2v_compression)
|
39 |
if len(expected_tokens) > 0:
|
40 |
-
next_expected_token =
|
41 |
else:
|
42 |
break
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
f" `{self.tokenizer.decode(expected_tokens)}`. Here's what wav2vec thought it heard:"
|
47 |
-
f"`{self.tokenizer.decode(logits.argmax(-1).tolist())}`")
|
48 |
-
if not return_partial:
|
49 |
-
return None
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
|
|
|
|
54 |
if '[' not in expected_text:
|
55 |
return audio
|
56 |
splitted = expected_text.split('[')
|
@@ -58,33 +124,22 @@ class Wav2VecAlignment:
|
|
58 |
for spl in splitted[1:]:
|
59 |
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
|
60 |
fully_split.extend(spl.split(']'))
|
61 |
-
|
62 |
-
fully_split = [re.sub(r'[^a-zA-Z ]', '', s) for s in fully_split]
|
63 |
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
|
64 |
non_redacted_intervals = []
|
65 |
last_point = 0
|
66 |
for i in range(len(fully_split)):
|
67 |
if i % 2 == 0:
|
68 |
-
|
|
|
69 |
last_point += len(fully_split[i])
|
70 |
|
71 |
bare_text = ''.join(fully_split)
|
72 |
-
alignments = self.align(audio, bare_text, audio_sample_rate
|
73 |
-
# If alignment fails, we will attempt to recover by assuming the remaining alignments consume the rest of the string.
|
74 |
-
def get_alignment(i):
|
75 |
-
if i >= len(alignments):
|
76 |
-
return audio.shape[-1]
|
77 |
|
78 |
output_audio = []
|
79 |
for nri in non_redacted_intervals:
|
80 |
start, stop = nri
|
81 |
-
output_audio.append(audio[:,
|
82 |
return torch.cat(output_audio, dim=-1)
|
83 |
|
84 |
-
|
85 |
-
if __name__ == '__main__':
|
86 |
-
some_audio = load_audio('../../results/train_dotrice_0.wav', 24000)
|
87 |
-
aligner = Wav2VecAlignment()
|
88 |
-
text = "[God fucking damn it I'm so angry] The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them."
|
89 |
-
redact = aligner.redact(some_audio, text)
|
90 |
-
torchaudio.save(f'test_output.wav', redact, 24000)
|
|
|
7 |
from tortoise.utils.audio import load_audio
|
8 |
|
9 |
|
10 |
+
def max_alignment(s1, s2, skip_character='~', record={}):
|
11 |
+
"""
|
12 |
+
A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
|
13 |
+
used to replace that character.
|
14 |
+
|
15 |
+
Finally got to use my DP skills!
|
16 |
+
"""
|
17 |
+
assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
|
18 |
+
if len(s1) == 0:
|
19 |
+
return ''
|
20 |
+
if len(s2) == 0:
|
21 |
+
return skip_character * len(s1)
|
22 |
+
if s1 == s2:
|
23 |
+
return s1
|
24 |
+
if s1[0] == s2[0]:
|
25 |
+
return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)
|
26 |
+
|
27 |
+
take_s1_key = (len(s1), len(s2) - 1)
|
28 |
+
if take_s1_key in record:
|
29 |
+
take_s1, take_s1_score = record[take_s1_key]
|
30 |
+
else:
|
31 |
+
take_s1 = max_alignment(s1, s2[1:], skip_character, record)
|
32 |
+
take_s1_score = len(take_s1.replace(skip_character, ''))
|
33 |
+
record[take_s1_key] = (take_s1, take_s1_score)
|
34 |
+
|
35 |
+
take_s2_key = (len(s1) - 1, len(s2))
|
36 |
+
if take_s2_key in record:
|
37 |
+
take_s2, take_s2_score = record[take_s2_key]
|
38 |
+
else:
|
39 |
+
take_s2 = max_alignment(s1[1:], s2, skip_character, record)
|
40 |
+
take_s2_score = len(take_s2.replace(skip_character, ''))
|
41 |
+
record[take_s2_key] = (take_s2, take_s2_score)
|
42 |
+
|
43 |
+
return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2
|
44 |
+
|
45 |
+
|
46 |
class Wav2VecAlignment:
|
47 |
+
"""
|
48 |
+
Uses wav2vec2 to perform audio<->text alignment.
|
49 |
+
"""
|
50 |
def __init__(self):
|
51 |
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
52 |
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
53 |
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
|
54 |
|
55 |
+
def align(self, audio, expected_text, audio_sample_rate=24000):
|
56 |
orig_len = audio.shape[-1]
|
57 |
|
58 |
with torch.no_grad():
|
|
|
64 |
self.model = self.model.cpu()
|
65 |
|
66 |
logits = logits[0]
|
67 |
+
pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
|
68 |
+
|
69 |
+
fixed_expectation = max_alignment(expected_text, pred_string)
|
70 |
w2v_compression = orig_len // logits.shape[0]
|
71 |
+
expected_tokens = self.tokenizer.encode(fixed_expectation)
|
72 |
+
expected_chars = list(fixed_expectation)
|
73 |
if len(expected_tokens) == 1:
|
74 |
return [0] # The alignment is simple; there is only one token.
|
75 |
expected_tokens.pop(0) # The first token is a given.
|
76 |
+
expected_chars.pop(0)
|
77 |
+
|
78 |
alignments = [0]
|
79 |
+
def pop_till_you_win():
|
80 |
+
if len(expected_tokens) == 0:
|
81 |
+
return None
|
82 |
+
popped = expected_tokens.pop(0)
|
83 |
+
popped_char = expected_chars.pop(0)
|
84 |
+
while popped_char == '~':
|
85 |
+
alignments.append(-1)
|
86 |
+
if len(expected_tokens) == 0:
|
87 |
+
return None
|
88 |
+
popped = expected_tokens.pop(0)
|
89 |
+
popped_char = expected_chars.pop(0)
|
90 |
+
return popped
|
91 |
+
|
92 |
+
next_expected_token = pop_till_you_win()
|
93 |
for i, logit in enumerate(logits):
|
94 |
+
top = logit.argmax()
|
95 |
+
if next_expected_token == top:
|
96 |
alignments.append(i * w2v_compression)
|
97 |
if len(expected_tokens) > 0:
|
98 |
+
next_expected_token = pop_till_you_win()
|
99 |
else:
|
100 |
break
|
101 |
|
102 |
+
pop_till_you_win()
|
103 |
+
assert len(expected_tokens) == 0, "This shouldn't happen. My coding sucks."
|
|
|
|
|
|
|
|
|
104 |
|
105 |
+
# Now fix up alignments. Anything with -1 should be interpolated.
|
106 |
+
alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
|
107 |
+
for i in range(len(alignments)):
|
108 |
+
if alignments[i] == -1:
|
109 |
+
for j in range(i+1, len(alignments)):
|
110 |
+
if alignments[j] != -1:
|
111 |
+
next_found_token = j
|
112 |
+
break
|
113 |
+
for j in range(i, next_found_token):
|
114 |
+
gap = alignments[next_found_token] - alignments[i-1]
|
115 |
+
alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1]
|
116 |
|
117 |
+
return alignments[:-1]
|
118 |
+
|
119 |
+
def redact(self, audio, expected_text, audio_sample_rate=24000):
|
120 |
if '[' not in expected_text:
|
121 |
return audio
|
122 |
splitted = expected_text.split('[')
|
|
|
124 |
for spl in splitted[1:]:
|
125 |
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
|
126 |
fully_split.extend(spl.split(']'))
|
127 |
+
|
|
|
128 |
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
|
129 |
non_redacted_intervals = []
|
130 |
last_point = 0
|
131 |
for i in range(len(fully_split)):
|
132 |
if i % 2 == 0:
|
133 |
+
end_interval = max(0, last_point + len(fully_split[i]) - 1)
|
134 |
+
non_redacted_intervals.append((last_point, end_interval))
|
135 |
last_point += len(fully_split[i])
|
136 |
|
137 |
bare_text = ''.join(fully_split)
|
138 |
+
alignments = self.align(audio, bare_text, audio_sample_rate)
|
|
|
|
|
|
|
|
|
139 |
|
140 |
output_audio = []
|
141 |
for nri in non_redacted_intervals:
|
142 |
start, stop = nri
|
143 |
+
output_audio.append(audio[:, alignments[start]:alignments[stop]])
|
144 |
return torch.cat(output_audio, dim=-1)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|