apsys commited on
Commit
14e7fb1
1 Parent(s): a80906b

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/main-checkpoint.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from speechtokenizer import SpeechTokenizer
6
+ from audiotools import AudioSignal
7
+ import bitsandbytes as bnb # Import bitsandbytes for INT8 quantization
8
+ import numpy as np
9
+ from uuid import uuid4
10
+
11
+ # Load the necessary models and tokenizers
12
+ model_path = "Vikhrmodels/llama_asr_tts_24000"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".")
14
+ # Специальные токены
15
+ start_audio_token = "<soa>"
16
+ end_audio_token = "<eoa>"
17
+ end_sequence_token = "<eos>"
18
+
19
+ # Константы
20
+ n_codebooks = 3
21
+ max_seq_length = 1024
22
+ top_k = 20
23
+
24
+ from safetensors.torch import load_file
25
+
26
+ def convert_to_16_bit_wav(data):
27
+ # Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
28
+ # breakpoint()
29
+ if data.dtype == np.float32:
30
+ # warnings.warn(
31
+ # "Audio data is not in 16-bit integer format."
32
+ # "Trying to convert to 16-bit int format."
33
+ # )
34
+ data = data / np.abs(data).max()
35
+ data = data * 32767
36
+ data = data.astype(np.int16)
37
+ elif data.dtype == np.int32:
38
+ # warnings.warn(
39
+ # "Audio data is not in 16-bit integer format."
40
+ # "Trying to convert to 16-bit int format."
41
+ # )
42
+ data = data / 65538
43
+ data = data.astype(np.int16)
44
+ elif data.dtype == np.int16:
45
+ pass
46
+ elif data.dtype == np.uint8:
47
+ # warnings.warn(
48
+ # "Audio data is not in 16-bit integer format."
49
+ # "Trying to convert to 16-bit int format."
50
+ # )
51
+ data = data * 257 - 32768
52
+ data = data.astype(np.int16)
53
+ else:
54
+ raise ValueError("Audio data cannot be converted to " "16-bit int format.")
55
+ return data
56
+
57
+ # Load the model with INT8 quantization
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ model_path,
60
+ cache_dir=".",
61
+ load_in_8bit=True, # Enable loading in INT8
62
+ device_map="auto" # Automatically map model to available devices
63
+ )
64
+
65
+ # Configurations for Speech Tokenizer
66
+ config_path = "audiotokenizer/speechtokenizer_hubert_avg_config.json"
67
+ ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
68
+ quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
69
+ quantizer.eval()
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+
72
+ # Перемещение всех слоев квантизатора на устройство и их заморозка
73
+ def freeze_entire_model(model):
74
+ for n, p in model.named_parameters():
75
+ p.requires_grad = False
76
+ return model
77
+
78
+ for n, child in quantizer.named_children():
79
+ child.to(device)
80
+ child = freeze_entire_model(child)
81
+
82
+ # Функция для создания токенов заполнения для аудио
83
+ def get_audio_padding_tokens(quantizer):
84
+ audio = torch.zeros((1, 1, 1)).to(device)
85
+ codes = quantizer.encode(audio)
86
+ del audio
87
+ torch.cuda.empty_cache()
88
+ return {"audio_tokens": codes.squeeze(1)}
89
+
90
+ # Функция для декодирования аудио из токенов
91
+ def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
92
+ start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
93
+ end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
94
+ start = start[0, -1] + 1 if len(start) else 0
95
+ end = end[0, -1] if len(end) else tokens.shape[-1]
96
+
97
+ audio_tokens = tokens[start:end] % n_original_tokens
98
+ reminder = audio_tokens.shape[-1] % n_codebooks
99
+
100
+ if reminder:
101
+ audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)
102
+
103
+ transposed = audio_tokens.view(-1, n_codebooks).t()
104
+ codes = transposed.view(n_codebooks, 1, -1).to(device)
105
+
106
+ audio = quantizer.decode(codes).squeeze(0)
107
+ torch.cuda.empty_cache()
108
+ xp = str(uuid4())+'.wav'
109
+ AudioSignal(audio.detach().cpu().numpy(),quantizer.sample_rate).write(xp)
110
+ return xp
111
+
112
+
113
+ # Пример использования
114
+
115
+ # Функция инференса для текста на входе и аудио на выходе
116
+ def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
117
+ text_tokenized = tokenizer(text, return_tensors="pt")
118
+ text_input_tokens = text_tokenized["input_ids"].to(device)
119
+
120
+ soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
121
+ eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
122
+
123
+ text_tokens = torch.cat([text_input_tokens, soa], dim=1)
124
+ attention_mask = torch.ones(text_tokens.size(), device=device)
125
+
126
+ output_audio_tokens = model.generate(text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
127
+
128
+ padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device)
129
+ audio_signal = decode_audio(output_audio_tokens[0], quantizer, padding_tokens.t()[0], len(tokenizer) - 1024)
130
+
131
+ return audio_signal
132
+
133
+ # Функция инференса для аудио на входе и текста на выходе
134
+ def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
135
+ audio_data, sample_rate = torchaudio.load(audio_path)
136
+
137
+ audio = audio_data.view(1, 1, -1).float().to(device)
138
+ codes = quantizer.encode(audio)
139
+ n_codebooks_a = 1
140
+ raw_audio_tokens = codes[:, :n_codebooks_a] + len(tokenizer) - 1024
141
+
142
+ soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
143
+ eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
144
+ audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)
145
+
146
+ attention_mask = torch.ones(audio_tokens.size(), device=device)
147
+
148
+ output_text_tokens = model.generate(audio_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
149
+
150
+ output_text_tokens = output_text_tokens.cpu()[0]
151
+ output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
152
+ decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
153
+
154
+ return decoded_text
155
+
156
+ # Functions for inference
157
+ def infer_text_to_audio_gr(text):
158
+ audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
159
+ return audio_signal
160
+
161
+ def infer_audio_to_text_gr(audio_path):
162
+ generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer)
163
+ return generated_text
164
+
165
+ # Gradio Interface
166
+ text_to_audio_interface = gr.Interface(
167
+ fn=infer_text_to_audio_gr,
168
+ inputs=gr.Textbox(label="Input Text"),
169
+ outputs=gr.Audio(label="Аудио Ответ"),
170
+ title="T2S",
171
+ description="Модель в режиме ответа в аудио",
172
+ allow_flagging='never',
173
+ )
174
+
175
+ audio_to_text_interface = gr.Interface(
176
+ fn=infer_audio_to_text_gr,
177
+ inputs=gr.Audio(type="filepath", label="Input Audio"),
178
+ outputs=gr.Textbox(label="Текстовый ответ"),
179
+ title="S2T",
180
+ description="Модель в режиме ответа в тексте",
181
+ allow_flagging='never'
182
+ )
183
+
184
+ # Launch Gradio App
185
+ demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Текст - Аудио", "Аудио - Текст"])
186
+ demo.launch(share=True)
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.34.2
3
+ aiofiles==23.2.1
4
+ aiohappyeyeballs==2.4.0
5
+ aiohttp==3.10.5
6
+ aiosignal==1.3.1
7
+ annotated-types==0.7.0
8
+ antlr4-python3-runtime==4.9.3
9
+ anyio==4.0.0
10
+ argbind==0.3.9
11
+ argon2-cffi==23.1.0
12
+ argon2-cffi-bindings==21.2.0
13
+ arrow==1.3.0
14
+ asttokens==2.4.1
15
+ async-lru==2.0.4
16
+ async-timeout==4.0.3
17
+ attrs==23.1.0
18
+ audioread==3.0.1
19
+ autobahn==21.11.1
20
+ Automat==20.2.0
21
+ Babel==2.13.1
22
+ base58==1.0.3
23
+ bcrypt==3.2.0
24
+ beartype==0.18.5
25
+ beautifulsoup4==4.12.2
26
+ bitsandbytes==0.43.3
27
+ bleach==6.1.0
28
+ blinker==1.4
29
+ cachetools==5.5.0
30
+ cbor==1.0.0
31
+ certifi==2022.12.7
32
+ cffi==1.16.0
33
+ charset-normalizer==2.1.1
34
+ click==8.1.7
35
+ colorama==0.4.4
36
+ comm==0.2.0
37
+ constantly==15.1.0
38
+ contourpy==1.3.0
39
+ cryptography==3.4.8
40
+ cycler==0.12.1
41
+ Cython==0.29.28
42
+ datasets==2.21.0
43
+ dbus-python==1.2.18
44
+ debugpy==1.8.0
45
+ decorator==5.1.1
46
+ defusedxml==0.7.1
47
+ descript-audiotools @ git+https://github.com/descriptinc/audiotools@7776c296c711db90176a63ff808c26e0ee087263
48
+ dill==0.3.8
49
+ distro==1.7.0
50
+ docstring_parser==0.16
51
+ ecdsa==0.18.0b1
52
+ einops==0.8.0
53
+ entrypoints==0.4
54
+ exceptiongroup==1.1.3
55
+ executing==2.0.1
56
+ fastapi==0.112.4
57
+ fastjsonschema==2.18.1
58
+ ffmpy==0.4.0
59
+ filelock==3.9.0
60
+ fire==0.6.0
61
+ flatbuffers===1.12.1-git20200711.33e2d80-dfsg1-0.6
62
+ flatten-dict==0.4.2
63
+ fonttools==4.53.1
64
+ fqdn==1.5.1
65
+ frozenlist==1.4.1
66
+ fsspec==2024.6.1
67
+ future==1.0.0
68
+ GeoIP==1.3.2
69
+ gradio==4.43.0
70
+ gradio_client==1.3.0
71
+ grpcio==1.66.1
72
+ h11==0.14.0
73
+ hkdf==0.0.3
74
+ httpcore==1.0.5
75
+ httplib2==0.20.2
76
+ httpx==0.27.2
77
+ huggingface-hub==0.24.6
78
+ humanize==0.0.0
79
+ hyperlink==21.0.0
80
+ idna==3.4
81
+ importlib-metadata==4.6.4
82
+ importlib_resources==6.4.5
83
+ incremental==21.3.0
84
+ iotop==0.6
85
+ ipykernel==6.26.0
86
+ ipython==8.17.2
87
+ ipython-genutils==0.2.0
88
+ ipywidgets==8.1.1
89
+ isoduration==20.11.0
90
+ jedi==0.19.1
91
+ jeepney==0.7.1
92
+ Jinja2==3.1.2
93
+ joblib==1.4.2
94
+ json5==0.9.14
95
+ jsonpointer==2.4
96
+ jsonschema==4.19.2
97
+ jsonschema-specifications==2023.7.1
98
+ keyring==23.5.0
99
+ kiwisolver==1.4.7
100
+ launchpadlib==1.10.16
101
+ lazr.restfulclient==0.14.4
102
+ lazr.uri==1.0.6
103
+ lazy_loader==0.4
104
+ librosa==0.10.2.post1
105
+ lion-pytorch==0.2.2
106
+ Markdown==3.7
107
+ markdown-it-py==3.0.0
108
+ markdown2==2.5.0
109
+ MarkupSafe==2.1.2
110
+ matplotlib==3.5.0
111
+ matplotlib-inline==0.1.6
112
+ mdurl==0.1.2
113
+ mistune==3.0.2
114
+ mnemonic==0.19
115
+ more-itertools==8.10.0
116
+ mpmath==1.3.0
117
+ msgpack==1.0.8
118
+ multidict==6.0.5
119
+ multiprocess==0.70.16
120
+ nbclassic==1.0.0
121
+ nbclient==0.9.0
122
+ nbconvert==7.11.0
123
+ nbformat==5.9.2
124
+ nest-asyncio==1.5.8
125
+ networkx==3.0
126
+ notebook==6.5.5
127
+ notebook_shim==0.2.3
128
+ numba==0.60.0
129
+ numpy==1.24.1
130
+ nvidia-ml-py==12.535.161
131
+ nvitop==1.3.2
132
+ oauthlib==3.2.0
133
+ omegaconf==2.3.0
134
+ orjson==3.10.7
135
+ overrides==7.4.0
136
+ packaging==23.2
137
+ pandas==2.2.2
138
+ pandocfilters==1.5.0
139
+ parso==0.8.3
140
+ passlib==1.7.4
141
+ pexpect==4.8.0
142
+ Pillow==9.3.0
143
+ platformdirs==3.11.0
144
+ ply==3.11
145
+ pooch==1.8.2
146
+ prometheus-client==0.18.0
147
+ prompt-toolkit==3.0.39
148
+ protobuf==3.19.6
149
+ psutil==5.9.6
150
+ ptyprocess==0.7.0
151
+ pure-eval==0.2.2
152
+ py-ubjson==0.16.1
153
+ pyarrow==17.0.0
154
+ pyasn1==0.4.8
155
+ pyasn1-modules==0.2.1
156
+ pycparser==2.21
157
+ pydantic==2.9.1
158
+ pydantic_core==2.23.3
159
+ pydub==0.25.1
160
+ Pygments==2.16.1
161
+ PyGObject==3.42.1
162
+ PyHamcrest==2.0.2
163
+ PyJWT==2.3.0
164
+ pyloudnorm==0.1.1
165
+ PyNaCl==1.5.0
166
+ pyOpenSSL==21.0.0
167
+ pyparsing==2.4.7
168
+ pypng==0.0.20
169
+ PyQRCode==1.2.1
170
+ pystoi==0.4.1
171
+ python-apt==2.4.0+ubuntu2
172
+ python-dateutil==2.8.2
173
+ python-json-logger==2.0.7
174
+ python-multipart==0.0.9
175
+ python-snappy==0.5.3
176
+ PyTrie==0.4.0
177
+ pytz==2024.1
178
+ PyYAML==6.0.1
179
+ pyzmq==24.0.1
180
+ randomname==0.2.1
181
+ referencing==0.30.2
182
+ regex==2024.7.24
183
+ requests==2.32.3
184
+ rfc3339-validator==0.1.4
185
+ rfc3986-validator==0.1.1
186
+ rich==13.8.0
187
+ rpds-py==0.12.0
188
+ ruff==0.6.4
189
+ safetensors==0.4.4
190
+ scikit-learn==1.5.1
191
+ scipy==1.14.1
192
+ SecretStorage==3.3.1
193
+ semantic-version==2.10.0
194
+ Send2Trash==1.8.2
195
+ service-identity==18.1.0
196
+ setuptools-scm==8.1.0
197
+ shellingham==1.5.4
198
+ six==1.16.0
199
+ sniffio==1.3.0
200
+ sortedcontainers==2.1.0
201
+ soundfile==0.12.1
202
+ soupsieve==2.5
203
+ soxr==0.5.0.post1
204
+ spake2==0.8
205
+ speechtokenizer==1.0.1
206
+ stack-data==0.6.3
207
+ starlette==0.38.5
208
+ sympy==1.12
209
+ tensorboard==2.17.1
210
+ tensorboard-data-server==0.7.2
211
+ termcolor==2.4.0
212
+ terminado==0.17.1
213
+ threadpoolctl==3.5.0
214
+ tinycss2==1.2.1
215
+ tokenizers==0.19.1
216
+ tomli==2.0.1
217
+ tomlkit==0.12.0
218
+ torch==2.1.0+cu118
219
+ torch-stoi==0.2.1
220
+ torchaudio==2.1.0+cu118
221
+ torchvision==0.16.0+cu118
222
+ tornado==6.3.3
223
+ tqdm==4.66.5
224
+ traitlets==5.13.0
225
+ transformers==4.44.2
226
+ triton==2.1.0
227
+ Twisted==22.1.0
228
+ txaio==21.2.1
229
+ txtorcon==20.0.0
230
+ typer==0.12.5
231
+ types-python-dateutil==2.8.19.14
232
+ typing_extensions==4.12.2
233
+ tzdata==2024.1
234
+ u-msgpack-python==2.3.0
235
+ ujson==5.1.0
236
+ uri-template==1.3.0
237
+ urllib3==2.2.2
238
+ uvicorn==0.30.6
239
+ wadllib==1.3.6
240
+ wcwidth==0.2.9
241
+ webcolors==1.13
242
+ webencodings==0.5.1
243
+ websocket-client==1.6.4
244
+ websockets==12.0
245
+ Werkzeug==3.0.4
246
+ widgetsnbextension==4.0.9
247
+ wsaccel==0.6.3
248
+ xxhash==3.5.0
249
+ yarl==1.9.8
250
+ zipp==1.0.0
251
+ zope.interface==5.4.0
audiotokenizer/SpeechTokenizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d04593b6c9a4b475f91ca481141a6ef5b23e6ac112f347dd2b2717f193c1c728
3
+ size 481906997
audiotokenizer/speechtokenizer_hubert_avg_config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 3,
4
+ "batch_size": 60,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.5,
7
+ "adam_b2": 0.9,
8
+ "lr_decay": 0.98,
9
+ "seed": 1234,
10
+ "lambda_distill": 0.15,
11
+
12
+ "n_filters": 64,
13
+ "strides": [8,5,4,2],
14
+ "dimension": 1024,
15
+ "semantic_dimension": 768,
16
+ "bidirectional": true,
17
+ "dilation_base": 2,
18
+ "residual_kernel_size": 3,
19
+ "n_residual_layers": 1,
20
+ "lstm_layers": 2,
21
+ "activation": "ELU",
22
+
23
+
24
+ "segment_size": 48000,
25
+ "num_mels": 80,
26
+ "num_freq": 1025,
27
+ "n_fft": 1024,
28
+ "hop_size": 240,
29
+ "win_size": 1024,
30
+
31
+ "sampling_rate": 16000,
32
+ "sample_rate": 16000,
33
+
34
+ "codebook_size": 1024,
35
+ "n_q": 8,
36
+
37
+ "fmin": 0,
38
+ "fmax": 8000,
39
+ "fmax_for_loss": null,
40
+
41
+ "num_workers": 12,
42
+
43
+ "dist_config": {
44
+ "dist_backend": "nccl",
45
+ "dist_url": "tcp://localhost:54322",
46
+ "world_size": 1
47
+ }
48
+ }
main.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from speechtokenizer import SpeechTokenizer
6
+ from audiotools import AudioSignal
7
+ import bitsandbytes as bnb # Import bitsandbytes for INT8 quantization
8
+ import numpy as np
9
+ from uuid import uuid4
10
+
11
+ # Load the necessary models and tokenizers
12
+ model_path = "Vikhrmodels/llama_asr_tts_24000"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".")
14
+ # Специальные токены
15
+ start_audio_token = "<soa>"
16
+ end_audio_token = "<eoa>"
17
+ end_sequence_token = "<eos>"
18
+
19
+ # Константы
20
+ n_codebooks = 3
21
+ max_seq_length = 1024
22
+ top_k = 20
23
+
24
+ from safetensors.torch import load_file
25
+
26
+ def convert_to_16_bit_wav(data):
27
+ # Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
28
+ # breakpoint()
29
+ if data.dtype == np.float32:
30
+ # warnings.warn(
31
+ # "Audio data is not in 16-bit integer format."
32
+ # "Trying to convert to 16-bit int format."
33
+ # )
34
+ data = data / np.abs(data).max()
35
+ data = data * 32767
36
+ data = data.astype(np.int16)
37
+ elif data.dtype == np.int32:
38
+ # warnings.warn(
39
+ # "Audio data is not in 16-bit integer format."
40
+ # "Trying to convert to 16-bit int format."
41
+ # )
42
+ data = data / 65538
43
+ data = data.astype(np.int16)
44
+ elif data.dtype == np.int16:
45
+ pass
46
+ elif data.dtype == np.uint8:
47
+ # warnings.warn(
48
+ # "Audio data is not in 16-bit integer format."
49
+ # "Trying to convert to 16-bit int format."
50
+ # )
51
+ data = data * 257 - 32768
52
+ data = data.astype(np.int16)
53
+ else:
54
+ raise ValueError("Audio data cannot be converted to " "16-bit int format.")
55
+ return data
56
+
57
+ # Load the model with INT8 quantization
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ model_path,
60
+ cache_dir=".",
61
+ load_in_8bit=True, # Enable loading in INT8
62
+ device_map="auto" # Automatically map model to available devices
63
+ )
64
+
65
+ # Configurations for Speech Tokenizer
66
+ config_path = "audiotokenizer/speechtokenizer_hubert_avg_config.json"
67
+ ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
68
+ quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
69
+ quantizer.eval()
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+
72
+ # Перемещение всех слоев квантизатора на устройство и их заморозка
73
+ def freeze_entire_model(model):
74
+ for n, p in model.named_parameters():
75
+ p.requires_grad = False
76
+ return model
77
+
78
+ for n, child in quantizer.named_children():
79
+ child.to(device)
80
+ child = freeze_entire_model(child)
81
+
82
+ # Функция для создания токенов заполнения для аудио
83
+ def get_audio_padding_tokens(quantizer):
84
+ audio = torch.zeros((1, 1, 1)).to(device)
85
+ codes = quantizer.encode(audio)
86
+ del audio
87
+ torch.cuda.empty_cache()
88
+ return {"audio_tokens": codes.squeeze(1)}
89
+
90
+ # Функция для декодирования аудио из токенов
91
+ def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
92
+ start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
93
+ end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
94
+ start = start[0, -1] + 1 if len(start) else 0
95
+ end = end[0, -1] if len(end) else tokens.shape[-1]
96
+
97
+ audio_tokens = tokens[start:end] % n_original_tokens
98
+ reminder = audio_tokens.shape[-1] % n_codebooks
99
+
100
+ if reminder:
101
+ audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)
102
+
103
+ transposed = audio_tokens.view(-1, n_codebooks).t()
104
+ codes = transposed.view(n_codebooks, 1, -1).to(device)
105
+
106
+ audio = quantizer.decode(codes).squeeze(0)
107
+ torch.cuda.empty_cache()
108
+ xp = str(uuid4())+'.wav'
109
+ AudioSignal(audio.detach().cpu().numpy(),quantizer.sample_rate).write(xp)
110
+ return xp
111
+
112
+
113
+ # Пример использования
114
+
115
+ # Функция инференса для текста на входе и аудио на выходе
116
+ def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
117
+ text_tokenized = tokenizer(text, return_tensors="pt")
118
+ text_input_tokens = text_tokenized["input_ids"].to(device)
119
+
120
+ soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
121
+ eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
122
+
123
+ text_tokens = torch.cat([text_input_tokens, soa], dim=1)
124
+ attention_mask = torch.ones(text_tokens.size(), device=device)
125
+
126
+ output_audio_tokens = model.generate(text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
127
+
128
+ padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device)
129
+ audio_signal = decode_audio(output_audio_tokens[0], quantizer, padding_tokens.t()[0], len(tokenizer) - 1024)
130
+
131
+ return audio_signal
132
+
133
+ # Функция инференса для аудио на входе и текста на выходе
134
+ def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
135
+ audio_data, sample_rate = torchaudio.load(audio_path)
136
+
137
+ audio = audio_data.view(1, 1, -1).float().to(device)
138
+ codes = quantizer.encode(audio)
139
+ n_codebooks_a = 1
140
+ raw_audio_tokens = codes[:, :n_codebooks_a] + len(tokenizer) - 1024
141
+
142
+ soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
143
+ eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
144
+ audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)
145
+
146
+ attention_mask = torch.ones(audio_tokens.size(), device=device)
147
+
148
+ output_text_tokens = model.generate(audio_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
149
+
150
+ output_text_tokens = output_text_tokens.cpu()[0]
151
+ output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
152
+ decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
153
+
154
+ return decoded_text
155
+
156
+ # Functions for inference
157
+ def infer_text_to_audio_gr(text):
158
+ audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
159
+ return audio_signal
160
+
161
+ def infer_audio_to_text_gr(audio_path):
162
+ generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer)
163
+ return generated_text
164
+
165
+ # Gradio Interface
166
+ text_to_audio_interface = gr.Interface(
167
+ fn=infer_text_to_audio_gr,
168
+ inputs=gr.Textbox(label="Input Text"),
169
+ outputs=gr.Audio(label="Аудио Ответ"),
170
+ title="T2S",
171
+ description="Модель в режиме ответа в аудио",
172
+ allow_flagging='never',
173
+ )
174
+
175
+ audio_to_text_interface = gr.Interface(
176
+ fn=infer_audio_to_text_gr,
177
+ inputs=gr.Audio(type="filepath", label="Input Audio"),
178
+ outputs=gr.Textbox(label="Текстовый ответ"),
179
+ title="S2T",
180
+ description="Модель в режиме ответа в тексте",
181
+ allow_flagging='never'
182
+ )
183
+
184
+ # Launch Gradio App
185
+ demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Текст - Аудио", "Аудио - Текст"])
186
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.34.2
3
+ aiofiles==23.2.1
4
+ aiohappyeyeballs==2.4.0
5
+ aiohttp==3.10.5
6
+ aiosignal==1.3.1
7
+ annotated-types==0.7.0
8
+ antlr4-python3-runtime==4.9.3
9
+ anyio==4.0.0
10
+ argbind==0.3.9
11
+ argon2-cffi==23.1.0
12
+ argon2-cffi-bindings==21.2.0
13
+ arrow==1.3.0
14
+ asttokens==2.4.1
15
+ async-lru==2.0.4
16
+ async-timeout==4.0.3
17
+ attrs==23.1.0
18
+ audioread==3.0.1
19
+ autobahn==21.11.1
20
+ Automat==20.2.0
21
+ Babel==2.13.1
22
+ base58==1.0.3
23
+ bcrypt==3.2.0
24
+ beartype==0.18.5
25
+ beautifulsoup4==4.12.2
26
+ bitsandbytes==0.43.3
27
+ bleach==6.1.0
28
+ blinker==1.4
29
+ cachetools==5.5.0
30
+ cbor==1.0.0
31
+ certifi==2022.12.7
32
+ cffi==1.16.0
33
+ charset-normalizer==2.1.1
34
+ click==8.1.7
35
+ colorama==0.4.4
36
+ comm==0.2.0
37
+ constantly==15.1.0
38
+ contourpy==1.3.0
39
+ cryptography==3.4.8
40
+ cycler==0.12.1
41
+ Cython==0.29.28
42
+ datasets==2.21.0
43
+ dbus-python==1.2.18
44
+ debugpy==1.8.0
45
+ decorator==5.1.1
46
+ defusedxml==0.7.1
47
+ descript-audiotools @ git+https://github.com/descriptinc/audiotools@7776c296c711db90176a63ff808c26e0ee087263
48
+ dill==0.3.8
49
+ distro==1.7.0
50
+ docstring_parser==0.16
51
+ ecdsa==0.18.0b1
52
+ einops==0.8.0
53
+ entrypoints==0.4
54
+ exceptiongroup==1.1.3
55
+ executing==2.0.1
56
+ fastapi==0.112.4
57
+ fastjsonschema==2.18.1
58
+ ffmpy==0.4.0
59
+ filelock==3.9.0
60
+ fire==0.6.0
61
+ flatbuffers===1.12.1-git20200711.33e2d80-dfsg1-0.6
62
+ flatten-dict==0.4.2
63
+ fonttools==4.53.1
64
+ fqdn==1.5.1
65
+ frozenlist==1.4.1
66
+ fsspec==2024.6.1
67
+ future==1.0.0
68
+ GeoIP==1.3.2
69
+ gradio==4.43.0
70
+ gradio_client==1.3.0
71
+ grpcio==1.66.1
72
+ h11==0.14.0
73
+ hkdf==0.0.3
74
+ httpcore==1.0.5
75
+ httplib2==0.20.2
76
+ httpx==0.27.2
77
+ huggingface-hub==0.24.6
78
+ humanize==0.0.0
79
+ hyperlink==21.0.0
80
+ idna==3.4
81
+ importlib-metadata==4.6.4
82
+ importlib_resources==6.4.5
83
+ incremental==21.3.0
84
+ iotop==0.6
85
+ ipykernel==6.26.0
86
+ ipython==8.17.2
87
+ ipython-genutils==0.2.0
88
+ ipywidgets==8.1.1
89
+ isoduration==20.11.0
90
+ jedi==0.19.1
91
+ jeepney==0.7.1
92
+ Jinja2==3.1.2
93
+ joblib==1.4.2
94
+ json5==0.9.14
95
+ jsonpointer==2.4
96
+ jsonschema==4.19.2
97
+ jsonschema-specifications==2023.7.1
98
+ keyring==23.5.0
99
+ kiwisolver==1.4.7
100
+ launchpadlib==1.10.16
101
+ lazr.restfulclient==0.14.4
102
+ lazr.uri==1.0.6
103
+ lazy_loader==0.4
104
+ librosa==0.10.2.post1
105
+ lion-pytorch==0.2.2
106
+ Markdown==3.7
107
+ markdown-it-py==3.0.0
108
+ markdown2==2.5.0
109
+ MarkupSafe==2.1.2
110
+ matplotlib==3.5.0
111
+ matplotlib-inline==0.1.6
112
+ mdurl==0.1.2
113
+ mistune==3.0.2
114
+ mnemonic==0.19
115
+ more-itertools==8.10.0
116
+ mpmath==1.3.0
117
+ msgpack==1.0.8
118
+ multidict==6.0.5
119
+ multiprocess==0.70.16
120
+ nbclassic==1.0.0
121
+ nbclient==0.9.0
122
+ nbconvert==7.11.0
123
+ nbformat==5.9.2
124
+ nest-asyncio==1.5.8
125
+ networkx==3.0
126
+ notebook==6.5.5
127
+ notebook_shim==0.2.3
128
+ numba==0.60.0
129
+ numpy==1.24.1
130
+ nvidia-ml-py==12.535.161
131
+ nvitop==1.3.2
132
+ oauthlib==3.2.0
133
+ omegaconf==2.3.0
134
+ orjson==3.10.7
135
+ overrides==7.4.0
136
+ packaging==23.2
137
+ pandas==2.2.2
138
+ pandocfilters==1.5.0
139
+ parso==0.8.3
140
+ passlib==1.7.4
141
+ pexpect==4.8.0
142
+ Pillow==9.3.0
143
+ platformdirs==3.11.0
144
+ ply==3.11
145
+ pooch==1.8.2
146
+ prometheus-client==0.18.0
147
+ prompt-toolkit==3.0.39
148
+ protobuf==3.19.6
149
+ psutil==5.9.6
150
+ ptyprocess==0.7.0
151
+ pure-eval==0.2.2
152
+ py-ubjson==0.16.1
153
+ pyarrow==17.0.0
154
+ pyasn1==0.4.8
155
+ pyasn1-modules==0.2.1
156
+ pycparser==2.21
157
+ pydantic==2.9.1
158
+ pydantic_core==2.23.3
159
+ pydub==0.25.1
160
+ Pygments==2.16.1
161
+ PyGObject==3.42.1
162
+ PyHamcrest==2.0.2
163
+ PyJWT==2.3.0
164
+ pyloudnorm==0.1.1
165
+ PyNaCl==1.5.0
166
+ pyOpenSSL==21.0.0
167
+ pyparsing==2.4.7
168
+ pypng==0.0.20
169
+ PyQRCode==1.2.1
170
+ pystoi==0.4.1
171
+ python-apt==2.4.0+ubuntu2
172
+ python-dateutil==2.8.2
173
+ python-json-logger==2.0.7
174
+ python-multipart==0.0.9
175
+ python-snappy==0.5.3
176
+ PyTrie==0.4.0
177
+ pytz==2024.1
178
+ PyYAML==6.0.1
179
+ pyzmq==24.0.1
180
+ randomname==0.2.1
181
+ referencing==0.30.2
182
+ regex==2024.7.24
183
+ requests==2.32.3
184
+ rfc3339-validator==0.1.4
185
+ rfc3986-validator==0.1.1
186
+ rich==13.8.0
187
+ rpds-py==0.12.0
188
+ ruff==0.6.4
189
+ safetensors==0.4.4
190
+ scikit-learn==1.5.1
191
+ scipy==1.14.1
192
+ SecretStorage==3.3.1
193
+ semantic-version==2.10.0
194
+ Send2Trash==1.8.2
195
+ service-identity==18.1.0
196
+ setuptools-scm==8.1.0
197
+ shellingham==1.5.4
198
+ six==1.16.0
199
+ sniffio==1.3.0
200
+ sortedcontainers==2.1.0
201
+ soundfile==0.12.1
202
+ soupsieve==2.5
203
+ soxr==0.5.0.post1
204
+ spake2==0.8
205
+ speechtokenizer==1.0.1
206
+ stack-data==0.6.3
207
+ starlette==0.38.5
208
+ sympy==1.12
209
+ tensorboard==2.17.1
210
+ tensorboard-data-server==0.7.2
211
+ termcolor==2.4.0
212
+ terminado==0.17.1
213
+ threadpoolctl==3.5.0
214
+ tinycss2==1.2.1
215
+ tokenizers==0.19.1
216
+ tomli==2.0.1
217
+ tomlkit==0.12.0
218
+ torch==2.1.0+cu118
219
+ torch-stoi==0.2.1
220
+ torchaudio==2.1.0+cu118
221
+ torchvision==0.16.0+cu118
222
+ tornado==6.3.3
223
+ tqdm==4.66.5
224
+ traitlets==5.13.0
225
+ transformers==4.44.2
226
+ triton==2.1.0
227
+ Twisted==22.1.0
228
+ txaio==21.2.1
229
+ txtorcon==20.0.0
230
+ typer==0.12.5
231
+ types-python-dateutil==2.8.19.14
232
+ typing_extensions==4.12.2
233
+ tzdata==2024.1
234
+ u-msgpack-python==2.3.0
235
+ ujson==5.1.0
236
+ uri-template==1.3.0
237
+ urllib3==2.2.2
238
+ uvicorn==0.30.6
239
+ wadllib==1.3.6
240
+ wcwidth==0.2.9
241
+ webcolors==1.13
242
+ webencodings==0.5.1
243
+ websocket-client==1.6.4
244
+ websockets==12.0
245
+ Werkzeug==3.0.4
246
+ widgetsnbextension==4.0.9
247
+ wsaccel==0.6.3
248
+ xxhash==3.5.0
249
+ yarl==1.9.8
250
+ zipp==1.0.0
251
+ zope.interface==5.4.0