code checkin for HF
Browse files- .gitattributes +0 -28
- .gitignore +133 -0
- LICENSE +201 -0
- README.md +77 -3
- api.py +317 -0
- data/mel_norms.pth +3 -0
- data/riding_hood.txt +54 -0
- data/tokenizer.json +1 -0
- do_tts.py +30 -0
- eval_multiple.py +38 -0
- models/arch_util.py +368 -0
- models/autoregressive.py +577 -0
- models/clvp.py +155 -0
- models/cvvp.py +133 -0
- models/diffusion_decoder.py +331 -0
- models/transformer.py +219 -0
- models/vocoder.py +325 -0
- models/xtransformers.py +1302 -0
- read.py +75 -0
- requirements.txt +10 -0
- sweep.py +65 -0
- tortoise_tts.ipynb +248 -0
- utils/__init__.py +0 -0
- utils/audio.py +143 -0
- utils/diffusion.py +1250 -0
- utils/stft.py +193 -0
- utils/tokenizer.py +187 -0
- utils/typical_sampling.py +33 -0
.gitattributes
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
26 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
results/*
|
24 |
+
pip-wheel-metadata/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
.python-version
|
87 |
+
|
88 |
+
# pipenv
|
89 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
90 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
91 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
92 |
+
# install all needed dependencies.
|
93 |
+
#Pipfile.lock
|
94 |
+
|
95 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
96 |
+
__pypackages__/
|
97 |
+
|
98 |
+
# Celery stuff
|
99 |
+
celerybeat-schedule
|
100 |
+
celerybeat.pid
|
101 |
+
|
102 |
+
# SageMath parsed files
|
103 |
+
*.sage.py
|
104 |
+
|
105 |
+
# Environments
|
106 |
+
.env
|
107 |
+
.venv
|
108 |
+
env/
|
109 |
+
venv/
|
110 |
+
ENV/
|
111 |
+
env.bak/
|
112 |
+
venv.bak/
|
113 |
+
|
114 |
+
# Spyder project settings
|
115 |
+
.spyderproject
|
116 |
+
.spyproject
|
117 |
+
|
118 |
+
# Rope project settings
|
119 |
+
.ropeproject
|
120 |
+
|
121 |
+
# mkdocs documentation
|
122 |
+
/site
|
123 |
+
|
124 |
+
# mypy
|
125 |
+
.mypy_cache/
|
126 |
+
.dmypy.json
|
127 |
+
dmypy.json
|
128 |
+
|
129 |
+
# Pyre type checker
|
130 |
+
.pyre/
|
131 |
+
|
132 |
+
.idea/*
|
133 |
+
.models/*
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,77 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tortoise-TTS
|
2 |
+
|
3 |
+
Tortoise TTS is an experimental text-to-speech program that uses recent machine learning techniques to generate
|
4 |
+
high-quality speech samples.
|
5 |
+
|
6 |
+
This repo contains all the code needed to run Tortoise TTS in inference mode.
|
7 |
+
|
8 |
+
## What's in a name?
|
9 |
+
|
10 |
+
I'm naming my speech-related repos after Mojave desert flora and fauna. Tortoise is a bit tongue in cheek: this model
|
11 |
+
is insanely slow. It leverages both an autoregressive speech alignment model and a diffusion model, both of which
|
12 |
+
are known for their slow inference. It also performs CLIP sampling, which slows things down even further. You can
|
13 |
+
expect ~5 seconds of speech to take ~30 seconds to produce on the latest hardware. Still, the results are pretty cool.
|
14 |
+
|
15 |
+
## What the heck is this?
|
16 |
+
|
17 |
+
Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together.
|
18 |
+
These models are all derived from different repositories which are all linked. All the models have been modified
|
19 |
+
for this use case (some substantially so).
|
20 |
+
|
21 |
+
First, an autoregressive transformer stack predicts discrete speech "tokens" given a text prompt. This model is very
|
22 |
+
similar to the GPT model used by DALLE, except it operates on speech data.
|
23 |
+
Based on: [GPT2 from Transformers](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
24 |
+
|
25 |
+
Next, a CLIP model judges a batch of outputs from the autoregressive transformer against the provided text and stack
|
26 |
+
ranks the outputs according to most probable. You could use greedy or beam-search decoding but in my experience CLIP
|
27 |
+
decoding creates considerably better results.
|
28 |
+
Based on [CLIP from lucidrains](https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py)
|
29 |
+
|
30 |
+
Next, the speech "tokens" are decoded into a low-quality MEL spectrogram using a VQVAE.
|
31 |
+
Based on [VQVAE2 by rosinality](https://github.com/rosinality/vq-vae-2-pytorch)
|
32 |
+
|
33 |
+
Finally, the output of the VQVAE is further decoded by a UNet diffusion model into raw audio, which can be placed in
|
34 |
+
a wav file.
|
35 |
+
Based on [ImprovedDiffusion by openai](https://github.com/openai/improved-diffusion)
|
36 |
+
|
37 |
+
## How do I use this?
|
38 |
+
|
39 |
+
Check out the colab: https://colab.research.google.com/drive/1wVVqUPqwiDBUVeWWOUNglpGhU3hg_cbR?usp=sharing
|
40 |
+
|
41 |
+
Or on a computer with a GPU (with >=16GB of VRAM):
|
42 |
+
```shell
|
43 |
+
git clone https://github.com/neonbjb/tortoise-tts.git
|
44 |
+
cd tortoise-tts
|
45 |
+
pip install -r requirements.txt
|
46 |
+
python do_tts.py
|
47 |
+
```
|
48 |
+
|
49 |
+
## Hand-picked TTS samples
|
50 |
+
|
51 |
+
I generated ~250 samples from 23 text prompts and 8 voices. The text prompts have never been seen by the model. The
|
52 |
+
voices were pulled from the training set.
|
53 |
+
|
54 |
+
All of the samples can be found in the results/ folder of this repo. I handpicked a few to show what the model is capable of:
|
55 |
+
|
56 |
+
- [Atkins - Road not taken](results/favorites/atkins_road_not_taken.wav)
|
57 |
+
- [Dotrice - Rolling Stone interview](results/favorites/dotrice_rollingstone.wav)
|
58 |
+
- [Dotrice - 'Ornaments' from tacotron test set](results/favorites/dotrice_tacotron_samp1.wav)
|
59 |
+
- [Kennard - 'Acute emotional intelligence' from tacotron test set](results/favorites/kennard_tacotron_samp2.wav)
|
60 |
+
- [Mol - Because I could not stop for death](results/favorites/mol_dickenson.wav)
|
61 |
+
- [Mol - Obama](results/favorites/mol_obama.wav)
|
62 |
+
|
63 |
+
Prosody is remarkably good for poetry, despite the fact that it was never trained on poetry.
|
64 |
+
|
65 |
+
## How do I train this?
|
66 |
+
|
67 |
+
Frankly - you don't. Building this model has been a labor of love for me, consuming most of my 6 RTX3090s worth of
|
68 |
+
resources for the better part of 6 months. It uses a dataset I've gathered, refined and transcribed that consists of
|
69 |
+
a lot of audio data which I cannot distribute because of copywrite or no open licenses.
|
70 |
+
|
71 |
+
With that said, I'm willing to help you out if you really want to give it a shot. DM me.
|
72 |
+
|
73 |
+
## Looking forward
|
74 |
+
|
75 |
+
I'm not satisfied with this yet. Treat this as a "sneak peek" and check back in a couple of months. I think the concept
|
76 |
+
is sound, but there are a few hurdles to overcome to get sample quality up. I have been doing major tweaks to the
|
77 |
+
diffusion model and should have something new and much better soon.
|
api.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from urllib import request
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import progressbar
|
9 |
+
|
10 |
+
from models.cvvp import CVVP
|
11 |
+
from models.diffusion_decoder import DiffusionTts
|
12 |
+
from models.autoregressive import UnifiedVoice
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from models.arch_util import TorchMelSpectrogram
|
16 |
+
from models.clvp import CLVP
|
17 |
+
from models.vocoder import UnivNetGenerator
|
18 |
+
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
|
19 |
+
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
20 |
+
from utils.tokenizer import VoiceBpeTokenizer, lev_distance
|
21 |
+
|
22 |
+
|
23 |
+
pbar = None
|
24 |
+
def download_models():
|
25 |
+
MODELS = {
|
26 |
+
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
|
27 |
+
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
|
28 |
+
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
|
29 |
+
}
|
30 |
+
os.makedirs('.models', exist_ok=True)
|
31 |
+
def show_progress(block_num, block_size, total_size):
|
32 |
+
global pbar
|
33 |
+
if pbar is None:
|
34 |
+
pbar = progressbar.ProgressBar(maxval=total_size)
|
35 |
+
pbar.start()
|
36 |
+
|
37 |
+
downloaded = block_num * block_size
|
38 |
+
if downloaded < total_size:
|
39 |
+
pbar.update(downloaded)
|
40 |
+
else:
|
41 |
+
pbar.finish()
|
42 |
+
pbar = None
|
43 |
+
for model_name, url in MODELS.items():
|
44 |
+
if os.path.exists(f'.models/{model_name}'):
|
45 |
+
continue
|
46 |
+
print(f'Downloading {model_name} from {url}...')
|
47 |
+
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
48 |
+
print('Done.')
|
49 |
+
|
50 |
+
|
51 |
+
def pad_or_truncate(t, length):
|
52 |
+
if t.shape[-1] == length:
|
53 |
+
return t
|
54 |
+
elif t.shape[-1] < length:
|
55 |
+
return F.pad(t, (0, length-t.shape[-1]))
|
56 |
+
else:
|
57 |
+
return t[..., :length]
|
58 |
+
|
59 |
+
|
60 |
+
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
|
61 |
+
"""
|
62 |
+
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
63 |
+
"""
|
64 |
+
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
65 |
+
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
66 |
+
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
|
67 |
+
|
68 |
+
|
69 |
+
def load_conditioning(clip, cond_length=132300):
|
70 |
+
gap = clip.shape[-1] - cond_length
|
71 |
+
if gap < 0:
|
72 |
+
clip = F.pad(clip, pad=(0, abs(gap)))
|
73 |
+
elif gap > 0:
|
74 |
+
rand_start = random.randint(0, gap)
|
75 |
+
clip = clip[:, rand_start:rand_start + cond_length]
|
76 |
+
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
|
77 |
+
return mel_clip.unsqueeze(0).cuda()
|
78 |
+
|
79 |
+
|
80 |
+
def clip_guided_generation(autoregressive_model, clip_model, conditioning_input, text_input, num_batches, stop_mel_token,
|
81 |
+
tokens_per_clip_inference=10, clip_results_to_reduce_to=8, **generation_kwargs):
|
82 |
+
"""
|
83 |
+
Uses a CLVP model trained to associate full text with **partial** audio clips to pick the best generation candidates
|
84 |
+
every few iterations. The top results are then propagated forward through the generation process. Rinse and repeat.
|
85 |
+
This is a hybrid between beam search and sampling.
|
86 |
+
"""
|
87 |
+
token_goal = tokens_per_clip_inference
|
88 |
+
finished = False
|
89 |
+
while not finished and token_goal < autoregressive_model.max_mel_tokens:
|
90 |
+
samples = []
|
91 |
+
for b in tqdm(range(num_batches)):
|
92 |
+
codes = autoregressive_model.inference_speech(conditioning_input, text_input, **generation_kwargs)
|
93 |
+
samples.append(codes)
|
94 |
+
for batch in samples:
|
95 |
+
for i in range(batch.shape[0]):
|
96 |
+
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token, complain=False)
|
97 |
+
clip_results.append(clip_model(text_input.repeat(batch.shape[0], 1), batch, return_loss=False))
|
98 |
+
clip_results = torch.cat(clip_results, dim=0)
|
99 |
+
samples = torch.cat(samples, dim=0)
|
100 |
+
best_results = samples[torch.topk(clip_results, k=clip_results_to_reduce_to).indices]
|
101 |
+
|
102 |
+
|
103 |
+
def fix_autoregressive_output(codes, stop_token, complain=True):
|
104 |
+
"""
|
105 |
+
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
|
106 |
+
trained on and what the autoregressive code generator creates (which has no padding or end).
|
107 |
+
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
|
108 |
+
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
|
109 |
+
and copying out the last few codes.
|
110 |
+
|
111 |
+
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
|
112 |
+
"""
|
113 |
+
# Strip off the autoregressive stop token and add padding.
|
114 |
+
stop_token_indices = (codes == stop_token).nonzero()
|
115 |
+
if len(stop_token_indices) == 0:
|
116 |
+
if complain:
|
117 |
+
print("No stop tokens found, enjoy that output of yours!")
|
118 |
+
return codes
|
119 |
+
else:
|
120 |
+
codes[stop_token_indices] = 83
|
121 |
+
stm = stop_token_indices.min().item()
|
122 |
+
codes[stm:] = 83
|
123 |
+
if stm - 3 < codes.shape[0]:
|
124 |
+
codes[-3] = 45
|
125 |
+
codes[-2] = 45
|
126 |
+
codes[-1] = 248
|
127 |
+
|
128 |
+
return codes
|
129 |
+
|
130 |
+
|
131 |
+
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1):
|
132 |
+
"""
|
133 |
+
Uses the specified diffusion model to convert discrete codes into a spectrogram.
|
134 |
+
"""
|
135 |
+
with torch.no_grad():
|
136 |
+
cond_mels = []
|
137 |
+
for sample in conditioning_samples:
|
138 |
+
sample = pad_or_truncate(sample, 102400)
|
139 |
+
cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False)
|
140 |
+
cond_mels.append(cond_mel)
|
141 |
+
cond_mels = torch.stack(cond_mels, dim=1)
|
142 |
+
|
143 |
+
output_seq_len = mel_codes.shape[1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
144 |
+
output_shape = (mel_codes.shape[0], 100, output_seq_len)
|
145 |
+
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
|
146 |
+
|
147 |
+
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
|
148 |
+
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
149 |
+
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
150 |
+
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
151 |
+
|
152 |
+
|
153 |
+
class TextToSpeech:
|
154 |
+
def __init__(self, autoregressive_batch_size=16):
|
155 |
+
self.autoregressive_batch_size = autoregressive_batch_size
|
156 |
+
self.tokenizer = VoiceBpeTokenizer()
|
157 |
+
download_models()
|
158 |
+
|
159 |
+
self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
160 |
+
model_dim=1024,
|
161 |
+
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
162 |
+
train_solo_embeddings=False,
|
163 |
+
average_conditioning_embeddings=True).cpu().eval()
|
164 |
+
self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
|
165 |
+
'''
|
166 |
+
self.autoregressive = UnifiedVoice(max_mel_tokens=2048, max_text_tokens=1024, max_conditioning_inputs=1, layers=42,
|
167 |
+
model_dim=1152, heads=18, number_text_tokens=256, train_solo_embeddings=False,
|
168 |
+
average_conditioning_embeddings=True, types=2).cpu().eval()
|
169 |
+
self.autoregressive.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_tts_xl\\models\\15250_gpt_ema.pth'))
|
170 |
+
'''
|
171 |
+
|
172 |
+
self.autoregressive_for_diffusion = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
|
173 |
+
model_dim=1024,
|
174 |
+
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
175 |
+
train_solo_embeddings=False,
|
176 |
+
average_conditioning_embeddings=True).cpu().eval()
|
177 |
+
self.autoregressive_for_diffusion.load_state_dict(torch.load('.models/autoregressive.pth'))
|
178 |
+
|
179 |
+
self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
180 |
+
text_seq_len=350, text_heads=8,
|
181 |
+
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
|
182 |
+
use_xformers=True).cpu().eval()
|
183 |
+
self.clvp.load_state_dict(torch.load('.models/clip.pth'))
|
184 |
+
|
185 |
+
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
|
186 |
+
speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
|
187 |
+
self.cvvp.load_state_dict(torch.load('.models/cvvp.pth'))
|
188 |
+
|
189 |
+
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
190 |
+
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
191 |
+
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
192 |
+
self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder.pth'))
|
193 |
+
|
194 |
+
self.vocoder = UnivNetGenerator().cpu()
|
195 |
+
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
196 |
+
self.vocoder.eval(inference=True)
|
197 |
+
|
198 |
+
def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs):
|
199 |
+
"""
|
200 |
+
Calls TTS with one of a set of preset generation parameters. Options:
|
201 |
+
'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest).
|
202 |
+
'fast': Decent quality speech at a decent inference rate. A good choice for mass inference.
|
203 |
+
'standard': Very good quality. This is generally about as good as you are going to get.
|
204 |
+
'high_quality': Use if you want the absolute best. This is not really worth the compute, though.
|
205 |
+
"""
|
206 |
+
# Use generally found best tuning knobs for generation.
|
207 |
+
kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
|
208 |
+
#'typical_sampling': True,
|
209 |
+
'top_p': .8,
|
210 |
+
'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
|
211 |
+
# Presets are defined here.
|
212 |
+
presets = {
|
213 |
+
'ultra_fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 16, 'cond_free': False},
|
214 |
+
'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32},
|
215 |
+
'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128},
|
216 |
+
'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 2048},
|
217 |
+
}
|
218 |
+
kwargs.update(presets[preset])
|
219 |
+
return self.tts(text, voice_samples, **kwargs)
|
220 |
+
|
221 |
+
def tts(self, text, voice_samples, k=1,
|
222 |
+
# autoregressive generation parameters follow
|
223 |
+
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
|
224 |
+
# CLVP & CVVP parameters
|
225 |
+
clvp_cvvp_slider=.5,
|
226 |
+
# diffusion generation parameters follow
|
227 |
+
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
|
228 |
+
**hf_generate_kwargs):
|
229 |
+
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
230 |
+
text = F.pad(text, (0, 1)) # This may not be necessary.
|
231 |
+
|
232 |
+
conds = []
|
233 |
+
if not isinstance(voice_samples, list):
|
234 |
+
voice_samples = [voice_samples]
|
235 |
+
for vs in voice_samples:
|
236 |
+
conds.append(load_conditioning(vs))
|
237 |
+
conds = torch.stack(conds, dim=1)
|
238 |
+
|
239 |
+
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
240 |
+
|
241 |
+
with torch.no_grad():
|
242 |
+
samples = []
|
243 |
+
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
|
244 |
+
stop_mel_token = self.autoregressive.stop_mel_token
|
245 |
+
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
246 |
+
self.autoregressive = self.autoregressive.cuda()
|
247 |
+
for b in tqdm(range(num_batches)):
|
248 |
+
codes = self.autoregressive.inference_speech(conds, text,
|
249 |
+
do_sample=True,
|
250 |
+
top_p=top_p,
|
251 |
+
temperature=temperature,
|
252 |
+
num_return_sequences=self.autoregressive_batch_size,
|
253 |
+
length_penalty=length_penalty,
|
254 |
+
repetition_penalty=repetition_penalty,
|
255 |
+
max_generate_length=max_mel_tokens,
|
256 |
+
**hf_generate_kwargs)
|
257 |
+
padding_needed = max_mel_tokens - codes.shape[1]
|
258 |
+
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
259 |
+
samples.append(codes)
|
260 |
+
self.autoregressive = self.autoregressive.cpu()
|
261 |
+
|
262 |
+
clip_results = []
|
263 |
+
self.clvp = self.clvp.cuda()
|
264 |
+
self.cvvp = self.cvvp.cuda()
|
265 |
+
for batch in samples:
|
266 |
+
for i in range(batch.shape[0]):
|
267 |
+
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
268 |
+
clvp = self.clvp(text.repeat(batch.shape[0], 1), batch, return_loss=False)
|
269 |
+
cvvp_accumulator = 0
|
270 |
+
for cl in range(conds.shape[1]):
|
271 |
+
cvvp_accumulator = cvvp_accumulator + self.cvvp(conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False )
|
272 |
+
cvvp = cvvp_accumulator / conds.shape[1]
|
273 |
+
clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider))
|
274 |
+
clip_results = torch.cat(clip_results, dim=0)
|
275 |
+
samples = torch.cat(samples, dim=0)
|
276 |
+
best_results = samples[torch.topk(clip_results, k=k).indices]
|
277 |
+
self.clvp = self.clvp.cpu()
|
278 |
+
self.cvvp = self.cvvp.cpu()
|
279 |
+
del samples
|
280 |
+
|
281 |
+
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
282 |
+
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
283 |
+
# results, but will increase memory usage.
|
284 |
+
self.autoregressive_for_diffusion = self.autoregressive_for_diffusion.cuda()
|
285 |
+
best_latents = self.autoregressive_for_diffusion(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results,
|
286 |
+
torch.tensor([best_results.shape[-1]*self.autoregressive_for_diffusion.mel_length_compression], device=conds.device),
|
287 |
+
return_latent=True, clip_inputs=False)
|
288 |
+
self.autoregressive_for_diffusion = self.autoregressive_for_diffusion.cpu()
|
289 |
+
|
290 |
+
print("Performing vocoding..")
|
291 |
+
wav_candidates = []
|
292 |
+
self.diffusion = self.diffusion.cuda()
|
293 |
+
self.vocoder = self.vocoder.cuda()
|
294 |
+
for b in range(best_results.shape[0]):
|
295 |
+
codes = best_results[b].unsqueeze(0)
|
296 |
+
latents = best_latents[b].unsqueeze(0)
|
297 |
+
|
298 |
+
# Find the first occurrence of the "calm" token and trim the codes to that.
|
299 |
+
ctokens = 0
|
300 |
+
for k in range(codes.shape[-1]):
|
301 |
+
if codes[0, k] == calm_token:
|
302 |
+
ctokens += 1
|
303 |
+
else:
|
304 |
+
ctokens = 0
|
305 |
+
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
306 |
+
latents = latents[:, :k]
|
307 |
+
break
|
308 |
+
|
309 |
+
mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature)
|
310 |
+
wav = self.vocoder.inference(mel)
|
311 |
+
wav_candidates.append(wav.cpu())
|
312 |
+
self.diffusion = self.diffusion.cpu()
|
313 |
+
self.vocoder = self.vocoder.cpu()
|
314 |
+
|
315 |
+
if len(wav_candidates) > 1:
|
316 |
+
return wav_candidates
|
317 |
+
return wav_candidates[0]
|
data/mel_norms.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f69422a8a8f344c4fca2f0c6b8d41d2151d6615b7321e48e6bb15ae949b119c
|
3 |
+
size 1067
|
data/riding_hood.txt
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her. It suited the girl so extremely well that everybody called her Little Red Riding Hood.
|
2 |
+
One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."
|
3 |
+
|
4 |
+
Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village.
|
5 |
+
|
6 |
+
As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest. He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother."
|
7 |
+
|
8 |
+
"Does she live far off?" said the wolf
|
9 |
+
|
10 |
+
"Oh I say," answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village."
|
11 |
+
|
12 |
+
"Well," said the wolf, "and I'll go and see her too. I'll go this way and go you that, and we shall see who will be there first."
|
13 |
+
|
14 |
+
The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers. It was not long before the wolf arrived at the old woman's house. He knocked at the door: tap, tap.
|
15 |
+
|
16 |
+
"Who's there?"
|
17 |
+
|
18 |
+
"Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."
|
19 |
+
|
20 |
+
The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."
|
21 |
+
|
22 |
+
The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten. He then shut the door and got into the grandmother's bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap.
|
23 |
+
|
24 |
+
"Who's there?"
|
25 |
+
|
26 |
+
Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."
|
27 |
+
|
28 |
+
The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up."
|
29 |
+
|
30 |
+
Little Red Riding Hood pulled the bobbin, and the door opened.
|
31 |
+
|
32 |
+
The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me."
|
33 |
+
|
34 |
+
Little Red Riding Hood took off her clothes and got into bed. She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!"
|
35 |
+
|
36 |
+
"All the better to hug you with, my dear."
|
37 |
+
|
38 |
+
"Grandmother, what big legs you have!"
|
39 |
+
|
40 |
+
"All the better to run with, my child."
|
41 |
+
|
42 |
+
"Grandmother, what big ears you have!"
|
43 |
+
|
44 |
+
"All the better to hear with, my child."
|
45 |
+
|
46 |
+
"Grandmother, what big eyes you have!"
|
47 |
+
|
48 |
+
"All the better to see with, my child."
|
49 |
+
|
50 |
+
"Grandmother, what big teeth you have got!"
|
51 |
+
|
52 |
+
"All the better to eat you up with."
|
53 |
+
|
54 |
+
And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.
|
data/tokenizer.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
|
do_tts.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
from api import TextToSpeech
|
7 |
+
from utils.audio import load_audio, get_voices
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument('--text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
12 |
+
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
|
13 |
+
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='patrick_stewart')
|
14 |
+
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
15 |
+
args = parser.parse_args()
|
16 |
+
os.makedirs(args.output_path, exist_ok=True)
|
17 |
+
|
18 |
+
tts = TextToSpeech()
|
19 |
+
|
20 |
+
voices = get_voices()
|
21 |
+
selected_voices = args.voice.split(',')
|
22 |
+
for voice in selected_voices:
|
23 |
+
cond_paths = voices[voice]
|
24 |
+
conds = []
|
25 |
+
for cond_path in cond_paths:
|
26 |
+
c = load_audio(cond_path, 22050)
|
27 |
+
conds.append(c)
|
28 |
+
gen = tts.tts_with_preset(args.text, conds, preset='standard')
|
29 |
+
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), gen.squeeze(0).cpu(), 24000)
|
30 |
+
|
eval_multiple.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
from api import TextToSpeech
|
6 |
+
from utils.audio import load_audio
|
7 |
+
|
8 |
+
if __name__ == '__main__':
|
9 |
+
fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
|
10 |
+
stop_after = 128
|
11 |
+
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\audiobooks'
|
12 |
+
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
13 |
+
|
14 |
+
os.makedirs(outpath_real, exist_ok=True)
|
15 |
+
with open(fname, 'r', encoding='utf-8') as f:
|
16 |
+
lines = [l.strip().split('\t') for l in f.readlines()]
|
17 |
+
|
18 |
+
tts = TextToSpeech()
|
19 |
+
for k in range(3):
|
20 |
+
outpath = f'{outpath_base}_{k}'
|
21 |
+
os.makedirs(outpath, exist_ok=True)
|
22 |
+
recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8')
|
23 |
+
for e, line in enumerate(lines):
|
24 |
+
if e >= stop_after:
|
25 |
+
break
|
26 |
+
transcript = line[0]
|
27 |
+
path = os.path.join(os.path.dirname(fname), line[1])
|
28 |
+
cond_audio = load_audio(path, 22050)
|
29 |
+
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
30 |
+
sample = tts.tts_with_preset(transcript, [cond_audio, cond_audio], preset='standard')
|
31 |
+
|
32 |
+
down = torchaudio.functional.resample(sample, 24000, 22050)
|
33 |
+
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
34 |
+
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
35 |
+
|
36 |
+
recorder.write(f'{transcript}\t{fout_path}\n')
|
37 |
+
recorder.flush()
|
38 |
+
recorder.close()
|
models/arch_util.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
from x_transformers import ContinuousTransformerWrapper
|
9 |
+
from x_transformers.x_transformers import RelativePositionBias
|
10 |
+
|
11 |
+
|
12 |
+
def zero_module(module):
|
13 |
+
"""
|
14 |
+
Zero out the parameters of a module and return it.
|
15 |
+
"""
|
16 |
+
for p in module.parameters():
|
17 |
+
p.detach().zero_()
|
18 |
+
return module
|
19 |
+
|
20 |
+
|
21 |
+
class GroupNorm32(nn.GroupNorm):
|
22 |
+
def forward(self, x):
|
23 |
+
return super().forward(x.float()).type(x.dtype)
|
24 |
+
|
25 |
+
|
26 |
+
def normalization(channels):
|
27 |
+
"""
|
28 |
+
Make a standard normalization layer.
|
29 |
+
|
30 |
+
:param channels: number of input channels.
|
31 |
+
:return: an nn.Module for normalization.
|
32 |
+
"""
|
33 |
+
groups = 32
|
34 |
+
if channels <= 16:
|
35 |
+
groups = 8
|
36 |
+
elif channels <= 64:
|
37 |
+
groups = 16
|
38 |
+
while channels % groups != 0:
|
39 |
+
groups = int(groups / 2)
|
40 |
+
assert groups > 2
|
41 |
+
return GroupNorm32(groups, channels)
|
42 |
+
|
43 |
+
|
44 |
+
class QKVAttentionLegacy(nn.Module):
|
45 |
+
"""
|
46 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, n_heads):
|
50 |
+
super().__init__()
|
51 |
+
self.n_heads = n_heads
|
52 |
+
|
53 |
+
def forward(self, qkv, mask=None, rel_pos=None):
|
54 |
+
"""
|
55 |
+
Apply QKV attention.
|
56 |
+
|
57 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
58 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
59 |
+
"""
|
60 |
+
bs, width, length = qkv.shape
|
61 |
+
assert width % (3 * self.n_heads) == 0
|
62 |
+
ch = width // (3 * self.n_heads)
|
63 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
64 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
65 |
+
weight = torch.einsum(
|
66 |
+
"bct,bcs->bts", q * scale, k * scale
|
67 |
+
) # More stable with f16 than dividing afterwards
|
68 |
+
if rel_pos is not None:
|
69 |
+
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
70 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
71 |
+
if mask is not None:
|
72 |
+
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
73 |
+
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
74 |
+
weight = weight * mask
|
75 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
76 |
+
|
77 |
+
return a.reshape(bs, -1, length)
|
78 |
+
|
79 |
+
|
80 |
+
class AttentionBlock(nn.Module):
|
81 |
+
"""
|
82 |
+
An attention block that allows spatial positions to attend to each other.
|
83 |
+
|
84 |
+
Originally ported from here, but adapted to the N-d case.
|
85 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
channels,
|
91 |
+
num_heads=1,
|
92 |
+
num_head_channels=-1,
|
93 |
+
do_checkpoint=True,
|
94 |
+
relative_pos_embeddings=False,
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.channels = channels
|
98 |
+
self.do_checkpoint = do_checkpoint
|
99 |
+
if num_head_channels == -1:
|
100 |
+
self.num_heads = num_heads
|
101 |
+
else:
|
102 |
+
assert (
|
103 |
+
channels % num_head_channels == 0
|
104 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
105 |
+
self.num_heads = channels // num_head_channels
|
106 |
+
self.norm = normalization(channels)
|
107 |
+
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
108 |
+
# split heads before split qkv
|
109 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
110 |
+
|
111 |
+
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
112 |
+
if relative_pos_embeddings:
|
113 |
+
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
114 |
+
else:
|
115 |
+
self.relative_pos_embeddings = None
|
116 |
+
|
117 |
+
def forward(self, x, mask=None):
|
118 |
+
b, c, *spatial = x.shape
|
119 |
+
x = x.reshape(b, c, -1)
|
120 |
+
qkv = self.qkv(self.norm(x))
|
121 |
+
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
122 |
+
h = self.proj_out(h)
|
123 |
+
return (x + h).reshape(b, c, *spatial)
|
124 |
+
|
125 |
+
|
126 |
+
class Upsample(nn.Module):
|
127 |
+
"""
|
128 |
+
An upsampling layer with an optional convolution.
|
129 |
+
|
130 |
+
:param channels: channels in the inputs and outputs.
|
131 |
+
:param use_conv: a bool determining if a convolution is applied.
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, channels, use_conv, out_channels=None, factor=4):
|
135 |
+
super().__init__()
|
136 |
+
self.channels = channels
|
137 |
+
self.out_channels = out_channels or channels
|
138 |
+
self.use_conv = use_conv
|
139 |
+
self.factor = factor
|
140 |
+
if use_conv:
|
141 |
+
ksize = 5
|
142 |
+
pad = 2
|
143 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
assert x.shape[1] == self.channels
|
147 |
+
x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
|
148 |
+
if self.use_conv:
|
149 |
+
x = self.conv(x)
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
class Downsample(nn.Module):
|
154 |
+
"""
|
155 |
+
A downsampling layer with an optional convolution.
|
156 |
+
|
157 |
+
:param channels: channels in the inputs and outputs.
|
158 |
+
:param use_conv: a bool determining if a convolution is applied.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
|
162 |
+
super().__init__()
|
163 |
+
self.channels = channels
|
164 |
+
self.out_channels = out_channels or channels
|
165 |
+
self.use_conv = use_conv
|
166 |
+
|
167 |
+
stride = factor
|
168 |
+
if use_conv:
|
169 |
+
self.op = nn.Conv1d(
|
170 |
+
self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
assert self.channels == self.out_channels
|
174 |
+
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
assert x.shape[1] == self.channels
|
178 |
+
return self.op(x)
|
179 |
+
|
180 |
+
|
181 |
+
class ResBlock(nn.Module):
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
channels,
|
185 |
+
dropout,
|
186 |
+
out_channels=None,
|
187 |
+
use_conv=False,
|
188 |
+
use_scale_shift_norm=False,
|
189 |
+
up=False,
|
190 |
+
down=False,
|
191 |
+
kernel_size=3,
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
self.channels = channels
|
195 |
+
self.dropout = dropout
|
196 |
+
self.out_channels = out_channels or channels
|
197 |
+
self.use_conv = use_conv
|
198 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
199 |
+
padding = 1 if kernel_size == 3 else 2
|
200 |
+
|
201 |
+
self.in_layers = nn.Sequential(
|
202 |
+
normalization(channels),
|
203 |
+
nn.SiLU(),
|
204 |
+
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
|
205 |
+
)
|
206 |
+
|
207 |
+
self.updown = up or down
|
208 |
+
|
209 |
+
if up:
|
210 |
+
self.h_upd = Upsample(channels, False)
|
211 |
+
self.x_upd = Upsample(channels, False)
|
212 |
+
elif down:
|
213 |
+
self.h_upd = Downsample(channels, False)
|
214 |
+
self.x_upd = Downsample(channels, False)
|
215 |
+
else:
|
216 |
+
self.h_upd = self.x_upd = nn.Identity()
|
217 |
+
|
218 |
+
self.out_layers = nn.Sequential(
|
219 |
+
normalization(self.out_channels),
|
220 |
+
nn.SiLU(),
|
221 |
+
nn.Dropout(p=dropout),
|
222 |
+
zero_module(
|
223 |
+
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
224 |
+
),
|
225 |
+
)
|
226 |
+
|
227 |
+
if self.out_channels == channels:
|
228 |
+
self.skip_connection = nn.Identity()
|
229 |
+
elif use_conv:
|
230 |
+
self.skip_connection = nn.Conv1d(
|
231 |
+
channels, self.out_channels, kernel_size, padding=padding
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
if self.updown:
|
238 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
239 |
+
h = in_rest(x)
|
240 |
+
h = self.h_upd(h)
|
241 |
+
x = self.x_upd(x)
|
242 |
+
h = in_conv(h)
|
243 |
+
else:
|
244 |
+
h = self.in_layers(x)
|
245 |
+
h = self.out_layers(h)
|
246 |
+
return self.skip_connection(x) + h
|
247 |
+
|
248 |
+
|
249 |
+
class AudioMiniEncoder(nn.Module):
|
250 |
+
def __init__(self,
|
251 |
+
spec_dim,
|
252 |
+
embedding_dim,
|
253 |
+
base_channels=128,
|
254 |
+
depth=2,
|
255 |
+
resnet_blocks=2,
|
256 |
+
attn_blocks=4,
|
257 |
+
num_attn_heads=4,
|
258 |
+
dropout=0,
|
259 |
+
downsample_factor=2,
|
260 |
+
kernel_size=3):
|
261 |
+
super().__init__()
|
262 |
+
self.init = nn.Sequential(
|
263 |
+
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
|
264 |
+
)
|
265 |
+
ch = base_channels
|
266 |
+
res = []
|
267 |
+
for l in range(depth):
|
268 |
+
for r in range(resnet_blocks):
|
269 |
+
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
|
270 |
+
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
|
271 |
+
ch *= 2
|
272 |
+
self.res = nn.Sequential(*res)
|
273 |
+
self.final = nn.Sequential(
|
274 |
+
normalization(ch),
|
275 |
+
nn.SiLU(),
|
276 |
+
nn.Conv1d(ch, embedding_dim, 1)
|
277 |
+
)
|
278 |
+
attn = []
|
279 |
+
for a in range(attn_blocks):
|
280 |
+
attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
|
281 |
+
self.attn = nn.Sequential(*attn)
|
282 |
+
self.dim = embedding_dim
|
283 |
+
|
284 |
+
def forward(self, x):
|
285 |
+
h = self.init(x)
|
286 |
+
h = self.res(h)
|
287 |
+
h = self.final(h)
|
288 |
+
h = self.attn(h)
|
289 |
+
return h[:, :, 0]
|
290 |
+
|
291 |
+
|
292 |
+
class TorchMelSpectrogram(nn.Module):
|
293 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
|
294 |
+
sampling_rate=22050, normalize=False, mel_norm_file='data/mel_norms.pth'):
|
295 |
+
super().__init__()
|
296 |
+
# These are the default tacotron values for the MEL spectrogram.
|
297 |
+
self.filter_length = filter_length
|
298 |
+
self.hop_length = hop_length
|
299 |
+
self.win_length = win_length
|
300 |
+
self.n_mel_channels = n_mel_channels
|
301 |
+
self.mel_fmin = mel_fmin
|
302 |
+
self.mel_fmax = mel_fmax
|
303 |
+
self.sampling_rate = sampling_rate
|
304 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
|
305 |
+
win_length=self.win_length, power=2, normalized=normalize,
|
306 |
+
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
307 |
+
f_max=self.mel_fmax, n_mels=self.n_mel_channels,
|
308 |
+
norm="slaney")
|
309 |
+
self.mel_norm_file = mel_norm_file
|
310 |
+
if self.mel_norm_file is not None:
|
311 |
+
self.mel_norms = torch.load(self.mel_norm_file)
|
312 |
+
else:
|
313 |
+
self.mel_norms = None
|
314 |
+
|
315 |
+
def forward(self, inp):
|
316 |
+
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
317 |
+
inp = inp.squeeze(1)
|
318 |
+
assert len(inp.shape) == 2
|
319 |
+
self.mel_stft = self.mel_stft.to(inp.device)
|
320 |
+
mel = self.mel_stft(inp)
|
321 |
+
# Perform dynamic range compression
|
322 |
+
mel = torch.log(torch.clamp(mel, min=1e-5))
|
323 |
+
if self.mel_norms is not None:
|
324 |
+
self.mel_norms = self.mel_norms.to(mel.device)
|
325 |
+
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
326 |
+
return mel
|
327 |
+
|
328 |
+
|
329 |
+
class CheckpointedLayer(nn.Module):
|
330 |
+
"""
|
331 |
+
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
332 |
+
checkpoint for all other args.
|
333 |
+
"""
|
334 |
+
def __init__(self, wrap):
|
335 |
+
super().__init__()
|
336 |
+
self.wrap = wrap
|
337 |
+
|
338 |
+
def forward(self, x, *args, **kwargs):
|
339 |
+
for k, v in kwargs.items():
|
340 |
+
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
341 |
+
partial = functools.partial(self.wrap, **kwargs)
|
342 |
+
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
343 |
+
|
344 |
+
|
345 |
+
class CheckpointedXTransformerEncoder(nn.Module):
|
346 |
+
"""
|
347 |
+
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
348 |
+
to channels-last that XTransformer expects.
|
349 |
+
"""
|
350 |
+
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
|
351 |
+
super().__init__()
|
352 |
+
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
353 |
+
self.needs_permute = needs_permute
|
354 |
+
self.exit_permute = exit_permute
|
355 |
+
|
356 |
+
if not checkpoint:
|
357 |
+
return
|
358 |
+
for i in range(len(self.transformer.attn_layers.layers)):
|
359 |
+
n, b, r = self.transformer.attn_layers.layers[i]
|
360 |
+
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
361 |
+
|
362 |
+
def forward(self, x, **kwargs):
|
363 |
+
if self.needs_permute:
|
364 |
+
x = x.permute(0,2,1)
|
365 |
+
h = self.transformer(x, **kwargs)
|
366 |
+
if self.exit_permute:
|
367 |
+
h = h.permute(0,2,1)
|
368 |
+
return h
|
models/autoregressive.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
7 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
8 |
+
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
9 |
+
from models.arch_util import AttentionBlock
|
10 |
+
from utils.typical_sampling import TypicalLogitsWarper
|
11 |
+
|
12 |
+
|
13 |
+
def null_position_embeddings(range, dim):
|
14 |
+
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
15 |
+
|
16 |
+
|
17 |
+
class ResBlock(nn.Module):
|
18 |
+
"""
|
19 |
+
Basic residual convolutional block that uses GroupNorm.
|
20 |
+
"""
|
21 |
+
def __init__(self, chan):
|
22 |
+
super().__init__()
|
23 |
+
self.net = nn.Sequential(
|
24 |
+
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
25 |
+
nn.GroupNorm(chan//8, chan),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
28 |
+
nn.GroupNorm(chan//8, chan)
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return F.relu(self.net(x) + x)
|
33 |
+
|
34 |
+
|
35 |
+
class GPT2InferenceModel(GPT2PreTrainedModel):
|
36 |
+
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
|
37 |
+
super().__init__(config)
|
38 |
+
self.transformer = gpt
|
39 |
+
self.text_pos_embedding = text_pos_emb
|
40 |
+
self.embeddings = embeddings
|
41 |
+
self.lm_head = nn.Sequential(norm, linear)
|
42 |
+
|
43 |
+
# Model parallel
|
44 |
+
self.model_parallel = False
|
45 |
+
self.device_map = None
|
46 |
+
self.cached_mel_emb = None
|
47 |
+
|
48 |
+
def parallelize(self, device_map=None):
|
49 |
+
self.device_map = (
|
50 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
51 |
+
if device_map is None
|
52 |
+
else device_map
|
53 |
+
)
|
54 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
55 |
+
self.transformer.parallelize(self.device_map)
|
56 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
57 |
+
self.model_parallel = True
|
58 |
+
|
59 |
+
def deparallelize(self):
|
60 |
+
self.transformer.deparallelize()
|
61 |
+
self.transformer = self.transformer.to("cpu")
|
62 |
+
self.lm_head = self.lm_head.to("cpu")
|
63 |
+
self.model_parallel = False
|
64 |
+
torch.cuda.empty_cache()
|
65 |
+
|
66 |
+
def get_output_embeddings(self):
|
67 |
+
return self.lm_head
|
68 |
+
|
69 |
+
def set_output_embeddings(self, new_embeddings):
|
70 |
+
self.lm_head = new_embeddings
|
71 |
+
|
72 |
+
def store_mel_emb(self, mel_emb):
|
73 |
+
self.cached_mel_emb = mel_emb
|
74 |
+
|
75 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
76 |
+
|
77 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
78 |
+
# only last token for inputs_ids if past is defined in kwargs
|
79 |
+
if past:
|
80 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
81 |
+
if token_type_ids is not None:
|
82 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
83 |
+
|
84 |
+
attention_mask = kwargs.get("attention_mask", None)
|
85 |
+
position_ids = kwargs.get("position_ids", None)
|
86 |
+
|
87 |
+
if attention_mask is not None and position_ids is None:
|
88 |
+
# create position_ids on the fly for batch generation
|
89 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
90 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
91 |
+
if past:
|
92 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
93 |
+
else:
|
94 |
+
position_ids = None
|
95 |
+
return {
|
96 |
+
"input_ids": input_ids,
|
97 |
+
"past_key_values": past,
|
98 |
+
"use_cache": kwargs.get("use_cache"),
|
99 |
+
"position_ids": position_ids,
|
100 |
+
"attention_mask": attention_mask,
|
101 |
+
"token_type_ids": token_type_ids,
|
102 |
+
}
|
103 |
+
|
104 |
+
def forward(
|
105 |
+
self,
|
106 |
+
input_ids=None,
|
107 |
+
past_key_values=None,
|
108 |
+
attention_mask=None,
|
109 |
+
token_type_ids=None,
|
110 |
+
position_ids=None,
|
111 |
+
head_mask=None,
|
112 |
+
inputs_embeds=None,
|
113 |
+
encoder_hidden_states=None,
|
114 |
+
encoder_attention_mask=None,
|
115 |
+
labels=None,
|
116 |
+
use_cache=None,
|
117 |
+
output_attentions=None,
|
118 |
+
output_hidden_states=None,
|
119 |
+
return_dict=None,
|
120 |
+
):
|
121 |
+
assert self.cached_mel_emb is not None
|
122 |
+
assert inputs_embeds is None # Not supported by this inference model.
|
123 |
+
assert labels is None # Training not supported by this inference model.
|
124 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
125 |
+
|
126 |
+
# Create embedding
|
127 |
+
mel_len = self.cached_mel_emb.shape[1]
|
128 |
+
if input_ids.shape[1] != 1:
|
129 |
+
text_inputs = input_ids[:, mel_len:]
|
130 |
+
text_emb = self.embeddings(text_inputs)
|
131 |
+
text_emb = text_emb + self.text_pos_embedding(text_emb)
|
132 |
+
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
133 |
+
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
134 |
+
else:
|
135 |
+
mel_emb = self.cached_mel_emb
|
136 |
+
emb = torch.cat([mel_emb, text_emb], dim=1)
|
137 |
+
else:
|
138 |
+
emb = self.embeddings(input_ids)
|
139 |
+
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)
|
140 |
+
|
141 |
+
transformer_outputs = self.transformer(
|
142 |
+
inputs_embeds=emb,
|
143 |
+
past_key_values=past_key_values,
|
144 |
+
attention_mask=attention_mask,
|
145 |
+
token_type_ids=token_type_ids,
|
146 |
+
position_ids=position_ids,
|
147 |
+
head_mask=head_mask,
|
148 |
+
encoder_hidden_states=encoder_hidden_states,
|
149 |
+
encoder_attention_mask=encoder_attention_mask,
|
150 |
+
use_cache=use_cache,
|
151 |
+
output_attentions=output_attentions,
|
152 |
+
output_hidden_states=output_hidden_states,
|
153 |
+
return_dict=return_dict,
|
154 |
+
)
|
155 |
+
hidden_states = transformer_outputs[0]
|
156 |
+
|
157 |
+
# Set device for model parallelism
|
158 |
+
if self.model_parallel:
|
159 |
+
torch.cuda.set_device(self.transformer.first_device)
|
160 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
161 |
+
|
162 |
+
lm_logits = self.lm_head(hidden_states)
|
163 |
+
|
164 |
+
if not return_dict:
|
165 |
+
return (lm_logits,) + transformer_outputs[1:]
|
166 |
+
|
167 |
+
return CausalLMOutputWithCrossAttentions(
|
168 |
+
loss=None,
|
169 |
+
logits=lm_logits,
|
170 |
+
past_key_values=transformer_outputs.past_key_values,
|
171 |
+
hidden_states=transformer_outputs.hidden_states,
|
172 |
+
attentions=transformer_outputs.attentions,
|
173 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
174 |
+
)
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def _reorder_cache(past, beam_idx):
|
178 |
+
"""
|
179 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
180 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
181 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
182 |
+
"""
|
183 |
+
return tuple(
|
184 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
185 |
+
for layer_past in past
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
class ConditioningEncoder(nn.Module):
|
190 |
+
def __init__(self,
|
191 |
+
spec_dim,
|
192 |
+
embedding_dim,
|
193 |
+
attn_blocks=6,
|
194 |
+
num_attn_heads=4,
|
195 |
+
do_checkpointing=False,
|
196 |
+
mean=False):
|
197 |
+
super().__init__()
|
198 |
+
attn = []
|
199 |
+
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
200 |
+
for a in range(attn_blocks):
|
201 |
+
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
202 |
+
self.attn = nn.Sequential(*attn)
|
203 |
+
self.dim = embedding_dim
|
204 |
+
self.do_checkpointing = do_checkpointing
|
205 |
+
self.mean = mean
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
h = self.init(x)
|
209 |
+
h = self.attn(h)
|
210 |
+
if self.mean:
|
211 |
+
return h.mean(dim=2)
|
212 |
+
else:
|
213 |
+
return h[:, :, 0]
|
214 |
+
|
215 |
+
|
216 |
+
class LearnedPositionEmbeddings(nn.Module):
|
217 |
+
def __init__(self, seq_len, model_dim, init=.02):
|
218 |
+
super().__init__()
|
219 |
+
self.emb = nn.Embedding(seq_len, model_dim)
|
220 |
+
# Initializing this way is standard for GPT-2
|
221 |
+
self.emb.weight.data.normal_(mean=0.0, std=init)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
sl = x.shape[1]
|
225 |
+
return self.emb(torch.arange(0, sl, device=x.device))
|
226 |
+
|
227 |
+
def get_fixed_embedding(self, ind, dev):
|
228 |
+
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
229 |
+
|
230 |
+
|
231 |
+
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
232 |
+
"""
|
233 |
+
GPT-2 implemented by the HuggingFace library.
|
234 |
+
"""
|
235 |
+
from transformers import GPT2Config, GPT2Model
|
236 |
+
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
237 |
+
n_positions=max_mel_seq_len+max_text_seq_len,
|
238 |
+
n_ctx=max_mel_seq_len+max_text_seq_len,
|
239 |
+
n_embd=model_dim,
|
240 |
+
n_layer=layers,
|
241 |
+
n_head=heads,
|
242 |
+
gradient_checkpointing=checkpointing,
|
243 |
+
use_cache=not checkpointing)
|
244 |
+
gpt = GPT2Model(gpt_config)
|
245 |
+
# Override the built in positional embeddings
|
246 |
+
del gpt.wpe
|
247 |
+
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
248 |
+
# Built-in token embeddings are unused.
|
249 |
+
del gpt.wte
|
250 |
+
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
251 |
+
None, None
|
252 |
+
|
253 |
+
|
254 |
+
class MelEncoder(nn.Module):
|
255 |
+
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
256 |
+
super().__init__()
|
257 |
+
self.channels = channels
|
258 |
+
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
259 |
+
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
260 |
+
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
261 |
+
nn.GroupNorm(channels//16, channels//2),
|
262 |
+
nn.ReLU(),
|
263 |
+
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
264 |
+
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
265 |
+
nn.GroupNorm(channels//8, channels),
|
266 |
+
nn.ReLU(),
|
267 |
+
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
268 |
+
)
|
269 |
+
self.reduction = 4
|
270 |
+
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
for e in self.encoder:
|
274 |
+
x = e(x)
|
275 |
+
return x.permute(0,2,1)
|
276 |
+
|
277 |
+
|
278 |
+
class UnifiedVoice(nn.Module):
|
279 |
+
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
280 |
+
mel_length_compression=1024, number_text_tokens=256,
|
281 |
+
start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
|
282 |
+
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
283 |
+
checkpointing=True, average_conditioning_embeddings=False,
|
284 |
+
types=1):
|
285 |
+
"""
|
286 |
+
Args:
|
287 |
+
layers: Number of layers in transformer stack.
|
288 |
+
model_dim: Operating dimensions of the transformer
|
289 |
+
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
|
290 |
+
max_text_tokens: Maximum number of text tokens that will be encountered by model.
|
291 |
+
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
|
292 |
+
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
293 |
+
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
294 |
+
number_text_tokens:
|
295 |
+
start_text_token:
|
296 |
+
stop_text_token:
|
297 |
+
number_mel_codes:
|
298 |
+
start_mel_token:
|
299 |
+
stop_mel_token:
|
300 |
+
train_solo_embeddings:
|
301 |
+
use_mel_codes_as_input:
|
302 |
+
checkpointing:
|
303 |
+
average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
|
304 |
+
"""
|
305 |
+
super().__init__()
|
306 |
+
|
307 |
+
self.number_text_tokens = number_text_tokens
|
308 |
+
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
309 |
+
self.stop_text_token = 0
|
310 |
+
self.number_mel_codes = number_mel_codes
|
311 |
+
self.start_mel_token = start_mel_token
|
312 |
+
self.stop_mel_token = stop_mel_token
|
313 |
+
self.layers = layers
|
314 |
+
self.heads = heads
|
315 |
+
self.max_mel_tokens = max_mel_tokens
|
316 |
+
self.max_text_tokens = max_text_tokens
|
317 |
+
self.model_dim = model_dim
|
318 |
+
self.max_conditioning_inputs = max_conditioning_inputs
|
319 |
+
self.mel_length_compression = mel_length_compression
|
320 |
+
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
321 |
+
self.average_conditioning_embeddings = average_conditioning_embeddings
|
322 |
+
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
323 |
+
if use_mel_codes_as_input:
|
324 |
+
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
325 |
+
else:
|
326 |
+
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
327 |
+
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
328 |
+
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
|
329 |
+
if train_solo_embeddings:
|
330 |
+
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
331 |
+
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
332 |
+
else:
|
333 |
+
self.mel_solo_embedding = 0
|
334 |
+
self.text_solo_embedding = 0
|
335 |
+
|
336 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
337 |
+
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
338 |
+
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
339 |
+
|
340 |
+
# Initialize the embeddings per the GPT-2 scheme
|
341 |
+
embeddings = [self.text_embedding]
|
342 |
+
if use_mel_codes_as_input:
|
343 |
+
embeddings.append(self.mel_embedding)
|
344 |
+
for module in embeddings:
|
345 |
+
module.weight.data.normal_(mean=0.0, std=.02)
|
346 |
+
|
347 |
+
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
348 |
+
inp = F.pad(input, (1,0), value=start_token)
|
349 |
+
tar = F.pad(input, (0,1), value=stop_token)
|
350 |
+
return inp, tar
|
351 |
+
|
352 |
+
def set_mel_padding(self, mel_input_tokens, wav_lengths):
|
353 |
+
"""
|
354 |
+
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
|
355 |
+
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
|
356 |
+
preformatting to create a working TTS model.
|
357 |
+
"""
|
358 |
+
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
359 |
+
mel_lengths = wav_lengths // self.mel_length_compression
|
360 |
+
for b in range(len(mel_lengths)):
|
361 |
+
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
362 |
+
if actual_end < mel_input_tokens.shape[-1]:
|
363 |
+
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
364 |
+
return mel_input_tokens
|
365 |
+
|
366 |
+
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
|
367 |
+
if second_inputs is not None:
|
368 |
+
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
369 |
+
else:
|
370 |
+
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
371 |
+
|
372 |
+
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
373 |
+
if get_attns:
|
374 |
+
return gpt_out.attentions
|
375 |
+
|
376 |
+
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
|
377 |
+
enc = self.final_norm(enc)
|
378 |
+
|
379 |
+
if return_latent:
|
380 |
+
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
381 |
+
|
382 |
+
first_logits = enc[:, :first_inputs.shape[1]]
|
383 |
+
first_logits = first_head(first_logits)
|
384 |
+
first_logits = first_logits.permute(0,2,1)
|
385 |
+
if second_inputs is not None:
|
386 |
+
second_logits = enc[:, -second_inputs.shape[1]:]
|
387 |
+
second_logits = second_head(second_logits)
|
388 |
+
second_logits = second_logits.permute(0,2,1)
|
389 |
+
return first_logits, second_logits
|
390 |
+
else:
|
391 |
+
return first_logits
|
392 |
+
|
393 |
+
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
394 |
+
return_latent=False, clip_inputs=True):
|
395 |
+
"""
|
396 |
+
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
397 |
+
(actuated by `text_first`).
|
398 |
+
|
399 |
+
speech_conditioning_input: MEL float tensor, (b,80,s)
|
400 |
+
text_inputs: long tensor, (b,t)
|
401 |
+
text_lengths: long tensor, (b,)
|
402 |
+
mel_inputs: long tensor, (b,m)
|
403 |
+
wav_lengths: long tensor, (b,)
|
404 |
+
raw_mels: MEL float tensor (b,80,s)
|
405 |
+
|
406 |
+
If return_attentions is specified, only logits are returned.
|
407 |
+
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
408 |
+
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
409 |
+
"""
|
410 |
+
# Types are expressed by expanding the text embedding space.
|
411 |
+
if types is not None:
|
412 |
+
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
413 |
+
|
414 |
+
if clip_inputs:
|
415 |
+
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
416 |
+
# chopping the inputs by the maximum actual length.
|
417 |
+
max_text_len = text_lengths.max()
|
418 |
+
text_inputs = text_inputs[:, :max_text_len]
|
419 |
+
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
420 |
+
mel_codes = mel_codes[:, :max_mel_len]
|
421 |
+
if raw_mels is not None:
|
422 |
+
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
423 |
+
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
424 |
+
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
425 |
+
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
426 |
+
|
427 |
+
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
428 |
+
conds = []
|
429 |
+
for j in range(speech_conditioning_input.shape[1]):
|
430 |
+
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
431 |
+
conds = torch.stack(conds, dim=1)
|
432 |
+
if self.average_conditioning_embeddings:
|
433 |
+
conds = conds.mean(dim=1).unsqueeze(1)
|
434 |
+
|
435 |
+
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
436 |
+
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
437 |
+
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
438 |
+
if raw_mels is not None:
|
439 |
+
mel_inp = F.pad(raw_mels, (0, 8))
|
440 |
+
else:
|
441 |
+
mel_inp = mel_codes
|
442 |
+
mel_emb = self.mel_embedding(mel_inp)
|
443 |
+
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
444 |
+
|
445 |
+
if text_first:
|
446 |
+
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
447 |
+
if return_latent:
|
448 |
+
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
449 |
+
else:
|
450 |
+
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
451 |
+
if return_latent:
|
452 |
+
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
453 |
+
|
454 |
+
if return_attentions:
|
455 |
+
return mel_logits
|
456 |
+
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
457 |
+
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
458 |
+
return loss_text.mean(), loss_mel.mean(), mel_logits
|
459 |
+
|
460 |
+
def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
|
461 |
+
"""
|
462 |
+
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
|
463 |
+
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
|
464 |
+
"""
|
465 |
+
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
466 |
+
|
467 |
+
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
468 |
+
# chopping the inputs by the maximum actual length.
|
469 |
+
max_text_len = text_lengths.max()
|
470 |
+
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
471 |
+
|
472 |
+
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
473 |
+
conds = []
|
474 |
+
for j in range(speech_conditioning_input.shape[1]):
|
475 |
+
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
476 |
+
conds = torch.stack(conds, dim=1)
|
477 |
+
if self.average_conditioning_embeddings:
|
478 |
+
conds = conds.mean(dim=1).unsqueeze(1)
|
479 |
+
|
480 |
+
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
481 |
+
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
482 |
+
text_logits = self.get_logits(conds, text_emb, self.text_head)
|
483 |
+
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
484 |
+
return loss_text.mean()
|
485 |
+
|
486 |
+
def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
|
487 |
+
"""
|
488 |
+
Performs autoregressive modeling on only speech data.
|
489 |
+
"""
|
490 |
+
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
491 |
+
|
492 |
+
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
493 |
+
# chopping the inputs by the maximum actual length.
|
494 |
+
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
495 |
+
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
496 |
+
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
497 |
+
if raw_mels is not None:
|
498 |
+
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
499 |
+
|
500 |
+
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
501 |
+
conds = []
|
502 |
+
for j in range(speech_conditioning_input.shape[1]):
|
503 |
+
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
504 |
+
conds = torch.stack(conds, dim=1)
|
505 |
+
if self.average_conditioning_embeddings:
|
506 |
+
conds = conds.mean(dim=1).unsqueeze(1)
|
507 |
+
|
508 |
+
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
509 |
+
if raw_mels is not None:
|
510 |
+
mel_inp = F.pad(raw_mels, (0, 4))
|
511 |
+
else:
|
512 |
+
mel_inp = mel_codes
|
513 |
+
mel_emb = self.mel_embedding(mel_inp)
|
514 |
+
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
|
515 |
+
mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
|
516 |
+
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
517 |
+
return loss_mel.mean()
|
518 |
+
|
519 |
+
def inference_speech(self, speech_conditioning_input, text_inputs, input_tokens=None, num_return_sequences=1,
|
520 |
+
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
521 |
+
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
522 |
+
if not hasattr(self, 'inference_model'):
|
523 |
+
# TODO: Decouple gpt_config from this inference model.
|
524 |
+
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
525 |
+
n_positions=seq_length,
|
526 |
+
n_ctx=seq_length,
|
527 |
+
n_embd=self.model_dim,
|
528 |
+
n_layer=self.layers,
|
529 |
+
n_head=self.heads,
|
530 |
+
gradient_checkpointing=False,
|
531 |
+
use_cache=True)
|
532 |
+
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
533 |
+
self.gpt.wte = self.mel_embedding
|
534 |
+
|
535 |
+
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
536 |
+
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
537 |
+
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
538 |
+
|
539 |
+
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
540 |
+
conds = []
|
541 |
+
for j in range(speech_conditioning_input.shape[1]):
|
542 |
+
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
543 |
+
conds = torch.stack(conds, dim=1)
|
544 |
+
if self.average_conditioning_embeddings:
|
545 |
+
conds = conds.mean(dim=1).unsqueeze(1)
|
546 |
+
|
547 |
+
emb = torch.cat([conds, text_emb], dim=1)
|
548 |
+
self.inference_model.store_mel_emb(emb)
|
549 |
+
|
550 |
+
fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long,
|
551 |
+
device=text_inputs.device)
|
552 |
+
fake_inputs[:, -1] = self.start_mel_token
|
553 |
+
trunc_index = fake_inputs.shape[1]
|
554 |
+
if input_tokens is None:
|
555 |
+
inputs = fake_inputs
|
556 |
+
else:
|
557 |
+
assert num_return_sequences % input_tokens.shape[0] == 0, "The number of return sequences must be divisible by the number of input sequences"
|
558 |
+
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
|
559 |
+
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
|
560 |
+
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
561 |
+
|
562 |
+
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
563 |
+
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
564 |
+
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
565 |
+
max_length=max_length, logits_processor=logits_processor,
|
566 |
+
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
|
567 |
+
return gen[:, trunc_index:]
|
568 |
+
|
569 |
+
|
570 |
+
if __name__ == '__main__':
|
571 |
+
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
572 |
+
l = gpt(torch.randn(2, 3, 80, 800),
|
573 |
+
torch.randint(high=120, size=(2,120)),
|
574 |
+
torch.tensor([32, 120]),
|
575 |
+
torch.randint(high=8192, size=(2,250)),
|
576 |
+
torch.tensor([250*256,195*256]))
|
577 |
+
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
models/clvp.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import einsum
|
5 |
+
from x_transformers import Encoder
|
6 |
+
|
7 |
+
from models.arch_util import CheckpointedXTransformerEncoder
|
8 |
+
from models.transformer import Transformer
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def masked_mean(t, mask, dim = 1):
|
16 |
+
t = t.masked_fill(~mask[:, :, None], 0.)
|
17 |
+
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
18 |
+
|
19 |
+
class CLVP(nn.Module):
|
20 |
+
"""
|
21 |
+
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
|
22 |
+
transcribed text.
|
23 |
+
|
24 |
+
Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
*,
|
30 |
+
dim_text=512,
|
31 |
+
dim_speech=512,
|
32 |
+
dim_latent=512,
|
33 |
+
num_text_tokens=256,
|
34 |
+
text_enc_depth=6,
|
35 |
+
text_seq_len=120,
|
36 |
+
text_heads=8,
|
37 |
+
num_speech_tokens=8192,
|
38 |
+
speech_enc_depth=6,
|
39 |
+
speech_heads=8,
|
40 |
+
speech_seq_len=250,
|
41 |
+
text_mask_percentage=0,
|
42 |
+
voice_mask_percentage=0,
|
43 |
+
wav_token_compression=1024,
|
44 |
+
use_xformers=False,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
48 |
+
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
|
49 |
+
|
50 |
+
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
51 |
+
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
52 |
+
|
53 |
+
if use_xformers:
|
54 |
+
self.text_transformer = CheckpointedXTransformerEncoder(
|
55 |
+
needs_permute=False,
|
56 |
+
exit_permute=False,
|
57 |
+
max_seq_len=-1,
|
58 |
+
attn_layers=Encoder(
|
59 |
+
dim=dim_text,
|
60 |
+
depth=text_enc_depth,
|
61 |
+
heads=text_heads,
|
62 |
+
ff_dropout=.1,
|
63 |
+
ff_mult=2,
|
64 |
+
attn_dropout=.1,
|
65 |
+
use_rmsnorm=True,
|
66 |
+
ff_glu=True,
|
67 |
+
rotary_pos_emb=True,
|
68 |
+
))
|
69 |
+
self.speech_transformer = CheckpointedXTransformerEncoder(
|
70 |
+
needs_permute=False,
|
71 |
+
exit_permute=False,
|
72 |
+
max_seq_len=-1,
|
73 |
+
attn_layers=Encoder(
|
74 |
+
dim=dim_speech,
|
75 |
+
depth=speech_enc_depth,
|
76 |
+
heads=speech_heads,
|
77 |
+
ff_dropout=.1,
|
78 |
+
ff_mult=2,
|
79 |
+
attn_dropout=.1,
|
80 |
+
use_rmsnorm=True,
|
81 |
+
ff_glu=True,
|
82 |
+
rotary_pos_emb=True,
|
83 |
+
))
|
84 |
+
else:
|
85 |
+
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
86 |
+
heads=text_heads)
|
87 |
+
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
88 |
+
depth=speech_enc_depth, heads=speech_heads)
|
89 |
+
|
90 |
+
self.temperature = nn.Parameter(torch.tensor(1.))
|
91 |
+
self.text_mask_percentage = text_mask_percentage
|
92 |
+
self.voice_mask_percentage = voice_mask_percentage
|
93 |
+
self.wav_token_compression = wav_token_compression
|
94 |
+
self.xformers = use_xformers
|
95 |
+
if not use_xformers:
|
96 |
+
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
97 |
+
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
98 |
+
|
99 |
+
def forward(
|
100 |
+
self,
|
101 |
+
text,
|
102 |
+
speech_tokens,
|
103 |
+
return_loss=False
|
104 |
+
):
|
105 |
+
b, device = text.shape[0], text.device
|
106 |
+
if self.training:
|
107 |
+
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
108 |
+
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
|
109 |
+
else:
|
110 |
+
text_mask = torch.ones_like(text.float()).bool()
|
111 |
+
voice_mask = torch.ones_like(speech_tokens.float()).bool()
|
112 |
+
|
113 |
+
text_emb = self.text_emb(text)
|
114 |
+
speech_emb = self.speech_emb(speech_tokens)
|
115 |
+
|
116 |
+
if not self.xformers:
|
117 |
+
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
118 |
+
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
119 |
+
|
120 |
+
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
121 |
+
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
122 |
+
|
123 |
+
text_latents = masked_mean(enc_text, text_mask, dim=1)
|
124 |
+
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
|
125 |
+
|
126 |
+
text_latents = self.to_text_latent(text_latents)
|
127 |
+
speech_latents = self.to_speech_latent(speech_latents)
|
128 |
+
|
129 |
+
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
130 |
+
|
131 |
+
temp = self.temperature.exp()
|
132 |
+
|
133 |
+
if not return_loss:
|
134 |
+
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
|
135 |
+
return sim
|
136 |
+
|
137 |
+
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
138 |
+
labels = torch.arange(b, device=device)
|
139 |
+
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
140 |
+
return loss
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == '__main__':
|
144 |
+
clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2)
|
145 |
+
clip(torch.randint(0,256,(2,120)),
|
146 |
+
torch.tensor([50,100]),
|
147 |
+
torch.randint(0,8192,(2,250)),
|
148 |
+
torch.tensor([101,102]),
|
149 |
+
return_loss=True)
|
150 |
+
nonloss = clip(torch.randint(0,256,(2,120)),
|
151 |
+
torch.tensor([50,100]),
|
152 |
+
torch.randint(0,8192,(2,250)),
|
153 |
+
torch.tensor([101,102]),
|
154 |
+
return_loss=False)
|
155 |
+
print(nonloss.shape)
|
models/cvvp.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import einsum
|
5 |
+
from torch.utils.checkpoint import checkpoint
|
6 |
+
|
7 |
+
from models.arch_util import AttentionBlock
|
8 |
+
from models.xtransformers import ContinuousTransformerWrapper, Encoder
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def masked_mean(t, mask):
|
16 |
+
t = t.masked_fill(~mask, 0.)
|
17 |
+
return t.sum(dim = 1) / mask.sum(dim = 1)
|
18 |
+
|
19 |
+
|
20 |
+
class CollapsingTransformer(nn.Module):
|
21 |
+
def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
|
22 |
+
super().__init__()
|
23 |
+
self.transformer = ContinuousTransformerWrapper(
|
24 |
+
max_seq_len=-1,
|
25 |
+
use_pos_emb=False,
|
26 |
+
attn_layers=Encoder(
|
27 |
+
dim=model_dim,
|
28 |
+
depth=depth,
|
29 |
+
heads=heads,
|
30 |
+
ff_dropout=dropout,
|
31 |
+
ff_mult=1,
|
32 |
+
attn_dropout=dropout,
|
33 |
+
use_rmsnorm=True,
|
34 |
+
ff_glu=True,
|
35 |
+
rotary_pos_emb=True,
|
36 |
+
**encoder_kwargs,
|
37 |
+
))
|
38 |
+
self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1),
|
39 |
+
AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False),
|
40 |
+
nn.Conv1d(output_dims, output_dims, 1))
|
41 |
+
self.mask_percentage = mask_percentage
|
42 |
+
|
43 |
+
def forward(self, x, **transformer_kwargs):
|
44 |
+
h = self.transformer(x, **transformer_kwargs)
|
45 |
+
h = h.permute(0,2,1)
|
46 |
+
h = checkpoint(self.pre_combiner, h).permute(0,2,1)
|
47 |
+
if self.training:
|
48 |
+
mask = torch.rand_like(h.float()) > self.mask_percentage
|
49 |
+
else:
|
50 |
+
mask = torch.ones_like(h.float()).bool()
|
51 |
+
return masked_mean(h, mask)
|
52 |
+
|
53 |
+
|
54 |
+
class ConvFormatEmbedding(nn.Module):
|
55 |
+
def __init__(self, *args, **kwargs):
|
56 |
+
super().__init__()
|
57 |
+
self.emb = nn.Embedding(*args, **kwargs)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
y = self.emb(x)
|
61 |
+
return y.permute(0,2,1)
|
62 |
+
|
63 |
+
|
64 |
+
class CVVP(nn.Module):
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
model_dim=512,
|
68 |
+
transformer_heads=8,
|
69 |
+
dropout=.1,
|
70 |
+
conditioning_enc_depth=8,
|
71 |
+
cond_mask_percentage=0,
|
72 |
+
mel_channels=80,
|
73 |
+
mel_codes=None,
|
74 |
+
speech_enc_depth=8,
|
75 |
+
speech_mask_percentage=0,
|
76 |
+
latent_multiplier=1,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
latent_dim = latent_multiplier*model_dim
|
80 |
+
self.temperature = nn.Parameter(torch.tensor(1.))
|
81 |
+
|
82 |
+
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
|
83 |
+
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
|
84 |
+
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
|
85 |
+
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
86 |
+
|
87 |
+
if mel_codes is None:
|
88 |
+
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
89 |
+
else:
|
90 |
+
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
91 |
+
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
|
92 |
+
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
93 |
+
|
94 |
+
def get_grad_norm_parameter_groups(self):
|
95 |
+
return {
|
96 |
+
'conditioning': list(self.conditioning_transformer.parameters()),
|
97 |
+
'speech': list(self.speech_transformer.parameters()),
|
98 |
+
}
|
99 |
+
|
100 |
+
def forward(
|
101 |
+
self,
|
102 |
+
mel_cond,
|
103 |
+
mel_input,
|
104 |
+
return_loss=False
|
105 |
+
):
|
106 |
+
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
|
107 |
+
enc_cond = self.conditioning_transformer(cond_emb)
|
108 |
+
cond_latents = self.to_conditioning_latent(enc_cond)
|
109 |
+
|
110 |
+
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
|
111 |
+
enc_speech = self.speech_transformer(speech_emb)
|
112 |
+
speech_latents = self.to_speech_latent(enc_speech)
|
113 |
+
|
114 |
+
|
115 |
+
cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents))
|
116 |
+
temp = self.temperature.exp()
|
117 |
+
|
118 |
+
if not return_loss:
|
119 |
+
sim = einsum('n d, n d -> n', cond_latents, speech_latents) * temp
|
120 |
+
return sim
|
121 |
+
|
122 |
+
sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp
|
123 |
+
labels = torch.arange(cond_latents.shape[0], device=mel_input.device)
|
124 |
+
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
125 |
+
|
126 |
+
return loss
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == '__main__':
|
130 |
+
clvp = CVVP()
|
131 |
+
clvp(torch.randn(2,80,100),
|
132 |
+
torch.randn(2,80,95),
|
133 |
+
return_loss=True)
|
models/diffusion_decoder.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from abc import abstractmethod
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import autocast
|
9 |
+
|
10 |
+
from models.arch_util import normalization, AttentionBlock
|
11 |
+
|
12 |
+
|
13 |
+
def is_latent(t):
|
14 |
+
return t.dtype == torch.float
|
15 |
+
|
16 |
+
|
17 |
+
def is_sequence(t):
|
18 |
+
return t.dtype == torch.long
|
19 |
+
|
20 |
+
|
21 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
22 |
+
"""
|
23 |
+
Create sinusoidal timestep embeddings.
|
24 |
+
|
25 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
26 |
+
These may be fractional.
|
27 |
+
:param dim: the dimension of the output.
|
28 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
29 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
30 |
+
"""
|
31 |
+
half = dim // 2
|
32 |
+
freqs = torch.exp(
|
33 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
34 |
+
).to(device=timesteps.device)
|
35 |
+
args = timesteps[:, None].float() * freqs[None]
|
36 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
37 |
+
if dim % 2:
|
38 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
39 |
+
return embedding
|
40 |
+
|
41 |
+
|
42 |
+
class TimestepBlock(nn.Module):
|
43 |
+
@abstractmethod
|
44 |
+
def forward(self, x, emb):
|
45 |
+
"""
|
46 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
47 |
+
"""
|
48 |
+
|
49 |
+
|
50 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
51 |
+
def forward(self, x, emb):
|
52 |
+
for layer in self:
|
53 |
+
if isinstance(layer, TimestepBlock):
|
54 |
+
x = layer(x, emb)
|
55 |
+
else:
|
56 |
+
x = layer(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class ResBlock(TimestepBlock):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
channels,
|
64 |
+
emb_channels,
|
65 |
+
dropout,
|
66 |
+
out_channels=None,
|
67 |
+
dims=2,
|
68 |
+
kernel_size=3,
|
69 |
+
efficient_config=True,
|
70 |
+
use_scale_shift_norm=False,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
self.channels = channels
|
74 |
+
self.emb_channels = emb_channels
|
75 |
+
self.dropout = dropout
|
76 |
+
self.out_channels = out_channels or channels
|
77 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
78 |
+
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
|
79 |
+
eff_kernel = 1 if efficient_config else 3
|
80 |
+
eff_padding = 0 if efficient_config else 1
|
81 |
+
|
82 |
+
self.in_layers = nn.Sequential(
|
83 |
+
normalization(channels),
|
84 |
+
nn.SiLU(),
|
85 |
+
nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
|
86 |
+
)
|
87 |
+
|
88 |
+
self.emb_layers = nn.Sequential(
|
89 |
+
nn.SiLU(),
|
90 |
+
nn.Linear(
|
91 |
+
emb_channels,
|
92 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
93 |
+
),
|
94 |
+
)
|
95 |
+
self.out_layers = nn.Sequential(
|
96 |
+
normalization(self.out_channels),
|
97 |
+
nn.SiLU(),
|
98 |
+
nn.Dropout(p=dropout),
|
99 |
+
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
|
100 |
+
)
|
101 |
+
|
102 |
+
if self.out_channels == channels:
|
103 |
+
self.skip_connection = nn.Identity()
|
104 |
+
else:
|
105 |
+
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
|
106 |
+
|
107 |
+
def forward(self, x, emb):
|
108 |
+
h = self.in_layers(x)
|
109 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
110 |
+
while len(emb_out.shape) < len(h.shape):
|
111 |
+
emb_out = emb_out[..., None]
|
112 |
+
if self.use_scale_shift_norm:
|
113 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
114 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
115 |
+
h = out_norm(h) * (1 + scale) + shift
|
116 |
+
h = out_rest(h)
|
117 |
+
else:
|
118 |
+
h = h + emb_out
|
119 |
+
h = self.out_layers(h)
|
120 |
+
return self.skip_connection(x) + h
|
121 |
+
|
122 |
+
|
123 |
+
class DiffusionLayer(TimestepBlock):
|
124 |
+
def __init__(self, model_channels, dropout, num_heads):
|
125 |
+
super().__init__()
|
126 |
+
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
|
127 |
+
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
128 |
+
|
129 |
+
def forward(self, x, time_emb):
|
130 |
+
y = self.resblk(x, time_emb)
|
131 |
+
return self.attn(y)
|
132 |
+
|
133 |
+
|
134 |
+
class DiffusionTts(nn.Module):
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
model_channels=512,
|
138 |
+
num_layers=8,
|
139 |
+
in_channels=100,
|
140 |
+
in_latent_channels=512,
|
141 |
+
in_tokens=8193,
|
142 |
+
out_channels=200, # mean and variance
|
143 |
+
dropout=0,
|
144 |
+
use_fp16=False,
|
145 |
+
num_heads=16,
|
146 |
+
# Parameters for regularization.
|
147 |
+
layer_drop=.1,
|
148 |
+
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
|
152 |
+
self.in_channels = in_channels
|
153 |
+
self.model_channels = model_channels
|
154 |
+
self.out_channels = out_channels
|
155 |
+
self.dropout = dropout
|
156 |
+
self.num_heads = num_heads
|
157 |
+
self.unconditioned_percentage = unconditioned_percentage
|
158 |
+
self.enable_fp16 = use_fp16
|
159 |
+
self.layer_drop = layer_drop
|
160 |
+
|
161 |
+
self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
|
162 |
+
self.time_embed = nn.Sequential(
|
163 |
+
nn.Linear(model_channels, model_channels),
|
164 |
+
nn.SiLU(),
|
165 |
+
nn.Linear(model_channels, model_channels),
|
166 |
+
)
|
167 |
+
|
168 |
+
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
169 |
+
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
170 |
+
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
171 |
+
# transformer network.
|
172 |
+
self.code_embedding = nn.Embedding(in_tokens, model_channels)
|
173 |
+
self.code_converter = nn.Sequential(
|
174 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
175 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
176 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
177 |
+
)
|
178 |
+
self.code_norm = normalization(model_channels)
|
179 |
+
self.latent_conditioner = nn.Sequential(
|
180 |
+
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
181 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
182 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
183 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
184 |
+
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
185 |
+
)
|
186 |
+
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
187 |
+
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
188 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
189 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
190 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
191 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
192 |
+
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
|
193 |
+
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
194 |
+
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
195 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
196 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
197 |
+
DiffusionLayer(model_channels, dropout, num_heads),
|
198 |
+
)
|
199 |
+
|
200 |
+
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
|
201 |
+
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
202 |
+
|
203 |
+
self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
|
204 |
+
[ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
|
205 |
+
|
206 |
+
self.out = nn.Sequential(
|
207 |
+
normalization(model_channels),
|
208 |
+
nn.SiLU(),
|
209 |
+
nn.Conv1d(model_channels, out_channels, 3, padding=1),
|
210 |
+
)
|
211 |
+
|
212 |
+
def get_grad_norm_parameter_groups(self):
|
213 |
+
groups = {
|
214 |
+
'minicoder': list(self.contextual_embedder.parameters()),
|
215 |
+
'layers': list(self.layers.parameters()),
|
216 |
+
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
|
217 |
+
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
|
218 |
+
'time_embed': list(self.time_embed.parameters()),
|
219 |
+
}
|
220 |
+
return groups
|
221 |
+
|
222 |
+
def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
|
223 |
+
# Shuffle aligned_latent to BxCxS format
|
224 |
+
if is_latent(aligned_conditioning):
|
225 |
+
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
226 |
+
|
227 |
+
# Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
|
228 |
+
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
|
229 |
+
conditioning_input.shape) == 3 else conditioning_input
|
230 |
+
conds = []
|
231 |
+
for j in range(speech_conditioning_input.shape[1]):
|
232 |
+
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
233 |
+
conds = torch.cat(conds, dim=-1)
|
234 |
+
cond_emb = conds.mean(dim=-1)
|
235 |
+
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
236 |
+
if is_latent(aligned_conditioning):
|
237 |
+
code_emb = self.latent_conditioner(aligned_conditioning)
|
238 |
+
else:
|
239 |
+
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
240 |
+
code_emb = self.code_converter(code_emb)
|
241 |
+
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
242 |
+
|
243 |
+
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
244 |
+
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
245 |
+
if self.training and self.unconditioned_percentage > 0:
|
246 |
+
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
247 |
+
device=code_emb.device) < self.unconditioned_percentage
|
248 |
+
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
249 |
+
code_emb)
|
250 |
+
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
251 |
+
|
252 |
+
if not return_code_pred:
|
253 |
+
return expanded_code_emb
|
254 |
+
else:
|
255 |
+
mel_pred = self.mel_head(expanded_code_emb)
|
256 |
+
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
|
257 |
+
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
258 |
+
return expanded_code_emb, mel_pred
|
259 |
+
|
260 |
+
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
261 |
+
"""
|
262 |
+
Apply the model to an input batch.
|
263 |
+
|
264 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
265 |
+
:param timesteps: a 1-D batch of timesteps.
|
266 |
+
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
267 |
+
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
268 |
+
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
|
269 |
+
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
270 |
+
:return: an [N x C x ...] Tensor of outputs.
|
271 |
+
"""
|
272 |
+
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
|
273 |
+
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
274 |
+
|
275 |
+
unused_params = []
|
276 |
+
if conditioning_free:
|
277 |
+
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
278 |
+
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
279 |
+
unused_params.extend(list(self.latent_conditioner.parameters()))
|
280 |
+
else:
|
281 |
+
if precomputed_aligned_embeddings is not None:
|
282 |
+
code_emb = precomputed_aligned_embeddings
|
283 |
+
else:
|
284 |
+
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
|
285 |
+
if is_latent(aligned_conditioning):
|
286 |
+
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
287 |
+
else:
|
288 |
+
unused_params.extend(list(self.latent_conditioner.parameters()))
|
289 |
+
|
290 |
+
unused_params.append(self.unconditioned_embedding)
|
291 |
+
|
292 |
+
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
293 |
+
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
294 |
+
x = self.inp_block(x)
|
295 |
+
x = torch.cat([x, code_emb], dim=1)
|
296 |
+
x = self.integrating_conv(x)
|
297 |
+
for i, lyr in enumerate(self.layers):
|
298 |
+
# Do layer drop where applicable. Do not drop first and last layers.
|
299 |
+
if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop:
|
300 |
+
unused_params.extend(list(lyr.parameters()))
|
301 |
+
else:
|
302 |
+
# First and last blocks will have autocast disabled for improved precision.
|
303 |
+
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
|
304 |
+
x = lyr(x, time_emb)
|
305 |
+
|
306 |
+
x = x.float()
|
307 |
+
out = self.out(x)
|
308 |
+
|
309 |
+
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
310 |
+
extraneous_addition = 0
|
311 |
+
for p in unused_params:
|
312 |
+
extraneous_addition = extraneous_addition + p.mean()
|
313 |
+
out = out + extraneous_addition * 0
|
314 |
+
|
315 |
+
if return_code_pred:
|
316 |
+
return out, mel_pred
|
317 |
+
return out
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == '__main__':
|
321 |
+
clip = torch.randn(2, 100, 400)
|
322 |
+
aligned_latent = torch.randn(2,388,512)
|
323 |
+
aligned_sequence = torch.randint(0,8192,(2,100))
|
324 |
+
cond = torch.randn(2, 100, 400)
|
325 |
+
ts = torch.LongTensor([600, 600])
|
326 |
+
model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5)
|
327 |
+
# Test with latent aligned conditioning
|
328 |
+
#o = model(clip, ts, aligned_latent, cond)
|
329 |
+
# Test with sequence aligned conditioning
|
330 |
+
o = model(clip, ts, aligned_sequence, cond)
|
331 |
+
|
models/transformer.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
# helpers
|
11 |
+
|
12 |
+
|
13 |
+
def exists(val):
|
14 |
+
return val is not None
|
15 |
+
|
16 |
+
|
17 |
+
def default(val, d):
|
18 |
+
return val if exists(val) else d
|
19 |
+
|
20 |
+
|
21 |
+
def cast_tuple(val, depth = 1):
|
22 |
+
if isinstance(val, list):
|
23 |
+
val = tuple(val)
|
24 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
25 |
+
|
26 |
+
|
27 |
+
def max_neg_value(t):
|
28 |
+
return -torch.finfo(t.dtype).max
|
29 |
+
|
30 |
+
|
31 |
+
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
32 |
+
t = t / alpha
|
33 |
+
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
|
34 |
+
return (t * alpha).softmax(dim = dim)
|
35 |
+
|
36 |
+
|
37 |
+
def route_args(router, args, depth):
|
38 |
+
routed_args = [(dict(), dict()) for _ in range(depth)]
|
39 |
+
matched_keys = [key for key in args.keys() if key in router]
|
40 |
+
|
41 |
+
for key in matched_keys:
|
42 |
+
val = args[key]
|
43 |
+
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
44 |
+
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
45 |
+
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
46 |
+
return routed_args
|
47 |
+
|
48 |
+
|
49 |
+
# classes
|
50 |
+
class SequentialSequence(nn.Module):
|
51 |
+
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
|
52 |
+
super().__init__()
|
53 |
+
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
54 |
+
self.layers = layers
|
55 |
+
self.args_route = args_route
|
56 |
+
self.layer_dropout = layer_dropout
|
57 |
+
|
58 |
+
def forward(self, x, **kwargs):
|
59 |
+
args = route_args(self.args_route, kwargs, len(self.layers))
|
60 |
+
layers_and_args = list(zip(self.layers, args))
|
61 |
+
|
62 |
+
for (f, g), (f_args, g_args) in layers_and_args:
|
63 |
+
x = x + f(x, **f_args)
|
64 |
+
x = x + g(x, **g_args)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class DivideMax(nn.Module):
|
69 |
+
def __init__(self, dim):
|
70 |
+
super().__init__()
|
71 |
+
self.dim = dim
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
maxes = x.amax(dim = self.dim, keepdim = True).detach()
|
75 |
+
return x / maxes
|
76 |
+
|
77 |
+
|
78 |
+
# https://arxiv.org/abs/2103.17239
|
79 |
+
class LayerScale(nn.Module):
|
80 |
+
def __init__(self, dim, depth, fn):
|
81 |
+
super().__init__()
|
82 |
+
if depth <= 18:
|
83 |
+
init_eps = 0.1
|
84 |
+
elif depth > 18 and depth <= 24:
|
85 |
+
init_eps = 1e-5
|
86 |
+
else:
|
87 |
+
init_eps = 1e-6
|
88 |
+
|
89 |
+
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
90 |
+
self.scale = nn.Parameter(scale)
|
91 |
+
self.fn = fn
|
92 |
+
def forward(self, x, **kwargs):
|
93 |
+
return self.fn(x, **kwargs) * self.scale
|
94 |
+
|
95 |
+
# layer norm
|
96 |
+
|
97 |
+
|
98 |
+
class PreNorm(nn.Module):
|
99 |
+
def __init__(self, dim, fn, sandwich = False):
|
100 |
+
super().__init__()
|
101 |
+
self.norm = nn.LayerNorm(dim)
|
102 |
+
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
|
103 |
+
self.fn = fn
|
104 |
+
|
105 |
+
def forward(self, x, **kwargs):
|
106 |
+
x = self.norm(x)
|
107 |
+
x = self.fn(x, **kwargs)
|
108 |
+
return self.norm_out(x)
|
109 |
+
|
110 |
+
# feed forward
|
111 |
+
|
112 |
+
|
113 |
+
class GEGLU(nn.Module):
|
114 |
+
def forward(self, x):
|
115 |
+
x, gates = x.chunk(2, dim = -1)
|
116 |
+
return x * F.gelu(gates)
|
117 |
+
|
118 |
+
|
119 |
+
class FeedForward(nn.Module):
|
120 |
+
def __init__(self, dim, dropout = 0., mult = 4.):
|
121 |
+
super().__init__()
|
122 |
+
self.net = nn.Sequential(
|
123 |
+
nn.Linear(dim, dim * mult * 2),
|
124 |
+
GEGLU(),
|
125 |
+
nn.Dropout(dropout),
|
126 |
+
nn.Linear(dim * mult, dim)
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
return self.net(x)
|
131 |
+
|
132 |
+
# Attention
|
133 |
+
|
134 |
+
|
135 |
+
class Attention(nn.Module):
|
136 |
+
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
|
137 |
+
super().__init__()
|
138 |
+
inner_dim = dim_head * heads
|
139 |
+
self.heads = heads
|
140 |
+
self.seq_len = seq_len
|
141 |
+
self.scale = dim_head ** -0.5
|
142 |
+
|
143 |
+
self.causal = causal
|
144 |
+
|
145 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
146 |
+
self.to_out = nn.Sequential(
|
147 |
+
nn.Linear(inner_dim, dim),
|
148 |
+
nn.Dropout(dropout)
|
149 |
+
)
|
150 |
+
|
151 |
+
def forward(self, x, mask = None):
|
152 |
+
b, n, _, h, device = *x.shape, self.heads, x.device
|
153 |
+
softmax = torch.softmax
|
154 |
+
|
155 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
156 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
157 |
+
|
158 |
+
q = q * self.scale
|
159 |
+
|
160 |
+
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
161 |
+
mask_value = max_neg_value(dots)
|
162 |
+
|
163 |
+
if exists(mask):
|
164 |
+
mask = rearrange(mask, 'b j -> b () () j')
|
165 |
+
dots.masked_fill_(~mask, mask_value)
|
166 |
+
del mask
|
167 |
+
|
168 |
+
if self.causal:
|
169 |
+
i, j = dots.shape[-2:]
|
170 |
+
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
171 |
+
dots.masked_fill_(mask, mask_value)
|
172 |
+
|
173 |
+
attn = softmax(dots, dim=-1)
|
174 |
+
|
175 |
+
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
176 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
177 |
+
out = self.to_out(out)
|
178 |
+
return out
|
179 |
+
|
180 |
+
|
181 |
+
# main transformer class
|
182 |
+
class Transformer(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
*,
|
186 |
+
dim,
|
187 |
+
depth,
|
188 |
+
seq_len,
|
189 |
+
causal = True,
|
190 |
+
heads = 8,
|
191 |
+
dim_head = 64,
|
192 |
+
ff_mult = 4,
|
193 |
+
attn_dropout = 0.,
|
194 |
+
ff_dropout = 0.,
|
195 |
+
sparse_attn = False,
|
196 |
+
sandwich_norm = False,
|
197 |
+
):
|
198 |
+
super().__init__()
|
199 |
+
layers = nn.ModuleList([])
|
200 |
+
sparse_layer = cast_tuple(sparse_attn, depth)
|
201 |
+
|
202 |
+
for ind, sparse_attn in zip(range(depth), sparse_layer):
|
203 |
+
attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
|
204 |
+
|
205 |
+
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
|
206 |
+
|
207 |
+
layers.append(nn.ModuleList([
|
208 |
+
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
|
209 |
+
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
|
210 |
+
]))
|
211 |
+
|
212 |
+
execute_type = SequentialSequence
|
213 |
+
route_attn = ((True, False),) * depth
|
214 |
+
attn_route_map = {'mask': route_attn}
|
215 |
+
|
216 |
+
self.layers = execute_type(layers, args_route = attn_route_map)
|
217 |
+
|
218 |
+
def forward(self, x, **kwargs):
|
219 |
+
return self.layers(x, **kwargs)
|
models/vocoder.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
MAX_WAV_VALUE = 32768.0
|
6 |
+
|
7 |
+
class KernelPredictor(torch.nn.Module):
|
8 |
+
''' Kernel predictor for the location-variable convolutions'''
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
cond_channels,
|
13 |
+
conv_in_channels,
|
14 |
+
conv_out_channels,
|
15 |
+
conv_layers,
|
16 |
+
conv_kernel_size=3,
|
17 |
+
kpnet_hidden_channels=64,
|
18 |
+
kpnet_conv_size=3,
|
19 |
+
kpnet_dropout=0.0,
|
20 |
+
kpnet_nonlinear_activation="LeakyReLU",
|
21 |
+
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
22 |
+
):
|
23 |
+
'''
|
24 |
+
Args:
|
25 |
+
cond_channels (int): number of channel for the conditioning sequence,
|
26 |
+
conv_in_channels (int): number of channel for the input sequence,
|
27 |
+
conv_out_channels (int): number of channel for the output sequence,
|
28 |
+
conv_layers (int): number of layers
|
29 |
+
'''
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.conv_in_channels = conv_in_channels
|
33 |
+
self.conv_out_channels = conv_out_channels
|
34 |
+
self.conv_kernel_size = conv_kernel_size
|
35 |
+
self.conv_layers = conv_layers
|
36 |
+
|
37 |
+
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
38 |
+
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
39 |
+
|
40 |
+
self.input_conv = nn.Sequential(
|
41 |
+
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
42 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
43 |
+
)
|
44 |
+
|
45 |
+
self.residual_convs = nn.ModuleList()
|
46 |
+
padding = (kpnet_conv_size - 1) // 2
|
47 |
+
for _ in range(3):
|
48 |
+
self.residual_convs.append(
|
49 |
+
nn.Sequential(
|
50 |
+
nn.Dropout(kpnet_dropout),
|
51 |
+
nn.utils.weight_norm(
|
52 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
53 |
+
bias=True)),
|
54 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
55 |
+
nn.utils.weight_norm(
|
56 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
57 |
+
bias=True)),
|
58 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
59 |
+
)
|
60 |
+
)
|
61 |
+
self.kernel_conv = nn.utils.weight_norm(
|
62 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
|
63 |
+
self.bias_conv = nn.utils.weight_norm(
|
64 |
+
nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
|
65 |
+
|
66 |
+
def forward(self, c):
|
67 |
+
'''
|
68 |
+
Args:
|
69 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
70 |
+
'''
|
71 |
+
batch, _, cond_length = c.shape
|
72 |
+
c = self.input_conv(c)
|
73 |
+
for residual_conv in self.residual_convs:
|
74 |
+
residual_conv.to(c.device)
|
75 |
+
c = c + residual_conv(c)
|
76 |
+
k = self.kernel_conv(c)
|
77 |
+
b = self.bias_conv(c)
|
78 |
+
kernels = k.contiguous().view(
|
79 |
+
batch,
|
80 |
+
self.conv_layers,
|
81 |
+
self.conv_in_channels,
|
82 |
+
self.conv_out_channels,
|
83 |
+
self.conv_kernel_size,
|
84 |
+
cond_length,
|
85 |
+
)
|
86 |
+
bias = b.contiguous().view(
|
87 |
+
batch,
|
88 |
+
self.conv_layers,
|
89 |
+
self.conv_out_channels,
|
90 |
+
cond_length,
|
91 |
+
)
|
92 |
+
|
93 |
+
return kernels, bias
|
94 |
+
|
95 |
+
def remove_weight_norm(self):
|
96 |
+
nn.utils.remove_weight_norm(self.input_conv[0])
|
97 |
+
nn.utils.remove_weight_norm(self.kernel_conv)
|
98 |
+
nn.utils.remove_weight_norm(self.bias_conv)
|
99 |
+
for block in self.residual_convs:
|
100 |
+
nn.utils.remove_weight_norm(block[1])
|
101 |
+
nn.utils.remove_weight_norm(block[3])
|
102 |
+
|
103 |
+
|
104 |
+
class LVCBlock(torch.nn.Module):
|
105 |
+
'''the location-variable convolutions'''
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
in_channels,
|
110 |
+
cond_channels,
|
111 |
+
stride,
|
112 |
+
dilations=[1, 3, 9, 27],
|
113 |
+
lReLU_slope=0.2,
|
114 |
+
conv_kernel_size=3,
|
115 |
+
cond_hop_length=256,
|
116 |
+
kpnet_hidden_channels=64,
|
117 |
+
kpnet_conv_size=3,
|
118 |
+
kpnet_dropout=0.0,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.cond_hop_length = cond_hop_length
|
123 |
+
self.conv_layers = len(dilations)
|
124 |
+
self.conv_kernel_size = conv_kernel_size
|
125 |
+
|
126 |
+
self.kernel_predictor = KernelPredictor(
|
127 |
+
cond_channels=cond_channels,
|
128 |
+
conv_in_channels=in_channels,
|
129 |
+
conv_out_channels=2 * in_channels,
|
130 |
+
conv_layers=len(dilations),
|
131 |
+
conv_kernel_size=conv_kernel_size,
|
132 |
+
kpnet_hidden_channels=kpnet_hidden_channels,
|
133 |
+
kpnet_conv_size=kpnet_conv_size,
|
134 |
+
kpnet_dropout=kpnet_dropout,
|
135 |
+
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}
|
136 |
+
)
|
137 |
+
|
138 |
+
self.convt_pre = nn.Sequential(
|
139 |
+
nn.LeakyReLU(lReLU_slope),
|
140 |
+
nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride,
|
141 |
+
padding=stride // 2 + stride % 2, output_padding=stride % 2)),
|
142 |
+
)
|
143 |
+
|
144 |
+
self.conv_blocks = nn.ModuleList()
|
145 |
+
for dilation in dilations:
|
146 |
+
self.conv_blocks.append(
|
147 |
+
nn.Sequential(
|
148 |
+
nn.LeakyReLU(lReLU_slope),
|
149 |
+
nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size,
|
150 |
+
padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
|
151 |
+
nn.LeakyReLU(lReLU_slope),
|
152 |
+
)
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x, c):
|
156 |
+
''' forward propagation of the location-variable convolutions.
|
157 |
+
Args:
|
158 |
+
x (Tensor): the input sequence (batch, in_channels, in_length)
|
159 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tensor: the output sequence (batch, in_channels, in_length)
|
163 |
+
'''
|
164 |
+
_, in_channels, _ = x.shape # (B, c_g, L')
|
165 |
+
|
166 |
+
x = self.convt_pre(x) # (B, c_g, stride * L')
|
167 |
+
kernels, bias = self.kernel_predictor(c)
|
168 |
+
|
169 |
+
for i, conv in enumerate(self.conv_blocks):
|
170 |
+
output = conv(x) # (B, c_g, stride * L')
|
171 |
+
|
172 |
+
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
173 |
+
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
174 |
+
|
175 |
+
output = self.location_variable_convolution(output, k, b,
|
176 |
+
hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
|
177 |
+
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
178 |
+
output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
183 |
+
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
184 |
+
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
185 |
+
Args:
|
186 |
+
x (Tensor): the input sequence (batch, in_channels, in_length).
|
187 |
+
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
188 |
+
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
189 |
+
dilation (int): the dilation of convolution.
|
190 |
+
hop_size (int): the hop_size of the conditioning sequence.
|
191 |
+
Returns:
|
192 |
+
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
193 |
+
'''
|
194 |
+
batch, _, in_length = x.shape
|
195 |
+
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
196 |
+
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
197 |
+
|
198 |
+
padding = dilation * int((kernel_size - 1) / 2)
|
199 |
+
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
|
200 |
+
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
201 |
+
|
202 |
+
if hop_size < dilation:
|
203 |
+
x = F.pad(x, (0, dilation), 'constant', 0)
|
204 |
+
x = x.unfold(3, dilation,
|
205 |
+
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
206 |
+
x = x[:, :, :, :, :hop_size]
|
207 |
+
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
208 |
+
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
209 |
+
|
210 |
+
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
|
211 |
+
o = o.to(memory_format=torch.channels_last_3d)
|
212 |
+
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
213 |
+
o = o + bias
|
214 |
+
o = o.contiguous().view(batch, out_channels, -1)
|
215 |
+
|
216 |
+
return o
|
217 |
+
|
218 |
+
def remove_weight_norm(self):
|
219 |
+
self.kernel_predictor.remove_weight_norm()
|
220 |
+
nn.utils.remove_weight_norm(self.convt_pre[1])
|
221 |
+
for block in self.conv_blocks:
|
222 |
+
nn.utils.remove_weight_norm(block[1])
|
223 |
+
|
224 |
+
|
225 |
+
class UnivNetGenerator(nn.Module):
|
226 |
+
"""UnivNet Generator"""
|
227 |
+
|
228 |
+
def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
|
229 |
+
# Below are MEL configurations options that this generator requires.
|
230 |
+
hop_length=256, n_mel_channels=100):
|
231 |
+
super(UnivNetGenerator, self).__init__()
|
232 |
+
self.mel_channel = n_mel_channels
|
233 |
+
self.noise_dim = noise_dim
|
234 |
+
self.hop_length = hop_length
|
235 |
+
channel_size = channel_size
|
236 |
+
kpnet_conv_size = kpnet_conv_size
|
237 |
+
|
238 |
+
self.res_stack = nn.ModuleList()
|
239 |
+
hop_length = 1
|
240 |
+
for stride in strides:
|
241 |
+
hop_length = stride * hop_length
|
242 |
+
self.res_stack.append(
|
243 |
+
LVCBlock(
|
244 |
+
channel_size,
|
245 |
+
n_mel_channels,
|
246 |
+
stride=stride,
|
247 |
+
dilations=dilations,
|
248 |
+
lReLU_slope=lReLU_slope,
|
249 |
+
cond_hop_length=hop_length,
|
250 |
+
kpnet_conv_size=kpnet_conv_size
|
251 |
+
)
|
252 |
+
)
|
253 |
+
|
254 |
+
self.conv_pre = \
|
255 |
+
nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
|
256 |
+
|
257 |
+
self.conv_post = nn.Sequential(
|
258 |
+
nn.LeakyReLU(lReLU_slope),
|
259 |
+
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
|
260 |
+
nn.Tanh(),
|
261 |
+
)
|
262 |
+
|
263 |
+
def forward(self, c, z):
|
264 |
+
'''
|
265 |
+
Args:
|
266 |
+
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
|
267 |
+
z (Tensor): the noise sequence (batch, noise_dim, in_length)
|
268 |
+
|
269 |
+
'''
|
270 |
+
z = self.conv_pre(z) # (B, c_g, L)
|
271 |
+
|
272 |
+
for res_block in self.res_stack:
|
273 |
+
res_block.to(z.device)
|
274 |
+
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
|
275 |
+
|
276 |
+
z = self.conv_post(z) # (B, 1, L * 256)
|
277 |
+
|
278 |
+
return z
|
279 |
+
|
280 |
+
def eval(self, inference=False):
|
281 |
+
super(UnivNetGenerator, self).eval()
|
282 |
+
# don't remove weight norm while validation in training loop
|
283 |
+
if inference:
|
284 |
+
self.remove_weight_norm()
|
285 |
+
|
286 |
+
def remove_weight_norm(self):
|
287 |
+
print('Removing weight norm...')
|
288 |
+
|
289 |
+
nn.utils.remove_weight_norm(self.conv_pre)
|
290 |
+
|
291 |
+
for layer in self.conv_post:
|
292 |
+
if len(layer.state_dict()) != 0:
|
293 |
+
nn.utils.remove_weight_norm(layer)
|
294 |
+
|
295 |
+
for res_block in self.res_stack:
|
296 |
+
res_block.remove_weight_norm()
|
297 |
+
|
298 |
+
def inference(self, c, z=None):
|
299 |
+
# pad input mel with zeros to cut artifact
|
300 |
+
# see https://github.com/seungwonpark/melgan/issues/8
|
301 |
+
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
302 |
+
mel = torch.cat((c, zero), dim=2)
|
303 |
+
|
304 |
+
if z is None:
|
305 |
+
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
|
306 |
+
|
307 |
+
audio = self.forward(mel, z)
|
308 |
+
audio = audio[:, :, :-(self.hop_length * 10)]
|
309 |
+
audio = audio.clamp(min=-1, max=1)
|
310 |
+
return audio
|
311 |
+
|
312 |
+
|
313 |
+
if __name__ == '__main__':
|
314 |
+
model = UnivNetGenerator()
|
315 |
+
|
316 |
+
c = torch.randn(3, 100, 10)
|
317 |
+
z = torch.randn(3, 64, 10)
|
318 |
+
print(c.shape)
|
319 |
+
|
320 |
+
y = model(c, z)
|
321 |
+
print(y.shape)
|
322 |
+
assert y.shape == torch.Size([3, 1, 2560])
|
323 |
+
|
324 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
325 |
+
print(pytorch_total_params)
|
models/xtransformers.py
ADDED
@@ -0,0 +1,1302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
from torch import nn, einsum
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from functools import partial
|
7 |
+
from inspect import isfunction
|
8 |
+
from collections import namedtuple
|
9 |
+
|
10 |
+
from einops import rearrange, repeat, reduce
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
|
13 |
+
from entmax import entmax15
|
14 |
+
from torch.utils.checkpoint import checkpoint
|
15 |
+
|
16 |
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
17 |
+
|
18 |
+
DEFAULT_DIM_HEAD = 64
|
19 |
+
|
20 |
+
Intermediates = namedtuple('Intermediates', [
|
21 |
+
'pre_softmax_attn',
|
22 |
+
'post_softmax_attn'
|
23 |
+
])
|
24 |
+
|
25 |
+
LayerIntermediates = namedtuple('Intermediates', [
|
26 |
+
'hiddens',
|
27 |
+
'attn_intermediates',
|
28 |
+
'past_key_values',
|
29 |
+
])
|
30 |
+
|
31 |
+
|
32 |
+
# helpers
|
33 |
+
|
34 |
+
def exists(val):
|
35 |
+
return val is not None
|
36 |
+
|
37 |
+
|
38 |
+
def default(val, d):
|
39 |
+
if exists(val):
|
40 |
+
return val
|
41 |
+
return d() if isfunction(d) else d
|
42 |
+
|
43 |
+
|
44 |
+
def cast_tuple(val, depth):
|
45 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
46 |
+
|
47 |
+
|
48 |
+
class always():
|
49 |
+
def __init__(self, val):
|
50 |
+
self.val = val
|
51 |
+
|
52 |
+
def __call__(self, *args, **kwargs):
|
53 |
+
return self.val
|
54 |
+
|
55 |
+
|
56 |
+
class not_equals():
|
57 |
+
def __init__(self, val):
|
58 |
+
self.val = val
|
59 |
+
|
60 |
+
def __call__(self, x, *args, **kwargs):
|
61 |
+
return x != self.val
|
62 |
+
|
63 |
+
|
64 |
+
class equals():
|
65 |
+
def __init__(self, val):
|
66 |
+
self.val = val
|
67 |
+
|
68 |
+
def __call__(self, x, *args, **kwargs):
|
69 |
+
return x == self.val
|
70 |
+
|
71 |
+
|
72 |
+
def max_neg_value(tensor):
|
73 |
+
return -torch.finfo(tensor.dtype).max
|
74 |
+
|
75 |
+
|
76 |
+
def l2norm(t):
|
77 |
+
return F.normalize(t, p=2, dim=-1)
|
78 |
+
|
79 |
+
|
80 |
+
# init helpers
|
81 |
+
|
82 |
+
def init_zero_(layer):
|
83 |
+
nn.init.constant_(layer.weight, 0.)
|
84 |
+
if exists(layer.bias):
|
85 |
+
nn.init.constant_(layer.bias, 0.)
|
86 |
+
|
87 |
+
|
88 |
+
# keyword argument helpers
|
89 |
+
|
90 |
+
def pick_and_pop(keys, d):
|
91 |
+
values = list(map(lambda key: d.pop(key), keys))
|
92 |
+
return dict(zip(keys, values))
|
93 |
+
|
94 |
+
|
95 |
+
def group_dict_by_key(cond, d):
|
96 |
+
return_val = [dict(), dict()]
|
97 |
+
for key in d.keys():
|
98 |
+
match = bool(cond(key))
|
99 |
+
ind = int(not match)
|
100 |
+
return_val[ind][key] = d[key]
|
101 |
+
return (*return_val,)
|
102 |
+
|
103 |
+
|
104 |
+
def string_begins_with(prefix, str):
|
105 |
+
return str.startswith(prefix)
|
106 |
+
|
107 |
+
|
108 |
+
def group_by_key_prefix(prefix, d):
|
109 |
+
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
110 |
+
|
111 |
+
|
112 |
+
def groupby_prefix_and_trim(prefix, d):
|
113 |
+
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
114 |
+
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
115 |
+
return kwargs_without_prefix, kwargs
|
116 |
+
|
117 |
+
|
118 |
+
# activations
|
119 |
+
|
120 |
+
class ReluSquared(nn.Module):
|
121 |
+
def forward(self, x):
|
122 |
+
return F.relu(x) ** 2
|
123 |
+
|
124 |
+
|
125 |
+
# positional embeddings
|
126 |
+
|
127 |
+
class AbsolutePositionalEmbedding(nn.Module):
|
128 |
+
def __init__(self, dim, max_seq_len):
|
129 |
+
super().__init__()
|
130 |
+
self.scale = dim ** -0.5
|
131 |
+
self.emb = nn.Embedding(max_seq_len, dim)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
n = torch.arange(x.shape[1], device=x.device)
|
135 |
+
pos_emb = self.emb(n)
|
136 |
+
pos_emb = rearrange(pos_emb, 'n d -> () n d')
|
137 |
+
return pos_emb * self.scale
|
138 |
+
|
139 |
+
|
140 |
+
class FixedPositionalEmbedding(nn.Module):
|
141 |
+
def __init__(self, dim):
|
142 |
+
super().__init__()
|
143 |
+
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
144 |
+
self.register_buffer('inv_freq', inv_freq)
|
145 |
+
|
146 |
+
def forward(self, x, seq_dim=1, offset=0):
|
147 |
+
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
148 |
+
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
149 |
+
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
150 |
+
return rearrange(emb, 'n d -> () n d')
|
151 |
+
|
152 |
+
|
153 |
+
class RelativePositionBias(nn.Module):
|
154 |
+
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
|
155 |
+
super().__init__()
|
156 |
+
self.scale = scale
|
157 |
+
self.causal = causal
|
158 |
+
self.num_buckets = num_buckets
|
159 |
+
self.max_distance = max_distance
|
160 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
164 |
+
ret = 0
|
165 |
+
n = -relative_position
|
166 |
+
if not causal:
|
167 |
+
num_buckets //= 2
|
168 |
+
ret += (n < 0).long() * num_buckets
|
169 |
+
n = torch.abs(n)
|
170 |
+
else:
|
171 |
+
n = torch.max(n, torch.zeros_like(n))
|
172 |
+
|
173 |
+
max_exact = num_buckets // 2
|
174 |
+
is_small = n < max_exact
|
175 |
+
|
176 |
+
val_if_large = max_exact + (
|
177 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
178 |
+
).long()
|
179 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
180 |
+
|
181 |
+
ret += torch.where(is_small, n, val_if_large)
|
182 |
+
return ret
|
183 |
+
|
184 |
+
def forward(self, qk_dots):
|
185 |
+
i, j, device = *qk_dots.shape[-2:], qk_dots.device
|
186 |
+
q_pos = torch.arange(i, dtype=torch.long, device=device)
|
187 |
+
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
188 |
+
rel_pos = k_pos[None, :] - q_pos[:, None]
|
189 |
+
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
|
190 |
+
max_distance=self.max_distance)
|
191 |
+
values = self.relative_attention_bias(rp_bucket)
|
192 |
+
bias = rearrange(values, 'i j h -> () h i j')
|
193 |
+
return qk_dots + (bias * self.scale)
|
194 |
+
|
195 |
+
|
196 |
+
class AlibiPositionalBias(nn.Module):
|
197 |
+
def __init__(self, heads, **kwargs):
|
198 |
+
super().__init__()
|
199 |
+
self.heads = heads
|
200 |
+
slopes = torch.Tensor(self._get_slopes(heads))
|
201 |
+
slopes = rearrange(slopes, 'h -> () h () ()')
|
202 |
+
self.register_buffer('slopes', slopes, persistent=False)
|
203 |
+
self.register_buffer('bias', None, persistent=False)
|
204 |
+
|
205 |
+
@staticmethod
|
206 |
+
def _get_slopes(heads):
|
207 |
+
def get_slopes_power_of_2(n):
|
208 |
+
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
209 |
+
ratio = start
|
210 |
+
return [start * ratio ** i for i in range(n)]
|
211 |
+
|
212 |
+
if math.log2(heads).is_integer():
|
213 |
+
return get_slopes_power_of_2(heads)
|
214 |
+
|
215 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
|
216 |
+
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
|
217 |
+
:heads - closest_power_of_2]
|
218 |
+
|
219 |
+
def forward(self, qk_dots):
|
220 |
+
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
|
221 |
+
|
222 |
+
if exists(self.bias) and self.bias.shape[-1] >= j:
|
223 |
+
return qk_dots + self.bias[..., :j]
|
224 |
+
|
225 |
+
bias = torch.arange(j, device=device)
|
226 |
+
bias = rearrange(bias, 'j -> () () () j')
|
227 |
+
bias = bias * self.slopes
|
228 |
+
|
229 |
+
num_heads_unalibied = h - bias.shape[1]
|
230 |
+
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
|
231 |
+
|
232 |
+
self.register_buffer('bias', bias, persistent=False)
|
233 |
+
return qk_dots + self.bias
|
234 |
+
|
235 |
+
|
236 |
+
class LearnedAlibiPositionalBias(AlibiPositionalBias):
|
237 |
+
def __init__(self, heads, bidirectional=False):
|
238 |
+
super().__init__(heads)
|
239 |
+
los_slopes = torch.log(self.slopes)
|
240 |
+
self.learned_logslopes = nn.Parameter(los_slopes)
|
241 |
+
|
242 |
+
self.bidirectional = bidirectional
|
243 |
+
if self.bidirectional:
|
244 |
+
self.learned_logslopes_future = nn.Parameter(los_slopes)
|
245 |
+
|
246 |
+
def forward(self, qk_dots):
|
247 |
+
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
|
248 |
+
|
249 |
+
def get_slopes(param):
|
250 |
+
return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
|
251 |
+
|
252 |
+
if exists(self.bias) and self.bias.shape[-1] >= j:
|
253 |
+
bias = self.bias[..., :i, :j]
|
254 |
+
else:
|
255 |
+
i_arange = torch.arange(i, device=device)
|
256 |
+
j_arange = torch.arange(j, device=device)
|
257 |
+
bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
|
258 |
+
self.register_buffer('bias', bias, persistent=False)
|
259 |
+
|
260 |
+
if self.bidirectional:
|
261 |
+
past_slopes = get_slopes(self.learned_logslopes)
|
262 |
+
future_slopes = get_slopes(self.learned_logslopes_future)
|
263 |
+
bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
|
264 |
+
else:
|
265 |
+
slopes = get_slopes(self.learned_logslopes)
|
266 |
+
bias = bias * slopes
|
267 |
+
|
268 |
+
return qk_dots + bias
|
269 |
+
|
270 |
+
|
271 |
+
class RotaryEmbedding(nn.Module):
|
272 |
+
def __init__(self, dim):
|
273 |
+
super().__init__()
|
274 |
+
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
275 |
+
self.register_buffer('inv_freq', inv_freq)
|
276 |
+
|
277 |
+
def forward(self, max_seq_len, device):
|
278 |
+
t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
|
279 |
+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
280 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
281 |
+
return rearrange(emb, 'n d -> () () n d')
|
282 |
+
|
283 |
+
|
284 |
+
def rotate_half(x):
|
285 |
+
x = rearrange(x, '... (j d) -> ... j d', j=2)
|
286 |
+
x1, x2 = x.unbind(dim=-2)
|
287 |
+
return torch.cat((-x2, x1), dim=-1)
|
288 |
+
|
289 |
+
|
290 |
+
def apply_rotary_pos_emb(t, freqs):
|
291 |
+
seq_len = t.shape[-2]
|
292 |
+
freqs = freqs[:, :, -seq_len:]
|
293 |
+
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
|
294 |
+
|
295 |
+
|
296 |
+
# norms
|
297 |
+
|
298 |
+
class Scale(nn.Module):
|
299 |
+
def __init__(self, value, fn):
|
300 |
+
super().__init__()
|
301 |
+
self.value = value
|
302 |
+
self.fn = fn
|
303 |
+
|
304 |
+
def forward(self, x, **kwargs):
|
305 |
+
out = self.fn(x, **kwargs)
|
306 |
+
scale_fn = lambda t: t * self.value
|
307 |
+
|
308 |
+
if not isinstance(out, tuple):
|
309 |
+
return scale_fn(out)
|
310 |
+
|
311 |
+
return (scale_fn(out[0]), *out[1:])
|
312 |
+
|
313 |
+
|
314 |
+
class Rezero(nn.Module):
|
315 |
+
def __init__(self, fn):
|
316 |
+
super().__init__()
|
317 |
+
self.fn = fn
|
318 |
+
self.g = nn.Parameter(torch.zeros(1))
|
319 |
+
|
320 |
+
def forward(self, x, **kwargs):
|
321 |
+
out = self.fn(x, **kwargs)
|
322 |
+
rezero_fn = lambda t: t * self.g
|
323 |
+
|
324 |
+
if not isinstance(out, tuple):
|
325 |
+
return rezero_fn(out)
|
326 |
+
|
327 |
+
return (rezero_fn(out[0]), *out[1:])
|
328 |
+
|
329 |
+
|
330 |
+
class ScaleNorm(nn.Module):
|
331 |
+
def __init__(self, dim, eps=1e-5):
|
332 |
+
super().__init__()
|
333 |
+
self.scale = dim ** -0.5
|
334 |
+
self.eps = eps
|
335 |
+
self.g = nn.Parameter(torch.ones(1))
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
339 |
+
return x / norm.clamp(min=self.eps) * self.g
|
340 |
+
|
341 |
+
|
342 |
+
class RMSNorm(nn.Module):
|
343 |
+
def __init__(self, dim, eps=1e-8):
|
344 |
+
super().__init__()
|
345 |
+
self.scale = dim ** -0.5
|
346 |
+
self.eps = eps
|
347 |
+
self.g = nn.Parameter(torch.ones(dim))
|
348 |
+
|
349 |
+
def forward(self, x):
|
350 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
351 |
+
return x / norm.clamp(min=self.eps) * self.g
|
352 |
+
|
353 |
+
|
354 |
+
class RMSScaleShiftNorm(nn.Module):
|
355 |
+
def __init__(self, dim, eps=1e-8):
|
356 |
+
super().__init__()
|
357 |
+
self.scale = dim ** -0.5
|
358 |
+
self.eps = eps
|
359 |
+
self.g = nn.Parameter(torch.ones(dim))
|
360 |
+
self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
|
361 |
+
|
362 |
+
def forward(self, x, norm_scale_shift_inp):
|
363 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
364 |
+
norm = x / norm.clamp(min=self.eps) * self.g
|
365 |
+
|
366 |
+
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
|
367 |
+
scale, shift = torch.chunk(ss_emb, 2, dim=1)
|
368 |
+
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
369 |
+
return h
|
370 |
+
|
371 |
+
|
372 |
+
# residual and residual gates
|
373 |
+
|
374 |
+
class Residual(nn.Module):
|
375 |
+
def __init__(self, dim, scale_residual=False):
|
376 |
+
super().__init__()
|
377 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
378 |
+
|
379 |
+
def forward(self, x, residual):
|
380 |
+
if exists(self.residual_scale):
|
381 |
+
residual = residual * self.residual_scale
|
382 |
+
|
383 |
+
return x + residual
|
384 |
+
|
385 |
+
|
386 |
+
class GRUGating(nn.Module):
|
387 |
+
def __init__(self, dim, scale_residual=False):
|
388 |
+
super().__init__()
|
389 |
+
self.gru = nn.GRUCell(dim, dim)
|
390 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
391 |
+
|
392 |
+
def forward(self, x, residual):
|
393 |
+
if exists(self.residual_scale):
|
394 |
+
residual = residual * self.residual_scale
|
395 |
+
|
396 |
+
gated_output = self.gru(
|
397 |
+
rearrange(x, 'b n d -> (b n) d'),
|
398 |
+
rearrange(residual, 'b n d -> (b n) d')
|
399 |
+
)
|
400 |
+
|
401 |
+
return gated_output.reshape_as(x)
|
402 |
+
|
403 |
+
|
404 |
+
# token shifting
|
405 |
+
|
406 |
+
def shift(t, amount, mask=None):
|
407 |
+
if amount == 0:
|
408 |
+
return t
|
409 |
+
|
410 |
+
if exists(mask):
|
411 |
+
t = t.masked_fill(~mask[..., None], 0.)
|
412 |
+
|
413 |
+
return F.pad(t, (0, 0, amount, -amount), value=0.)
|
414 |
+
|
415 |
+
|
416 |
+
class ShiftTokens(nn.Module):
|
417 |
+
def __init__(self, shifts, fn):
|
418 |
+
super().__init__()
|
419 |
+
self.fn = fn
|
420 |
+
self.shifts = tuple(shifts)
|
421 |
+
|
422 |
+
def forward(self, x, **kwargs):
|
423 |
+
mask = kwargs.get('mask', None)
|
424 |
+
shifts = self.shifts
|
425 |
+
segments = len(shifts)
|
426 |
+
feats_per_shift = x.shape[-1] // segments
|
427 |
+
splitted = x.split(feats_per_shift, dim=-1)
|
428 |
+
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
429 |
+
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
|
430 |
+
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
431 |
+
return self.fn(x, **kwargs)
|
432 |
+
|
433 |
+
|
434 |
+
# feedforward
|
435 |
+
|
436 |
+
class GLU(nn.Module):
|
437 |
+
def __init__(self, dim_in, dim_out, activation):
|
438 |
+
super().__init__()
|
439 |
+
self.act = activation
|
440 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
441 |
+
|
442 |
+
def forward(self, x):
|
443 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
444 |
+
return x * self.act(gate)
|
445 |
+
|
446 |
+
|
447 |
+
class FeedForward(nn.Module):
|
448 |
+
def __init__(
|
449 |
+
self,
|
450 |
+
dim,
|
451 |
+
dim_out=None,
|
452 |
+
mult=4,
|
453 |
+
glu=False,
|
454 |
+
relu_squared=False,
|
455 |
+
post_act_ln=False,
|
456 |
+
dropout=0.,
|
457 |
+
zero_init_output=False
|
458 |
+
):
|
459 |
+
super().__init__()
|
460 |
+
inner_dim = int(dim * mult)
|
461 |
+
dim_out = default(dim_out, dim)
|
462 |
+
activation = ReluSquared() if relu_squared else nn.GELU()
|
463 |
+
|
464 |
+
project_in = nn.Sequential(
|
465 |
+
nn.Linear(dim, inner_dim),
|
466 |
+
activation
|
467 |
+
) if not glu else GLU(dim, inner_dim, activation)
|
468 |
+
|
469 |
+
self.net = nn.Sequential(
|
470 |
+
project_in,
|
471 |
+
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
472 |
+
nn.Dropout(dropout),
|
473 |
+
nn.Linear(inner_dim, dim_out)
|
474 |
+
)
|
475 |
+
|
476 |
+
# init last linear layer to 0
|
477 |
+
if zero_init_output:
|
478 |
+
init_zero_(self.net[-1])
|
479 |
+
|
480 |
+
def forward(self, x):
|
481 |
+
return self.net(x)
|
482 |
+
|
483 |
+
|
484 |
+
# attention.
|
485 |
+
|
486 |
+
class Attention(nn.Module):
|
487 |
+
def __init__(
|
488 |
+
self,
|
489 |
+
dim,
|
490 |
+
dim_head=DEFAULT_DIM_HEAD,
|
491 |
+
heads=8,
|
492 |
+
causal=False,
|
493 |
+
talking_heads=False,
|
494 |
+
head_scale=False,
|
495 |
+
collab_heads=False,
|
496 |
+
collab_compression=.3,
|
497 |
+
sparse_topk=None,
|
498 |
+
use_entmax15=False,
|
499 |
+
num_mem_kv=0,
|
500 |
+
dropout=0.,
|
501 |
+
on_attn=False,
|
502 |
+
gate_values=False,
|
503 |
+
zero_init_output=False,
|
504 |
+
max_attend_past=None,
|
505 |
+
qk_norm=False,
|
506 |
+
scale_init_value=None,
|
507 |
+
rel_pos_bias=False,
|
508 |
+
rel_pos_num_buckets=32,
|
509 |
+
rel_pos_max_distance=128,
|
510 |
+
):
|
511 |
+
super().__init__()
|
512 |
+
self.scale = dim_head ** -0.5
|
513 |
+
|
514 |
+
self.heads = heads
|
515 |
+
self.causal = causal
|
516 |
+
self.max_attend_past = max_attend_past
|
517 |
+
|
518 |
+
qk_dim = v_dim = dim_head * heads
|
519 |
+
|
520 |
+
# collaborative heads
|
521 |
+
self.collab_heads = collab_heads
|
522 |
+
if self.collab_heads:
|
523 |
+
qk_dim = int(collab_compression * qk_dim)
|
524 |
+
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
|
525 |
+
|
526 |
+
self.to_q = nn.Linear(dim, qk_dim, bias=False)
|
527 |
+
self.to_k = nn.Linear(dim, qk_dim, bias=False)
|
528 |
+
self.to_v = nn.Linear(dim, v_dim, bias=False)
|
529 |
+
|
530 |
+
self.dropout = nn.Dropout(dropout)
|
531 |
+
|
532 |
+
# add GLU gating for aggregated values, from alphafold2
|
533 |
+
self.to_v_gate = None
|
534 |
+
if gate_values:
|
535 |
+
self.to_v_gate = nn.Linear(dim, v_dim)
|
536 |
+
nn.init.constant_(self.to_v_gate.weight, 0)
|
537 |
+
nn.init.constant_(self.to_v_gate.bias, 1)
|
538 |
+
|
539 |
+
# cosine sim attention
|
540 |
+
self.qk_norm = qk_norm
|
541 |
+
if qk_norm:
|
542 |
+
scale_init_value = default(scale_init_value,
|
543 |
+
-3) # if not provided, initialize as though it were sequence length of 1024
|
544 |
+
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
|
545 |
+
|
546 |
+
# talking heads
|
547 |
+
self.talking_heads = talking_heads
|
548 |
+
if talking_heads:
|
549 |
+
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
550 |
+
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
551 |
+
|
552 |
+
# head scaling
|
553 |
+
self.head_scale = head_scale
|
554 |
+
if head_scale:
|
555 |
+
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
|
556 |
+
|
557 |
+
# explicit topk sparse attention
|
558 |
+
self.sparse_topk = sparse_topk
|
559 |
+
|
560 |
+
# entmax
|
561 |
+
self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
562 |
+
|
563 |
+
# add memory key / values
|
564 |
+
self.num_mem_kv = num_mem_kv
|
565 |
+
if num_mem_kv > 0:
|
566 |
+
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
567 |
+
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
568 |
+
|
569 |
+
# attention on attention
|
570 |
+
self.attn_on_attn = on_attn
|
571 |
+
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
|
572 |
+
|
573 |
+
self.rel_pos_bias = rel_pos_bias
|
574 |
+
if rel_pos_bias:
|
575 |
+
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
576 |
+
self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
|
577 |
+
num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
|
578 |
+
|
579 |
+
# init output projection 0
|
580 |
+
if zero_init_output:
|
581 |
+
init_zero_(self.to_out)
|
582 |
+
|
583 |
+
def forward(
|
584 |
+
self,
|
585 |
+
x,
|
586 |
+
context=None,
|
587 |
+
mask=None,
|
588 |
+
context_mask=None,
|
589 |
+
attn_mask=None,
|
590 |
+
sinusoidal_emb=None,
|
591 |
+
rotary_pos_emb=None,
|
592 |
+
prev_attn=None,
|
593 |
+
mem=None,
|
594 |
+
layer_past=None,
|
595 |
+
):
|
596 |
+
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
|
597 |
+
context)
|
598 |
+
kv_input = default(context, x)
|
599 |
+
|
600 |
+
q_input = x
|
601 |
+
k_input = kv_input
|
602 |
+
v_input = kv_input
|
603 |
+
|
604 |
+
if exists(mem):
|
605 |
+
k_input = torch.cat((mem, k_input), dim=-2)
|
606 |
+
v_input = torch.cat((mem, v_input), dim=-2)
|
607 |
+
|
608 |
+
if exists(sinusoidal_emb):
|
609 |
+
# in shortformer, the query would start at a position offset depending on the past cached memory
|
610 |
+
offset = k_input.shape[-2] - q_input.shape[-2]
|
611 |
+
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
612 |
+
k_input = k_input + sinusoidal_emb(k_input)
|
613 |
+
|
614 |
+
q = self.to_q(q_input)
|
615 |
+
k = self.to_k(k_input)
|
616 |
+
v = self.to_v(v_input)
|
617 |
+
|
618 |
+
if not collab_heads:
|
619 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
620 |
+
else:
|
621 |
+
q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
|
622 |
+
k = rearrange(k, 'b n d -> b () n d')
|
623 |
+
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
|
624 |
+
|
625 |
+
if layer_past is not None:
|
626 |
+
past_key, past_value = layer_past
|
627 |
+
k = torch.cat([past_key, k], dim=-2)
|
628 |
+
v = torch.cat([past_value, v], dim=-2)
|
629 |
+
k_cache = k
|
630 |
+
v_cache = v
|
631 |
+
|
632 |
+
if exists(rotary_pos_emb) and not has_context:
|
633 |
+
l = rotary_pos_emb.shape[-1]
|
634 |
+
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
635 |
+
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
636 |
+
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
637 |
+
|
638 |
+
input_mask = None
|
639 |
+
if any(map(exists, (mask, context_mask))):
|
640 |
+
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
641 |
+
k_mask = q_mask if not exists(context) else context_mask
|
642 |
+
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
643 |
+
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
644 |
+
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
645 |
+
input_mask = q_mask * k_mask
|
646 |
+
|
647 |
+
if self.num_mem_kv > 0:
|
648 |
+
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
649 |
+
k = torch.cat((mem_k, k), dim=-2)
|
650 |
+
v = torch.cat((mem_v, v), dim=-2)
|
651 |
+
if exists(input_mask):
|
652 |
+
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
653 |
+
|
654 |
+
if collab_heads:
|
655 |
+
k = k.expand(-1, h, -1, -1)
|
656 |
+
|
657 |
+
if self.qk_norm:
|
658 |
+
q, k = map(l2norm, (q, k))
|
659 |
+
scale = 1 / (self.scale.exp().clamp(min=1e-2))
|
660 |
+
|
661 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
|
662 |
+
mask_value = max_neg_value(dots)
|
663 |
+
|
664 |
+
if exists(prev_attn):
|
665 |
+
dots = dots + prev_attn
|
666 |
+
|
667 |
+
pre_softmax_attn = dots.clone()
|
668 |
+
|
669 |
+
if talking_heads:
|
670 |
+
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
671 |
+
|
672 |
+
if self.rel_pos_bias:
|
673 |
+
dots = self.rel_pos(dots)
|
674 |
+
|
675 |
+
if exists(input_mask):
|
676 |
+
dots.masked_fill_(~input_mask, mask_value)
|
677 |
+
del input_mask
|
678 |
+
|
679 |
+
if exists(attn_mask):
|
680 |
+
assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
|
681 |
+
if attn_mask.ndim == 2:
|
682 |
+
attn_mask = rearrange(attn_mask, 'i j -> () () i j')
|
683 |
+
elif attn_mask.ndim == 3:
|
684 |
+
attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
|
685 |
+
dots.masked_fill_(~attn_mask, mask_value)
|
686 |
+
|
687 |
+
if exists(self.max_attend_past):
|
688 |
+
i, j = dots.shape[-2:]
|
689 |
+
range_q = torch.arange(j - i, j, device=device)
|
690 |
+
range_k = torch.arange(j, device=device)
|
691 |
+
dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
|
692 |
+
mask = dist > self.max_attend_past
|
693 |
+
dots.masked_fill_(mask, mask_value)
|
694 |
+
del mask
|
695 |
+
|
696 |
+
if self.causal:
|
697 |
+
i, j = dots.shape[-2:]
|
698 |
+
r = torch.arange(i, device=device)
|
699 |
+
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
700 |
+
mask = F.pad(mask, (j - i, 0), value=False)
|
701 |
+
dots.masked_fill_(mask, mask_value)
|
702 |
+
del mask
|
703 |
+
|
704 |
+
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
705 |
+
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
706 |
+
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
707 |
+
mask = dots < vk
|
708 |
+
dots.masked_fill_(mask, mask_value)
|
709 |
+
del mask
|
710 |
+
|
711 |
+
attn = self.attn_fn(dots, dim=-1)
|
712 |
+
post_softmax_attn = attn.clone()
|
713 |
+
|
714 |
+
attn = self.dropout(attn)
|
715 |
+
|
716 |
+
if talking_heads:
|
717 |
+
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
718 |
+
|
719 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
720 |
+
|
721 |
+
if head_scale:
|
722 |
+
out = out * self.head_scale_params
|
723 |
+
|
724 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
725 |
+
|
726 |
+
if exists(self.to_v_gate):
|
727 |
+
gates = self.to_v_gate(x)
|
728 |
+
out = out * gates.sigmoid()
|
729 |
+
|
730 |
+
intermediates = Intermediates(
|
731 |
+
pre_softmax_attn=pre_softmax_attn,
|
732 |
+
post_softmax_attn=post_softmax_attn
|
733 |
+
)
|
734 |
+
|
735 |
+
return self.to_out(out), intermediates, k_cache, v_cache
|
736 |
+
|
737 |
+
|
738 |
+
class AttentionLayers(nn.Module):
|
739 |
+
def __init__(
|
740 |
+
self,
|
741 |
+
dim,
|
742 |
+
depth,
|
743 |
+
heads=8,
|
744 |
+
causal=False,
|
745 |
+
cross_attend=False,
|
746 |
+
only_cross=False,
|
747 |
+
use_scalenorm=False,
|
748 |
+
use_rms_scaleshift_norm=False,
|
749 |
+
use_rmsnorm=False,
|
750 |
+
use_rezero=False,
|
751 |
+
alibi_pos_bias=False,
|
752 |
+
alibi_num_heads=None,
|
753 |
+
alibi_learned=False,
|
754 |
+
position_infused_attn=False,
|
755 |
+
rotary_pos_emb=False,
|
756 |
+
rotary_emb_dim=None,
|
757 |
+
custom_layers=None,
|
758 |
+
sandwich_coef=None,
|
759 |
+
par_ratio=None,
|
760 |
+
residual_attn=False,
|
761 |
+
cross_residual_attn=False,
|
762 |
+
macaron=False,
|
763 |
+
pre_norm=True,
|
764 |
+
gate_residual=False,
|
765 |
+
scale_residual=False,
|
766 |
+
shift_tokens=0,
|
767 |
+
sandwich_norm=False,
|
768 |
+
use_qk_norm_attn=False,
|
769 |
+
qk_norm_attn_seq_len=None,
|
770 |
+
zero_init_branch_output=False,
|
771 |
+
**kwargs
|
772 |
+
):
|
773 |
+
super().__init__()
|
774 |
+
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
775 |
+
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
776 |
+
|
777 |
+
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
778 |
+
|
779 |
+
self.dim = dim
|
780 |
+
self.depth = depth
|
781 |
+
self.layers = nn.ModuleList([])
|
782 |
+
self.causal = causal
|
783 |
+
|
784 |
+
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
785 |
+
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
786 |
+
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
787 |
+
|
788 |
+
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
|
789 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
|
790 |
+
|
791 |
+
assert not (
|
792 |
+
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
793 |
+
|
794 |
+
if alibi_pos_bias:
|
795 |
+
alibi_num_heads = default(alibi_num_heads, heads)
|
796 |
+
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
797 |
+
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
|
798 |
+
self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
|
799 |
+
else:
|
800 |
+
self.rel_pos = None
|
801 |
+
|
802 |
+
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
803 |
+
self.pre_norm = pre_norm
|
804 |
+
self.sandwich_norm = sandwich_norm
|
805 |
+
|
806 |
+
self.residual_attn = residual_attn
|
807 |
+
self.cross_residual_attn = cross_residual_attn
|
808 |
+
self.cross_attend = cross_attend
|
809 |
+
|
810 |
+
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
811 |
+
norm_class = RMSNorm if use_rmsnorm else norm_class
|
812 |
+
norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
|
813 |
+
norm_fn = partial(norm_class, dim)
|
814 |
+
|
815 |
+
norm_fn = nn.Identity if use_rezero else norm_fn
|
816 |
+
branch_fn = Rezero if use_rezero else None
|
817 |
+
|
818 |
+
if cross_attend and not only_cross:
|
819 |
+
default_block = ('a', 'c', 'f')
|
820 |
+
elif cross_attend and only_cross:
|
821 |
+
default_block = ('c', 'f')
|
822 |
+
else:
|
823 |
+
default_block = ('a', 'f')
|
824 |
+
|
825 |
+
if macaron:
|
826 |
+
default_block = ('f',) + default_block
|
827 |
+
|
828 |
+
# qk normalization
|
829 |
+
|
830 |
+
if use_qk_norm_attn:
|
831 |
+
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
|
832 |
+
qk_norm_attn_seq_len) else None
|
833 |
+
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
|
834 |
+
|
835 |
+
# zero init
|
836 |
+
|
837 |
+
if zero_init_branch_output:
|
838 |
+
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
839 |
+
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
840 |
+
|
841 |
+
# calculate layer block order
|
842 |
+
|
843 |
+
if exists(custom_layers):
|
844 |
+
layer_types = custom_layers
|
845 |
+
elif exists(par_ratio):
|
846 |
+
par_depth = depth * len(default_block)
|
847 |
+
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
848 |
+
default_block = tuple(filter(not_equals('f'), default_block))
|
849 |
+
par_attn = par_depth // par_ratio
|
850 |
+
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
851 |
+
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
852 |
+
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
853 |
+
par_block = default_block + ('f',) * (par_width - len(default_block))
|
854 |
+
par_head = par_block * par_attn
|
855 |
+
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
856 |
+
elif exists(sandwich_coef):
|
857 |
+
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
858 |
+
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
859 |
+
else:
|
860 |
+
layer_types = default_block * depth
|
861 |
+
|
862 |
+
self.layer_types = layer_types
|
863 |
+
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
864 |
+
|
865 |
+
# calculate token shifting
|
866 |
+
|
867 |
+
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
868 |
+
|
869 |
+
# iterate and construct layers
|
870 |
+
|
871 |
+
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
872 |
+
is_last_layer = ind == (len(self.layer_types) - 1)
|
873 |
+
|
874 |
+
if layer_type == 'a':
|
875 |
+
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
876 |
+
elif layer_type == 'c':
|
877 |
+
layer = Attention(dim, heads=heads, **attn_kwargs)
|
878 |
+
elif layer_type == 'f':
|
879 |
+
layer = FeedForward(dim, **ff_kwargs)
|
880 |
+
layer = layer if not macaron else Scale(0.5, layer)
|
881 |
+
else:
|
882 |
+
raise Exception(f'invalid layer type {layer_type}')
|
883 |
+
|
884 |
+
if layer_shift_tokens > 0:
|
885 |
+
shift_range_upper = layer_shift_tokens + 1
|
886 |
+
shift_range_lower = -layer_shift_tokens if not causal else 0
|
887 |
+
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
|
888 |
+
|
889 |
+
if exists(branch_fn):
|
890 |
+
layer = branch_fn(layer)
|
891 |
+
|
892 |
+
residual_fn = GRUGating if gate_residual else Residual
|
893 |
+
residual = residual_fn(dim, scale_residual=scale_residual)
|
894 |
+
|
895 |
+
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
|
896 |
+
|
897 |
+
pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
|
898 |
+
post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
|
899 |
+
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
|
900 |
+
|
901 |
+
norms = nn.ModuleList([
|
902 |
+
pre_branch_norm,
|
903 |
+
post_branch_norm,
|
904 |
+
post_main_norm
|
905 |
+
])
|
906 |
+
|
907 |
+
self.layers.append(nn.ModuleList([
|
908 |
+
norms,
|
909 |
+
layer,
|
910 |
+
residual
|
911 |
+
]))
|
912 |
+
|
913 |
+
def forward(
|
914 |
+
self,
|
915 |
+
x,
|
916 |
+
context=None,
|
917 |
+
full_context=None, # for passing a list of hidden states from an encoder
|
918 |
+
mask=None,
|
919 |
+
context_mask=None,
|
920 |
+
attn_mask=None,
|
921 |
+
mems=None,
|
922 |
+
return_hiddens=False,
|
923 |
+
norm_scale_shift_inp=None,
|
924 |
+
past_key_values=None,
|
925 |
+
expected_seq_len=None,
|
926 |
+
):
|
927 |
+
|
928 |
+
assert not (self.cross_attend ^ (exists(context) or exists(
|
929 |
+
full_context))), 'context must be passed in if cross_attend is set to True'
|
930 |
+
assert context is None or full_context is None, 'only one of full_context or context can be provided'
|
931 |
+
|
932 |
+
hiddens = []
|
933 |
+
intermediates = []
|
934 |
+
prev_attn = None
|
935 |
+
prev_cross_attn = None
|
936 |
+
|
937 |
+
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
938 |
+
norm_args = {}
|
939 |
+
if exists(norm_scale_shift_inp):
|
940 |
+
norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
|
941 |
+
|
942 |
+
rotary_pos_emb = None
|
943 |
+
if exists(self.rotary_pos_emb):
|
944 |
+
if not self.training and self.causal:
|
945 |
+
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
|
946 |
+
elif expected_seq_len is None:
|
947 |
+
expected_seq_len = 0
|
948 |
+
seq_len = x.shape[1]
|
949 |
+
if past_key_values is not None:
|
950 |
+
seq_len += past_key_values[0][0].shape[-2]
|
951 |
+
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
|
952 |
+
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
953 |
+
|
954 |
+
present_key_values = []
|
955 |
+
cross_attn_count = 0
|
956 |
+
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
957 |
+
if layer_type == 'a':
|
958 |
+
layer_mem = mems.pop(0) if mems else None
|
959 |
+
|
960 |
+
residual = x
|
961 |
+
|
962 |
+
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
963 |
+
|
964 |
+
if exists(pre_branch_norm):
|
965 |
+
x = pre_branch_norm(x, **norm_args)
|
966 |
+
|
967 |
+
if layer_type == 'a' or layer_type == 'c':
|
968 |
+
if past_key_values is not None:
|
969 |
+
layer_kv = past_key_values.pop(0)
|
970 |
+
layer_past = tuple(s.to(x.device) for s in layer_kv)
|
971 |
+
else:
|
972 |
+
layer_past = None
|
973 |
+
|
974 |
+
if layer_type == 'a':
|
975 |
+
out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
976 |
+
prev_attn, layer_mem, layer_past)
|
977 |
+
elif layer_type == 'c':
|
978 |
+
if exists(full_context):
|
979 |
+
out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
|
980 |
+
None, prev_attn, None, layer_past)
|
981 |
+
else:
|
982 |
+
out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
983 |
+
elif layer_type == 'f':
|
984 |
+
out = checkpoint(block, x)
|
985 |
+
|
986 |
+
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
987 |
+
present_key_values.append((k.detach(), v.detach()))
|
988 |
+
|
989 |
+
if exists(post_branch_norm):
|
990 |
+
out = post_branch_norm(out, **norm_args)
|
991 |
+
|
992 |
+
x = residual_fn(out, residual)
|
993 |
+
|
994 |
+
if layer_type in ('a', 'c'):
|
995 |
+
intermediates.append(inter)
|
996 |
+
|
997 |
+
if layer_type == 'a' and self.residual_attn:
|
998 |
+
prev_attn = inter.pre_softmax_attn
|
999 |
+
elif layer_type == 'c' and self.cross_residual_attn:
|
1000 |
+
prev_cross_attn = inter.pre_softmax_attn
|
1001 |
+
|
1002 |
+
if exists(post_main_norm):
|
1003 |
+
x = post_main_norm(x, **norm_args)
|
1004 |
+
|
1005 |
+
if layer_type == 'c':
|
1006 |
+
cross_attn_count += 1
|
1007 |
+
|
1008 |
+
if layer_type == 'f':
|
1009 |
+
hiddens.append(x)
|
1010 |
+
|
1011 |
+
if return_hiddens:
|
1012 |
+
intermediates = LayerIntermediates(
|
1013 |
+
hiddens=hiddens,
|
1014 |
+
attn_intermediates=intermediates,
|
1015 |
+
past_key_values=present_key_values
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
return x, intermediates
|
1019 |
+
|
1020 |
+
return x
|
1021 |
+
|
1022 |
+
|
1023 |
+
class Encoder(AttentionLayers):
|
1024 |
+
def __init__(self, **kwargs):
|
1025 |
+
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
1026 |
+
super().__init__(causal=False, **kwargs)
|
1027 |
+
|
1028 |
+
|
1029 |
+
class Decoder(AttentionLayers):
|
1030 |
+
def __init__(self, **kwargs):
|
1031 |
+
assert 'causal' not in kwargs, 'cannot set causality on decoder'
|
1032 |
+
super().__init__(causal=True, **kwargs)
|
1033 |
+
|
1034 |
+
|
1035 |
+
class CrossAttender(AttentionLayers):
|
1036 |
+
def __init__(self, **kwargs):
|
1037 |
+
super().__init__(cross_attend=True, only_cross=True, **kwargs)
|
1038 |
+
|
1039 |
+
|
1040 |
+
class ViTransformerWrapper(nn.Module):
|
1041 |
+
def __init__(
|
1042 |
+
self,
|
1043 |
+
*,
|
1044 |
+
image_size,
|
1045 |
+
patch_size,
|
1046 |
+
attn_layers,
|
1047 |
+
num_classes=None,
|
1048 |
+
dropout=0.,
|
1049 |
+
emb_dropout=0.
|
1050 |
+
):
|
1051 |
+
super().__init__()
|
1052 |
+
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
|
1053 |
+
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
1054 |
+
dim = attn_layers.dim
|
1055 |
+
num_patches = (image_size // patch_size) ** 2
|
1056 |
+
patch_dim = 3 * patch_size ** 2
|
1057 |
+
|
1058 |
+
self.patch_size = patch_size
|
1059 |
+
|
1060 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
1061 |
+
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
1062 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
1063 |
+
self.dropout = nn.Dropout(emb_dropout)
|
1064 |
+
|
1065 |
+
self.attn_layers = attn_layers
|
1066 |
+
self.norm = nn.LayerNorm(dim)
|
1067 |
+
self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
|
1068 |
+
|
1069 |
+
def forward(
|
1070 |
+
self,
|
1071 |
+
img,
|
1072 |
+
return_embeddings=False
|
1073 |
+
):
|
1074 |
+
p = self.patch_size
|
1075 |
+
|
1076 |
+
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
1077 |
+
x = self.patch_to_embedding(x)
|
1078 |
+
b, n, _ = x.shape
|
1079 |
+
|
1080 |
+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
1081 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
1082 |
+
x = x + self.pos_embedding[:, :(n + 1)]
|
1083 |
+
x = self.dropout(x)
|
1084 |
+
|
1085 |
+
x = self.attn_layers(x)
|
1086 |
+
x = self.norm(x)
|
1087 |
+
|
1088 |
+
if not exists(self.mlp_head) or return_embeddings:
|
1089 |
+
return x
|
1090 |
+
|
1091 |
+
return self.mlp_head(x[:, 0])
|
1092 |
+
|
1093 |
+
|
1094 |
+
class TransformerWrapper(nn.Module):
|
1095 |
+
def __init__(
|
1096 |
+
self,
|
1097 |
+
*,
|
1098 |
+
num_tokens,
|
1099 |
+
max_seq_len,
|
1100 |
+
attn_layers,
|
1101 |
+
emb_dim=None,
|
1102 |
+
max_mem_len=0.,
|
1103 |
+
shift_mem_down=0,
|
1104 |
+
emb_dropout=0.,
|
1105 |
+
num_memory_tokens=None,
|
1106 |
+
tie_embedding=False,
|
1107 |
+
use_pos_emb=True
|
1108 |
+
):
|
1109 |
+
super().__init__()
|
1110 |
+
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1111 |
+
|
1112 |
+
dim = attn_layers.dim
|
1113 |
+
emb_dim = default(emb_dim, dim)
|
1114 |
+
|
1115 |
+
self.max_seq_len = max_seq_len
|
1116 |
+
self.max_mem_len = max_mem_len
|
1117 |
+
self.shift_mem_down = shift_mem_down
|
1118 |
+
|
1119 |
+
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
1120 |
+
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
1121 |
+
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
1122 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
1123 |
+
|
1124 |
+
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
1125 |
+
self.attn_layers = attn_layers
|
1126 |
+
self.norm = nn.LayerNorm(dim)
|
1127 |
+
|
1128 |
+
self.init_()
|
1129 |
+
|
1130 |
+
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
1131 |
+
|
1132 |
+
# memory tokens (like [cls]) from Memory Transformers paper
|
1133 |
+
num_memory_tokens = default(num_memory_tokens, 0)
|
1134 |
+
self.num_memory_tokens = num_memory_tokens
|
1135 |
+
if num_memory_tokens > 0:
|
1136 |
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
1137 |
+
|
1138 |
+
def init_(self):
|
1139 |
+
nn.init.kaiming_normal_(self.token_emb.weight)
|
1140 |
+
|
1141 |
+
def forward(
|
1142 |
+
self,
|
1143 |
+
x,
|
1144 |
+
return_embeddings=False,
|
1145 |
+
mask=None,
|
1146 |
+
return_hiddens=False,
|
1147 |
+
return_attn=False,
|
1148 |
+
mems=None,
|
1149 |
+
use_cache=False,
|
1150 |
+
**kwargs
|
1151 |
+
):
|
1152 |
+
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
1153 |
+
x = self.token_emb(x)
|
1154 |
+
x = x + self.pos_emb(x)
|
1155 |
+
x = self.emb_dropout(x)
|
1156 |
+
|
1157 |
+
x = self.project_emb(x)
|
1158 |
+
|
1159 |
+
if num_mem > 0:
|
1160 |
+
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
1161 |
+
x = torch.cat((mem, x), dim=1)
|
1162 |
+
|
1163 |
+
# auto-handle masking after appending memory tokens
|
1164 |
+
if exists(mask):
|
1165 |
+
mask = F.pad(mask, (num_mem, 0), value=True)
|
1166 |
+
|
1167 |
+
if self.shift_mem_down and exists(mems):
|
1168 |
+
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
1169 |
+
mems = [*mems_r, *mems_l]
|
1170 |
+
|
1171 |
+
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
1172 |
+
x = self.norm(x)
|
1173 |
+
|
1174 |
+
mem, x = x[:, :num_mem], x[:, num_mem:]
|
1175 |
+
|
1176 |
+
out = self.to_logits(x) if not return_embeddings else x
|
1177 |
+
|
1178 |
+
if return_hiddens:
|
1179 |
+
hiddens = intermediates.hiddens
|
1180 |
+
return out, hiddens
|
1181 |
+
|
1182 |
+
res = [out]
|
1183 |
+
if return_attn:
|
1184 |
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1185 |
+
res.append(attn_maps)
|
1186 |
+
if use_cache:
|
1187 |
+
res.append(intermediates.past_key_values)
|
1188 |
+
|
1189 |
+
if len(res) > 1:
|
1190 |
+
return tuple(res)
|
1191 |
+
return res[0]
|
1192 |
+
|
1193 |
+
|
1194 |
+
class ContinuousTransformerWrapper(nn.Module):
|
1195 |
+
def __init__(
|
1196 |
+
self,
|
1197 |
+
*,
|
1198 |
+
max_seq_len,
|
1199 |
+
attn_layers,
|
1200 |
+
dim_in=None,
|
1201 |
+
dim_out=None,
|
1202 |
+
emb_dim=None,
|
1203 |
+
emb_dropout=0.,
|
1204 |
+
use_pos_emb=True
|
1205 |
+
):
|
1206 |
+
super().__init__()
|
1207 |
+
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
1208 |
+
|
1209 |
+
dim = attn_layers.dim
|
1210 |
+
|
1211 |
+
self.max_seq_len = max_seq_len
|
1212 |
+
|
1213 |
+
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
|
1214 |
+
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
1215 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
1216 |
+
|
1217 |
+
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
1218 |
+
|
1219 |
+
self.attn_layers = attn_layers
|
1220 |
+
self.norm = nn.LayerNorm(dim)
|
1221 |
+
|
1222 |
+
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
1223 |
+
|
1224 |
+
def forward(
|
1225 |
+
self,
|
1226 |
+
x,
|
1227 |
+
return_embeddings=False,
|
1228 |
+
mask=None,
|
1229 |
+
return_attn=False,
|
1230 |
+
mems=None,
|
1231 |
+
use_cache=False,
|
1232 |
+
**kwargs
|
1233 |
+
):
|
1234 |
+
b, n, _, device = *x.shape, x.device
|
1235 |
+
|
1236 |
+
x = self.project_in(x)
|
1237 |
+
x = x + self.pos_emb(x)
|
1238 |
+
x = self.emb_dropout(x)
|
1239 |
+
|
1240 |
+
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
1241 |
+
x = self.norm(x)
|
1242 |
+
|
1243 |
+
out = self.project_out(x) if not return_embeddings else x
|
1244 |
+
|
1245 |
+
res = [out]
|
1246 |
+
if return_attn:
|
1247 |
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
1248 |
+
res.append(attn_maps)
|
1249 |
+
if use_cache:
|
1250 |
+
res.append(intermediates.past_key_values)
|
1251 |
+
|
1252 |
+
if len(res) > 1:
|
1253 |
+
return tuple(res)
|
1254 |
+
return res[0]
|
1255 |
+
|
1256 |
+
|
1257 |
+
class XTransformer(nn.Module):
|
1258 |
+
def __init__(
|
1259 |
+
self,
|
1260 |
+
*,
|
1261 |
+
dim,
|
1262 |
+
tie_token_emb=False,
|
1263 |
+
**kwargs
|
1264 |
+
):
|
1265 |
+
super().__init__()
|
1266 |
+
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
|
1267 |
+
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
|
1268 |
+
|
1269 |
+
assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
|
1270 |
+
enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
|
1271 |
+
enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
|
1272 |
+
enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
|
1273 |
+
enc_transformer_kwargs['use_pos_emb'] = enc_kwargs.pop('use_pos_emb', True)
|
1274 |
+
|
1275 |
+
dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
|
1276 |
+
dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
|
1277 |
+
dec_transformer_kwargs['use_pos_emb'] = dec_kwargs.pop('use_pos_emb', True)
|
1278 |
+
|
1279 |
+
self.encoder = TransformerWrapper(
|
1280 |
+
**enc_transformer_kwargs,
|
1281 |
+
attn_layers=Encoder(dim=dim, **enc_kwargs)
|
1282 |
+
)
|
1283 |
+
|
1284 |
+
self.decoder = TransformerWrapper(
|
1285 |
+
**dec_transformer_kwargs,
|
1286 |
+
attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs)
|
1287 |
+
)
|
1288 |
+
|
1289 |
+
if tie_token_emb:
|
1290 |
+
self.decoder.token_emb = self.encoder.token_emb
|
1291 |
+
|
1292 |
+
self.decoder = AutoregressiveWrapper(self.decoder)
|
1293 |
+
|
1294 |
+
@torch.no_grad()
|
1295 |
+
def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs):
|
1296 |
+
encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
|
1297 |
+
return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs)
|
1298 |
+
|
1299 |
+
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None):
|
1300 |
+
enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
|
1301 |
+
out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask)
|
1302 |
+
return out
|
read.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from api import TextToSpeech, load_conditioning
|
9 |
+
from utils.audio import load_audio, get_voices
|
10 |
+
from utils.tokenizer import VoiceBpeTokenizer
|
11 |
+
|
12 |
+
def split_and_recombine_text(texts, desired_length=200, max_len=300):
|
13 |
+
# TODO: also split across '!' and '?'. Attempt to keep quotations together.
|
14 |
+
texts = [s.strip() + "." for s in texts.split('.')]
|
15 |
+
|
16 |
+
i = 0
|
17 |
+
while i < len(texts):
|
18 |
+
ltxt = texts[i]
|
19 |
+
if len(ltxt) >= desired_length or i == len(texts)-1:
|
20 |
+
i += 1
|
21 |
+
continue
|
22 |
+
if len(ltxt) + len(texts[i+1]) > max_len:
|
23 |
+
i += 1
|
24 |
+
continue
|
25 |
+
texts[i] = f'{ltxt} {texts[i+1]}'
|
26 |
+
texts.pop(i+1)
|
27 |
+
return texts
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood2.txt")
|
32 |
+
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
|
33 |
+
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='patrick_stewart')
|
34 |
+
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')
|
35 |
+
parser.add_argument('--generation_preset', type=str, help='Preset to use for generation', default='standard')
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
outpath = args.output_path
|
39 |
+
voices = get_voices()
|
40 |
+
selected_voices = args.voice.split(',')
|
41 |
+
for selected_voice in selected_voices:
|
42 |
+
voice_outpath = os.path.join(outpath, selected_voice)
|
43 |
+
os.makedirs(voice_outpath, exist_ok=True)
|
44 |
+
|
45 |
+
with open(args.textfile, 'r', encoding='utf-8') as f:
|
46 |
+
text = ''.join([l for l in f.readlines()])
|
47 |
+
texts = split_and_recombine_text(text)
|
48 |
+
tts = TextToSpeech()
|
49 |
+
|
50 |
+
if '&' in selected_voice:
|
51 |
+
voice_sel = selected_voice.split('&')
|
52 |
+
else:
|
53 |
+
voice_sel = [selected_voice]
|
54 |
+
cond_paths = []
|
55 |
+
for vsel in voice_sel:
|
56 |
+
if vsel not in voices.keys():
|
57 |
+
print(f'Error: voice {vsel} not available. Skipping.')
|
58 |
+
continue
|
59 |
+
cond_paths.extend(voices[vsel])
|
60 |
+
if not cond_paths:
|
61 |
+
print('Error: no valid voices specified. Try again.')
|
62 |
+
|
63 |
+
priors = []
|
64 |
+
for j, text in enumerate(texts):
|
65 |
+
conds = priors.copy()
|
66 |
+
for cond_path in cond_paths:
|
67 |
+
c = load_audio(cond_path, 22050)
|
68 |
+
conds.append(c)
|
69 |
+
gen = tts.tts_with_preset(text, conds, preset=args.generation_preset)
|
70 |
+
torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen.squeeze(0).cpu(), 24000)
|
71 |
+
|
72 |
+
priors.append(torchaudio.functional.resample(gen, 24000, 22050).squeeze(0))
|
73 |
+
while len(priors) > 2:
|
74 |
+
priors.pop(0)
|
75 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
rotary_embedding_torch
|
4 |
+
transformers
|
5 |
+
tokenizers
|
6 |
+
inflect
|
7 |
+
progressbar
|
8 |
+
einops
|
9 |
+
unidecode
|
10 |
+
x-transformers
|
sweep.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from random import shuffle
|
3 |
+
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
from api import TextToSpeech
|
7 |
+
from utils.audio import load_audio
|
8 |
+
|
9 |
+
|
10 |
+
def permutations(args):
|
11 |
+
res = []
|
12 |
+
k = next(iter(args.keys()))
|
13 |
+
vals = args[k]
|
14 |
+
del args[k]
|
15 |
+
if not args:
|
16 |
+
return [{k: v} for v in vals]
|
17 |
+
lower = permutations(args)
|
18 |
+
for v in vals:
|
19 |
+
for l in lower:
|
20 |
+
lc = l.copy()
|
21 |
+
lc[k] = v
|
22 |
+
res.append(lc)
|
23 |
+
return res
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
fname = 'Y:\\clips\\books2\\subset512-oco.tsv'
|
28 |
+
stop_after = 512
|
29 |
+
outpath_base = 'D:\\tmp\\tortoise-tts-eval\\sweep-2'
|
30 |
+
outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real'
|
31 |
+
|
32 |
+
arg_ranges = {
|
33 |
+
'top_p': [.8,1],
|
34 |
+
'temperature': [.8,.9,1],
|
35 |
+
'diffusion_temperature': [.8,1],
|
36 |
+
'cond_free_k': [1,2,5,10],
|
37 |
+
}
|
38 |
+
cfgs = permutations(arg_ranges)
|
39 |
+
shuffle(cfgs)
|
40 |
+
|
41 |
+
for cfg in cfgs:
|
42 |
+
cfg_desc = '_'.join([f'{k}-{v}' for k,v in cfg.items()])
|
43 |
+
outpath = os.path.join(outpath_base, f'{cfg_desc}')
|
44 |
+
os.makedirs(outpath, exist_ok=True)
|
45 |
+
os.makedirs(outpath_real, exist_ok=True)
|
46 |
+
with open(fname, 'r', encoding='utf-8') as f:
|
47 |
+
lines = [l.strip().split('\t') for l in f.readlines()]
|
48 |
+
|
49 |
+
recorder = open(os.path.join(outpath, 'transcript.tsv'), 'w', encoding='utf-8')
|
50 |
+
tts = TextToSpeech()
|
51 |
+
for e, line in enumerate(lines):
|
52 |
+
if e >= stop_after:
|
53 |
+
break
|
54 |
+
transcript = line[0]
|
55 |
+
path = os.path.join(os.path.dirname(fname), line[1])
|
56 |
+
cond_audio = load_audio(path, 22050)
|
57 |
+
torchaudio.save(os.path.join(outpath_real, os.path.basename(line[1])), cond_audio, 22050)
|
58 |
+
sample = tts.tts(transcript, [cond_audio, cond_audio], num_autoregressive_samples=32, repetition_penalty=2.0,
|
59 |
+
k=1, diffusion_iterations=32, length_penalty=1.0, **cfg)
|
60 |
+
down = torchaudio.functional.resample(sample, 24000, 22050)
|
61 |
+
fout_path = os.path.join(outpath, os.path.basename(line[1]))
|
62 |
+
torchaudio.save(fout_path, down.squeeze(0), 22050)
|
63 |
+
recorder.write(f'{transcript}\t{fout_path}\n')
|
64 |
+
recorder.flush()
|
65 |
+
recorder.close()
|
tortoise_tts.ipynb
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"name": "tortoise-tts.ipynb",
|
7 |
+
"provenance": [],
|
8 |
+
"collapsed_sections": []
|
9 |
+
},
|
10 |
+
"kernelspec": {
|
11 |
+
"name": "python3",
|
12 |
+
"display_name": "Python 3"
|
13 |
+
},
|
14 |
+
"language_info": {
|
15 |
+
"name": "python"
|
16 |
+
},
|
17 |
+
"accelerator": "GPU"
|
18 |
+
},
|
19 |
+
"cells": [
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": null,
|
23 |
+
"metadata": {
|
24 |
+
"id": "JrK20I32grP6"
|
25 |
+
},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"!git clone https://github.com/neonbjb/tortoise-tts.git\n",
|
29 |
+
"%cd tortoise-tts\n",
|
30 |
+
"!pip install -r requirements.txt"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"source": [
|
36 |
+
"# Imports used through the rest of the notebook.\n",
|
37 |
+
"import torch\n",
|
38 |
+
"import torchaudio\n",
|
39 |
+
"import torch.nn as nn\n",
|
40 |
+
"import torch.nn.functional as F\n",
|
41 |
+
"from tqdm import tqdm\n",
|
42 |
+
"\n",
|
43 |
+
"from utils.tokenizer import VoiceBpeTokenizer\n",
|
44 |
+
"from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder\n",
|
45 |
+
"from models.text_voice_clip import VoiceCLIP\n",
|
46 |
+
"from models.dvae import DiscreteVAE\n",
|
47 |
+
"from models.autoregressive import UnifiedVoice\n",
|
48 |
+
"\n",
|
49 |
+
"# These have some fairly interesting code that is hidden in the colab. Consider checking it out.\n",
|
50 |
+
"from do_tts import download_models, load_discrete_vocoder_diffuser, load_conditioning, fix_autoregressive_output, do_spectrogram_diffusion"
|
51 |
+
],
|
52 |
+
"metadata": {
|
53 |
+
"id": "Gen09NM4hONQ"
|
54 |
+
},
|
55 |
+
"execution_count": null,
|
56 |
+
"outputs": []
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"source": [
|
61 |
+
"# Download pretrained models and set up pretrained voice bank. Feel free to upload and add your own voices here.\n",
|
62 |
+
"# To do so, upload two WAV files cropped to 5-10 seconds of someone speaking.\n",
|
63 |
+
"download_models()\n",
|
64 |
+
"preselected_cond_voices = {\n",
|
65 |
+
" # Male voices\n",
|
66 |
+
" 'dotrice': ['voices/dotrice/1.wav', 'voices/dotrice/2.wav'],\n",
|
67 |
+
" 'harris': ['voices/harris/1.wav', 'voices/harris/2.wav'],\n",
|
68 |
+
" 'lescault': ['voices/lescault/1.wav', 'voices/lescault/2.wav'],\n",
|
69 |
+
" 'otto': ['voices/otto/1.wav', 'voices/otto/2.wav'],\n",
|
70 |
+
" # Female voices\n",
|
71 |
+
" 'atkins': ['voices/atkins/1.wav', 'voices/atkins/2.wav'],\n",
|
72 |
+
" 'grace': ['voices/grace/1.wav', 'voices/grace/2.wav'],\n",
|
73 |
+
" 'kennard': ['voices/kennard/1.wav', 'voices/kennard/2.wav'],\n",
|
74 |
+
" 'mol': ['voices/mol/1.wav', 'voices/mol/2.wav'],\n",
|
75 |
+
" }"
|
76 |
+
],
|
77 |
+
"metadata": {
|
78 |
+
"id": "SSleVnRAiEE2"
|
79 |
+
},
|
80 |
+
"execution_count": null,
|
81 |
+
"outputs": []
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"source": [
|
86 |
+
"# This is the text that will be spoken.\n",
|
87 |
+
"text = \"And took the other as just as fair, and having perhaps the better claim, because it was grassy and wanted wear.\"\n",
|
88 |
+
"# This is the voice that will speak it.\n",
|
89 |
+
"voice = 'atkins'\n",
|
90 |
+
"# This is the number of samples we will generate from the DALLE-style model. More will produce better results, but will take longer to produce.\n",
|
91 |
+
"# I don't recommend going less than 128.\n",
|
92 |
+
"num_autoregressive_samples = 128"
|
93 |
+
],
|
94 |
+
"metadata": {
|
95 |
+
"id": "bt_aoxONjfL2"
|
96 |
+
},
|
97 |
+
"execution_count": null,
|
98 |
+
"outputs": []
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"source": [
|
103 |
+
"# Prepare data.\n",
|
104 |
+
"tokenizer = VoiceBpeTokenizer()\n",
|
105 |
+
"text = torch.IntTensor(tokenizer.encode(text)).unsqueeze(0).cuda()\n",
|
106 |
+
"text = F.pad(text, (0,1)) # This may not be necessary.\n",
|
107 |
+
"cond_paths = preselected_cond_voices[voice]\n",
|
108 |
+
"conds = []\n",
|
109 |
+
"for cond_path in cond_paths:\n",
|
110 |
+
" c, cond_wav = load_conditioning(cond_path)\n",
|
111 |
+
" conds.append(c)\n",
|
112 |
+
"conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model."
|
113 |
+
],
|
114 |
+
"metadata": {
|
115 |
+
"id": "KEXOKjIvn6NW"
|
116 |
+
},
|
117 |
+
"execution_count": null,
|
118 |
+
"outputs": []
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"source": [
|
123 |
+
"# Load the autoregressive model.\n",
|
124 |
+
"autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,\n",
|
125 |
+
" heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()\n",
|
126 |
+
"autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))\n",
|
127 |
+
"stop_mel_token = autoregressive.stop_mel_token"
|
128 |
+
],
|
129 |
+
"metadata": {
|
130 |
+
"id": "Z15xFT_uhP8v"
|
131 |
+
},
|
132 |
+
"execution_count": null,
|
133 |
+
"outputs": []
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"source": [
|
138 |
+
"# Perform inference with the autoregressive model, generating num_autoregressive_samples\n",
|
139 |
+
"with torch.no_grad():\n",
|
140 |
+
" samples = []\n",
|
141 |
+
" for b in tqdm(range(num_autoregressive_samples // 16)):\n",
|
142 |
+
" codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,\n",
|
143 |
+
" temperature=.9, num_return_sequences=16, length_penalty=1)\n",
|
144 |
+
" padding_needed = 250 - codes.shape[1]\n",
|
145 |
+
" codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)\n",
|
146 |
+
" samples.append(codes)\n",
|
147 |
+
"\n",
|
148 |
+
"# Delete model weights to conserve memory.\n",
|
149 |
+
"del autoregressive"
|
150 |
+
],
|
151 |
+
"metadata": {
|
152 |
+
"id": "xajqWiEik-j0"
|
153 |
+
},
|
154 |
+
"execution_count": null,
|
155 |
+
"outputs": []
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "code",
|
159 |
+
"source": [
|
160 |
+
"# Load the CLIP model.\n",
|
161 |
+
"clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,\n",
|
162 |
+
" num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).cuda().eval()\n",
|
163 |
+
"clip.load_state_dict(torch.load('.models/clip.pth'))"
|
164 |
+
],
|
165 |
+
"metadata": {
|
166 |
+
"id": "KNgYSyuyliMs"
|
167 |
+
},
|
168 |
+
"execution_count": null,
|
169 |
+
"outputs": []
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"source": [
|
174 |
+
"# Use the CLIP model to select the best autoregressive output to match the given text.\n",
|
175 |
+
"clip_results = []\n",
|
176 |
+
"with torch.no_grad():\n",
|
177 |
+
" for batch in samples:\n",
|
178 |
+
" for i in range(batch.shape[0]):\n",
|
179 |
+
" batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)\n",
|
180 |
+
" text = text[:, :120] # Ugly hack to fix the fact that I didn't train CLIP to handle long enough text.\n",
|
181 |
+
" clip_results.append(clip(text.repeat(batch.shape[0], 1),\n",
|
182 |
+
" torch.full((batch.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),\n",
|
183 |
+
" batch, torch.full((batch.shape[0],), fill_value=batch.shape[1]*1024, dtype=torch.long, device='cuda'),\n",
|
184 |
+
" return_loss=False))\n",
|
185 |
+
" clip_results = torch.cat(clip_results, dim=0)\n",
|
186 |
+
" samples = torch.cat(samples, dim=0)\n",
|
187 |
+
" best_results = samples[torch.topk(clip_results, k=1).indices]\n",
|
188 |
+
"\n",
|
189 |
+
"# Save samples to CPU memory, delete clip to conserve memory.\n",
|
190 |
+
"samples = samples.cpu()\n",
|
191 |
+
"del clip"
|
192 |
+
],
|
193 |
+
"metadata": {
|
194 |
+
"id": "DDXkM0lclp4U"
|
195 |
+
},
|
196 |
+
"execution_count": null,
|
197 |
+
"outputs": []
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"cell_type": "code",
|
201 |
+
"source": [
|
202 |
+
"# Load the DVAE and diffusion model.\n",
|
203 |
+
"dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,\n",
|
204 |
+
" record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()\n",
|
205 |
+
"dvae.load_state_dict(torch.load('.models/dvae.pth'), strict=False)\n",
|
206 |
+
"diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],\n",
|
207 |
+
" spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,\n",
|
208 |
+
" conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval()\n",
|
209 |
+
"diffusion.load_state_dict(torch.load('.models/diffusion.pth'))\n",
|
210 |
+
"diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)"
|
211 |
+
],
|
212 |
+
"metadata": {
|
213 |
+
"id": "97acSnBal8Q2"
|
214 |
+
},
|
215 |
+
"execution_count": null,
|
216 |
+
"outputs": []
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "code",
|
220 |
+
"source": [
|
221 |
+
"# Decode the (best) discrete sequence created by the autoregressive model.\n",
|
222 |
+
"with torch.no_grad():\n",
|
223 |
+
" for b in range(best_results.shape[0]):\n",
|
224 |
+
" code = best_results[b].unsqueeze(0)\n",
|
225 |
+
" wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256, mean=True)\n",
|
226 |
+
" torchaudio.save(f'{voice}_{b}.wav', wav.squeeze(0).cpu(), 22050)"
|
227 |
+
],
|
228 |
+
"metadata": {
|
229 |
+
"id": "HEDABTrdl_kM"
|
230 |
+
},
|
231 |
+
"execution_count": null,
|
232 |
+
"outputs": []
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"source": [
|
237 |
+
"# Listen to your text! (told you that'd take a long time..)\n",
|
238 |
+
"from IPython.display import Audio\n",
|
239 |
+
"Audio(data=wav.squeeze(0).cpu().numpy(), rate=22050)"
|
240 |
+
],
|
241 |
+
"metadata": {
|
242 |
+
"id": "EyHmcdqBmSvf"
|
243 |
+
},
|
244 |
+
"execution_count": null,
|
245 |
+
"outputs": []
|
246 |
+
}
|
247 |
+
]
|
248 |
+
}
|
utils/__init__.py
ADDED
File without changes
|
utils/audio.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import numpy as np
|
7 |
+
from scipy.io.wavfile import read
|
8 |
+
|
9 |
+
from utils.stft import STFT
|
10 |
+
|
11 |
+
|
12 |
+
def load_wav_to_torch(full_path):
|
13 |
+
sampling_rate, data = read(full_path)
|
14 |
+
if data.dtype == np.int32:
|
15 |
+
norm_fix = 2 ** 31
|
16 |
+
elif data.dtype == np.int16:
|
17 |
+
norm_fix = 2 ** 15
|
18 |
+
elif data.dtype == np.float16 or data.dtype == np.float32:
|
19 |
+
norm_fix = 1.
|
20 |
+
else:
|
21 |
+
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
|
22 |
+
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
23 |
+
|
24 |
+
|
25 |
+
def load_audio(audiopath, sampling_rate):
|
26 |
+
if audiopath[-4:] == '.wav':
|
27 |
+
audio, lsr = load_wav_to_torch(audiopath)
|
28 |
+
elif audiopath[-4:] == '.mp3':
|
29 |
+
# https://github.com/neonbjb/pyfastmp3decoder - Definitely worth it.
|
30 |
+
from pyfastmp3decoder.mp3decoder import load_mp3
|
31 |
+
audio, lsr = load_mp3(audiopath, sampling_rate)
|
32 |
+
audio = torch.FloatTensor(audio)
|
33 |
+
|
34 |
+
# Remove any channel data.
|
35 |
+
if len(audio.shape) > 1:
|
36 |
+
if audio.shape[0] < 5:
|
37 |
+
audio = audio[0]
|
38 |
+
else:
|
39 |
+
assert audio.shape[1] < 5
|
40 |
+
audio = audio[:, 0]
|
41 |
+
|
42 |
+
if lsr != sampling_rate:
|
43 |
+
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
44 |
+
|
45 |
+
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
46 |
+
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
47 |
+
if torch.any(audio > 2) or not torch.any(audio < 0):
|
48 |
+
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
49 |
+
audio.clip_(-1, 1)
|
50 |
+
|
51 |
+
return audio.unsqueeze(0)
|
52 |
+
|
53 |
+
|
54 |
+
TACOTRON_MEL_MAX = 2.3143386840820312
|
55 |
+
TACOTRON_MEL_MIN = -11.512925148010254
|
56 |
+
|
57 |
+
|
58 |
+
def denormalize_tacotron_mel(norm_mel):
|
59 |
+
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
|
60 |
+
|
61 |
+
|
62 |
+
def normalize_tacotron_mel(mel):
|
63 |
+
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
64 |
+
|
65 |
+
|
66 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
67 |
+
"""
|
68 |
+
PARAMS
|
69 |
+
------
|
70 |
+
C: compression factor
|
71 |
+
"""
|
72 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
73 |
+
|
74 |
+
|
75 |
+
def dynamic_range_decompression(x, C=1):
|
76 |
+
"""
|
77 |
+
PARAMS
|
78 |
+
------
|
79 |
+
C: compression factor used to compress
|
80 |
+
"""
|
81 |
+
return torch.exp(x) / C
|
82 |
+
|
83 |
+
|
84 |
+
def get_voices():
|
85 |
+
subs = os.listdir('voices')
|
86 |
+
voices = {}
|
87 |
+
for sub in subs:
|
88 |
+
subj = os.path.join('voices', sub)
|
89 |
+
if os.path.isdir(subj):
|
90 |
+
voices[sub] = glob(f'{subj}/*.wav')
|
91 |
+
return voices
|
92 |
+
|
93 |
+
|
94 |
+
class TacotronSTFT(torch.nn.Module):
|
95 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
96 |
+
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
97 |
+
mel_fmax=8000.0):
|
98 |
+
super(TacotronSTFT, self).__init__()
|
99 |
+
self.n_mel_channels = n_mel_channels
|
100 |
+
self.sampling_rate = sampling_rate
|
101 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
102 |
+
from librosa.filters import mel as librosa_mel_fn
|
103 |
+
mel_basis = librosa_mel_fn(
|
104 |
+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
|
105 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
106 |
+
self.register_buffer('mel_basis', mel_basis)
|
107 |
+
|
108 |
+
def spectral_normalize(self, magnitudes):
|
109 |
+
output = dynamic_range_compression(magnitudes)
|
110 |
+
return output
|
111 |
+
|
112 |
+
def spectral_de_normalize(self, magnitudes):
|
113 |
+
output = dynamic_range_decompression(magnitudes)
|
114 |
+
return output
|
115 |
+
|
116 |
+
def mel_spectrogram(self, y):
|
117 |
+
"""Computes mel-spectrograms from a batch of waves
|
118 |
+
PARAMS
|
119 |
+
------
|
120 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
121 |
+
|
122 |
+
RETURNS
|
123 |
+
-------
|
124 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
125 |
+
"""
|
126 |
+
assert(torch.min(y.data) >= -10)
|
127 |
+
assert(torch.max(y.data) <= 10)
|
128 |
+
y = torch.clip(y, min=-1, max=1)
|
129 |
+
|
130 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
131 |
+
magnitudes = magnitudes.data
|
132 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
133 |
+
mel_output = self.spectral_normalize(mel_output)
|
134 |
+
return mel_output
|
135 |
+
|
136 |
+
|
137 |
+
def wav_to_univnet_mel(wav, do_normalization=False):
|
138 |
+
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
|
139 |
+
stft = stft.cuda()
|
140 |
+
mel = stft.mel_spectrogram(wav)
|
141 |
+
if do_normalization:
|
142 |
+
mel = normalize_tacotron_mel(mel)
|
143 |
+
return mel
|
utils/diffusion.py
ADDED
@@ -0,0 +1,1250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself:
|
3 |
+
|
4 |
+
This code started out as a PyTorch port of Ho et al's diffusion models:
|
5 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
|
6 |
+
|
7 |
+
Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import enum
|
11 |
+
import math
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch as th
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
20 |
+
"""
|
21 |
+
Compute the KL divergence between two gaussians.
|
22 |
+
|
23 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
24 |
+
scalars, among other use cases.
|
25 |
+
"""
|
26 |
+
tensor = None
|
27 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
28 |
+
if isinstance(obj, th.Tensor):
|
29 |
+
tensor = obj
|
30 |
+
break
|
31 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
32 |
+
|
33 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
34 |
+
# Tensors, but it does not work for th.exp().
|
35 |
+
logvar1, logvar2 = [
|
36 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
37 |
+
for x in (logvar1, logvar2)
|
38 |
+
]
|
39 |
+
|
40 |
+
return 0.5 * (
|
41 |
+
-1.0
|
42 |
+
+ logvar2
|
43 |
+
- logvar1
|
44 |
+
+ th.exp(logvar1 - logvar2)
|
45 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def approx_standard_normal_cdf(x):
|
50 |
+
"""
|
51 |
+
A fast approximation of the cumulative distribution function of the
|
52 |
+
standard normal.
|
53 |
+
"""
|
54 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
55 |
+
|
56 |
+
|
57 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
58 |
+
"""
|
59 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
60 |
+
given image.
|
61 |
+
|
62 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
63 |
+
rescaled to the range [-1, 1].
|
64 |
+
:param means: the Gaussian mean Tensor.
|
65 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
66 |
+
:return: a tensor like x of log probabilities (in nats).
|
67 |
+
"""
|
68 |
+
assert x.shape == means.shape == log_scales.shape
|
69 |
+
centered_x = x - means
|
70 |
+
inv_stdv = th.exp(-log_scales)
|
71 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
72 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
73 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
74 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
75 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
76 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
77 |
+
cdf_delta = cdf_plus - cdf_min
|
78 |
+
log_probs = th.where(
|
79 |
+
x < -0.999,
|
80 |
+
log_cdf_plus,
|
81 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
82 |
+
)
|
83 |
+
assert log_probs.shape == x.shape
|
84 |
+
return log_probs
|
85 |
+
|
86 |
+
|
87 |
+
def mean_flat(tensor):
|
88 |
+
"""
|
89 |
+
Take the mean over all non-batch dimensions.
|
90 |
+
"""
|
91 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
92 |
+
|
93 |
+
|
94 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
95 |
+
"""
|
96 |
+
Get a pre-defined beta schedule for the given name.
|
97 |
+
|
98 |
+
The beta schedule library consists of beta schedules which remain similar
|
99 |
+
in the limit of num_diffusion_timesteps.
|
100 |
+
Beta schedules may be added, but should not be removed or changed once
|
101 |
+
they are committed to maintain backwards compatibility.
|
102 |
+
"""
|
103 |
+
if schedule_name == "linear":
|
104 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
105 |
+
# diffusion steps.
|
106 |
+
scale = 1000 / num_diffusion_timesteps
|
107 |
+
beta_start = scale * 0.0001
|
108 |
+
beta_end = scale * 0.02
|
109 |
+
return np.linspace(
|
110 |
+
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
111 |
+
)
|
112 |
+
elif schedule_name == "cosine":
|
113 |
+
return betas_for_alpha_bar(
|
114 |
+
num_diffusion_timesteps,
|
115 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
119 |
+
|
120 |
+
|
121 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
122 |
+
"""
|
123 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
124 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
125 |
+
|
126 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
127 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
128 |
+
produces the cumulative product of (1-beta) up to that
|
129 |
+
part of the diffusion process.
|
130 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
131 |
+
prevent singularities.
|
132 |
+
"""
|
133 |
+
betas = []
|
134 |
+
for i in range(num_diffusion_timesteps):
|
135 |
+
t1 = i / num_diffusion_timesteps
|
136 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
137 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
138 |
+
return np.array(betas)
|
139 |
+
|
140 |
+
|
141 |
+
class ModelMeanType(enum.Enum):
|
142 |
+
"""
|
143 |
+
Which type of output the model predicts.
|
144 |
+
"""
|
145 |
+
|
146 |
+
PREVIOUS_X = 'previous_x' # the model predicts x_{t-1}
|
147 |
+
START_X = 'start_x' # the model predicts x_0
|
148 |
+
EPSILON = 'epsilon' # the model predicts epsilon
|
149 |
+
|
150 |
+
|
151 |
+
class ModelVarType(enum.Enum):
|
152 |
+
"""
|
153 |
+
What is used as the model's output variance.
|
154 |
+
|
155 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
156 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
157 |
+
"""
|
158 |
+
|
159 |
+
LEARNED = 'learned'
|
160 |
+
FIXED_SMALL = 'fixed_small'
|
161 |
+
FIXED_LARGE = 'fixed_large'
|
162 |
+
LEARNED_RANGE = 'learned_range'
|
163 |
+
|
164 |
+
|
165 |
+
class LossType(enum.Enum):
|
166 |
+
MSE = 'mse' # use raw MSE loss (and KL when learning variances)
|
167 |
+
RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances)
|
168 |
+
KL = 'kl' # use the variational lower-bound
|
169 |
+
RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB
|
170 |
+
|
171 |
+
def is_vb(self):
|
172 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
173 |
+
|
174 |
+
|
175 |
+
class GaussianDiffusion:
|
176 |
+
"""
|
177 |
+
Utilities for training and sampling diffusion models.
|
178 |
+
|
179 |
+
Ported directly from here, and then adapted over time to further experimentation.
|
180 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
181 |
+
|
182 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
183 |
+
starting at T and going to 1.
|
184 |
+
:param model_mean_type: a ModelMeanType determining what the model outputs.
|
185 |
+
:param model_var_type: a ModelVarType determining how variance is output.
|
186 |
+
:param loss_type: a LossType determining the loss function to use.
|
187 |
+
:param rescale_timesteps: if True, pass floating point timesteps into the
|
188 |
+
model so that they are always scaled like in the
|
189 |
+
original paper (0 to 1000).
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
*,
|
195 |
+
betas,
|
196 |
+
model_mean_type,
|
197 |
+
model_var_type,
|
198 |
+
loss_type,
|
199 |
+
rescale_timesteps=False,
|
200 |
+
conditioning_free=False,
|
201 |
+
conditioning_free_k=1,
|
202 |
+
ramp_conditioning_free=True,
|
203 |
+
):
|
204 |
+
self.model_mean_type = ModelMeanType(model_mean_type)
|
205 |
+
self.model_var_type = ModelVarType(model_var_type)
|
206 |
+
self.loss_type = LossType(loss_type)
|
207 |
+
self.rescale_timesteps = rescale_timesteps
|
208 |
+
self.conditioning_free = conditioning_free
|
209 |
+
self.conditioning_free_k = conditioning_free_k
|
210 |
+
self.ramp_conditioning_free = ramp_conditioning_free
|
211 |
+
|
212 |
+
# Use float64 for accuracy.
|
213 |
+
betas = np.array(betas, dtype=np.float64)
|
214 |
+
self.betas = betas
|
215 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
216 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
217 |
+
|
218 |
+
self.num_timesteps = int(betas.shape[0])
|
219 |
+
|
220 |
+
alphas = 1.0 - betas
|
221 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
222 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
223 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
224 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
225 |
+
|
226 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
227 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
228 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
229 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
230 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
231 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
232 |
+
|
233 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
234 |
+
self.posterior_variance = (
|
235 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
236 |
+
)
|
237 |
+
# log calculation clipped because the posterior variance is 0 at the
|
238 |
+
# beginning of the diffusion chain.
|
239 |
+
self.posterior_log_variance_clipped = np.log(
|
240 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
241 |
+
)
|
242 |
+
self.posterior_mean_coef1 = (
|
243 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
244 |
+
)
|
245 |
+
self.posterior_mean_coef2 = (
|
246 |
+
(1.0 - self.alphas_cumprod_prev)
|
247 |
+
* np.sqrt(alphas)
|
248 |
+
/ (1.0 - self.alphas_cumprod)
|
249 |
+
)
|
250 |
+
|
251 |
+
def q_mean_variance(self, x_start, t):
|
252 |
+
"""
|
253 |
+
Get the distribution q(x_t | x_0).
|
254 |
+
|
255 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
256 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
257 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
258 |
+
"""
|
259 |
+
mean = (
|
260 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
261 |
+
)
|
262 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
263 |
+
log_variance = _extract_into_tensor(
|
264 |
+
self.log_one_minus_alphas_cumprod, t, x_start.shape
|
265 |
+
)
|
266 |
+
return mean, variance, log_variance
|
267 |
+
|
268 |
+
def q_sample(self, x_start, t, noise=None):
|
269 |
+
"""
|
270 |
+
Diffuse the data for a given number of diffusion steps.
|
271 |
+
|
272 |
+
In other words, sample from q(x_t | x_0).
|
273 |
+
|
274 |
+
:param x_start: the initial data batch.
|
275 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
276 |
+
:param noise: if specified, the split-out normal noise.
|
277 |
+
:return: A noisy version of x_start.
|
278 |
+
"""
|
279 |
+
if noise is None:
|
280 |
+
noise = th.randn_like(x_start)
|
281 |
+
assert noise.shape == x_start.shape
|
282 |
+
return (
|
283 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
284 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
285 |
+
* noise
|
286 |
+
)
|
287 |
+
|
288 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
289 |
+
"""
|
290 |
+
Compute the mean and variance of the diffusion posterior:
|
291 |
+
|
292 |
+
q(x_{t-1} | x_t, x_0)
|
293 |
+
|
294 |
+
"""
|
295 |
+
assert x_start.shape == x_t.shape
|
296 |
+
posterior_mean = (
|
297 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
298 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
299 |
+
)
|
300 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
301 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
302 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
303 |
+
)
|
304 |
+
assert (
|
305 |
+
posterior_mean.shape[0]
|
306 |
+
== posterior_variance.shape[0]
|
307 |
+
== posterior_log_variance_clipped.shape[0]
|
308 |
+
== x_start.shape[0]
|
309 |
+
)
|
310 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
311 |
+
|
312 |
+
def p_mean_variance(
|
313 |
+
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
|
314 |
+
):
|
315 |
+
"""
|
316 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
317 |
+
the initial x, x_0.
|
318 |
+
|
319 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
320 |
+
as input.
|
321 |
+
:param x: the [N x C x ...] tensor at time t.
|
322 |
+
:param t: a 1-D Tensor of timesteps.
|
323 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
324 |
+
:param denoised_fn: if not None, a function which applies to the
|
325 |
+
x_start prediction before it is used to sample. Applies before
|
326 |
+
clip_denoised.
|
327 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
328 |
+
pass to the model. This can be used for conditioning.
|
329 |
+
:return: a dict with the following keys:
|
330 |
+
- 'mean': the model mean output.
|
331 |
+
- 'variance': the model variance output.
|
332 |
+
- 'log_variance': the log of 'variance'.
|
333 |
+
- 'pred_xstart': the prediction for x_0.
|
334 |
+
"""
|
335 |
+
if model_kwargs is None:
|
336 |
+
model_kwargs = {}
|
337 |
+
|
338 |
+
B, C = x.shape[:2]
|
339 |
+
assert t.shape == (B,)
|
340 |
+
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
341 |
+
if self.conditioning_free:
|
342 |
+
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
|
343 |
+
|
344 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
345 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
346 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
347 |
+
if self.conditioning_free:
|
348 |
+
model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
|
349 |
+
if self.model_var_type == ModelVarType.LEARNED:
|
350 |
+
model_log_variance = model_var_values
|
351 |
+
model_variance = th.exp(model_log_variance)
|
352 |
+
else:
|
353 |
+
min_log = _extract_into_tensor(
|
354 |
+
self.posterior_log_variance_clipped, t, x.shape
|
355 |
+
)
|
356 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
357 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
358 |
+
frac = (model_var_values + 1) / 2
|
359 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
360 |
+
model_variance = th.exp(model_log_variance)
|
361 |
+
else:
|
362 |
+
model_variance, model_log_variance = {
|
363 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
364 |
+
# to get a better decoder log likelihood.
|
365 |
+
ModelVarType.FIXED_LARGE: (
|
366 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
367 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
368 |
+
),
|
369 |
+
ModelVarType.FIXED_SMALL: (
|
370 |
+
self.posterior_variance,
|
371 |
+
self.posterior_log_variance_clipped,
|
372 |
+
),
|
373 |
+
}[self.model_var_type]
|
374 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
375 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
376 |
+
|
377 |
+
if self.conditioning_free:
|
378 |
+
if self.ramp_conditioning_free:
|
379 |
+
assert t.shape[0] == 1 # This should only be used in inference.
|
380 |
+
cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
|
381 |
+
else:
|
382 |
+
cfk = self.conditioning_free_k
|
383 |
+
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
|
384 |
+
|
385 |
+
def process_xstart(x):
|
386 |
+
if denoised_fn is not None:
|
387 |
+
x = denoised_fn(x)
|
388 |
+
if clip_denoised:
|
389 |
+
return x.clamp(-1, 1)
|
390 |
+
return x
|
391 |
+
|
392 |
+
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
393 |
+
pred_xstart = process_xstart(
|
394 |
+
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
395 |
+
)
|
396 |
+
model_mean = model_output
|
397 |
+
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
|
398 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
399 |
+
pred_xstart = process_xstart(model_output)
|
400 |
+
else:
|
401 |
+
pred_xstart = process_xstart(
|
402 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
403 |
+
)
|
404 |
+
model_mean, _, _ = self.q_posterior_mean_variance(
|
405 |
+
x_start=pred_xstart, x_t=x, t=t
|
406 |
+
)
|
407 |
+
else:
|
408 |
+
raise NotImplementedError(self.model_mean_type)
|
409 |
+
|
410 |
+
assert (
|
411 |
+
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
412 |
+
)
|
413 |
+
return {
|
414 |
+
"mean": model_mean,
|
415 |
+
"variance": model_variance,
|
416 |
+
"log_variance": model_log_variance,
|
417 |
+
"pred_xstart": pred_xstart,
|
418 |
+
}
|
419 |
+
|
420 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
421 |
+
assert x_t.shape == eps.shape
|
422 |
+
return (
|
423 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
424 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
425 |
+
)
|
426 |
+
|
427 |
+
def _predict_xstart_from_xprev(self, x_t, t, xprev):
|
428 |
+
assert x_t.shape == xprev.shape
|
429 |
+
return ( # (xprev - coef2*x_t) / coef1
|
430 |
+
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
|
431 |
+
- _extract_into_tensor(
|
432 |
+
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
|
433 |
+
)
|
434 |
+
* x_t
|
435 |
+
)
|
436 |
+
|
437 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
438 |
+
return (
|
439 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
440 |
+
- pred_xstart
|
441 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
442 |
+
|
443 |
+
def _scale_timesteps(self, t):
|
444 |
+
if self.rescale_timesteps:
|
445 |
+
return t.float() * (1000.0 / self.num_timesteps)
|
446 |
+
return t
|
447 |
+
|
448 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
449 |
+
"""
|
450 |
+
Compute the mean for the previous step, given a function cond_fn that
|
451 |
+
computes the gradient of a conditional log probability with respect to
|
452 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
453 |
+
condition on y.
|
454 |
+
|
455 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
456 |
+
"""
|
457 |
+
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
|
458 |
+
new_mean = (
|
459 |
+
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
460 |
+
)
|
461 |
+
return new_mean
|
462 |
+
|
463 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
464 |
+
"""
|
465 |
+
Compute what the p_mean_variance output would have been, should the
|
466 |
+
model's score function be conditioned by cond_fn.
|
467 |
+
|
468 |
+
See condition_mean() for details on cond_fn.
|
469 |
+
|
470 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
471 |
+
from Song et al (2020).
|
472 |
+
"""
|
473 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
474 |
+
|
475 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
476 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
|
477 |
+
x, self._scale_timesteps(t), **model_kwargs
|
478 |
+
)
|
479 |
+
|
480 |
+
out = p_mean_var.copy()
|
481 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
482 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(
|
483 |
+
x_start=out["pred_xstart"], x_t=x, t=t
|
484 |
+
)
|
485 |
+
return out
|
486 |
+
|
487 |
+
def p_sample(
|
488 |
+
self,
|
489 |
+
model,
|
490 |
+
x,
|
491 |
+
t,
|
492 |
+
clip_denoised=True,
|
493 |
+
denoised_fn=None,
|
494 |
+
cond_fn=None,
|
495 |
+
model_kwargs=None,
|
496 |
+
):
|
497 |
+
"""
|
498 |
+
Sample x_{t-1} from the model at the given timestep.
|
499 |
+
|
500 |
+
:param model: the model to sample from.
|
501 |
+
:param x: the current tensor at x_{t-1}.
|
502 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
503 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
504 |
+
:param denoised_fn: if not None, a function which applies to the
|
505 |
+
x_start prediction before it is used to sample.
|
506 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
507 |
+
similarly to the model.
|
508 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
509 |
+
pass to the model. This can be used for conditioning.
|
510 |
+
:return: a dict containing the following keys:
|
511 |
+
- 'sample': a random sample from the model.
|
512 |
+
- 'pred_xstart': a prediction of x_0.
|
513 |
+
"""
|
514 |
+
out = self.p_mean_variance(
|
515 |
+
model,
|
516 |
+
x,
|
517 |
+
t,
|
518 |
+
clip_denoised=clip_denoised,
|
519 |
+
denoised_fn=denoised_fn,
|
520 |
+
model_kwargs=model_kwargs,
|
521 |
+
)
|
522 |
+
noise = th.randn_like(x)
|
523 |
+
nonzero_mask = (
|
524 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
525 |
+
) # no noise when t == 0
|
526 |
+
if cond_fn is not None:
|
527 |
+
out["mean"] = self.condition_mean(
|
528 |
+
cond_fn, out, x, t, model_kwargs=model_kwargs
|
529 |
+
)
|
530 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
531 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
532 |
+
|
533 |
+
def p_sample_loop(
|
534 |
+
self,
|
535 |
+
model,
|
536 |
+
shape,
|
537 |
+
noise=None,
|
538 |
+
clip_denoised=True,
|
539 |
+
denoised_fn=None,
|
540 |
+
cond_fn=None,
|
541 |
+
model_kwargs=None,
|
542 |
+
device=None,
|
543 |
+
progress=False,
|
544 |
+
):
|
545 |
+
"""
|
546 |
+
Generate samples from the model.
|
547 |
+
|
548 |
+
:param model: the model module.
|
549 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
550 |
+
:param noise: if specified, the noise from the encoder to sample.
|
551 |
+
Should be of the same shape as `shape`.
|
552 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
553 |
+
:param denoised_fn: if not None, a function which applies to the
|
554 |
+
x_start prediction before it is used to sample.
|
555 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
556 |
+
similarly to the model.
|
557 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
558 |
+
pass to the model. This can be used for conditioning.
|
559 |
+
:param device: if specified, the device to create the samples on.
|
560 |
+
If not specified, use a model parameter's device.
|
561 |
+
:param progress: if True, show a tqdm progress bar.
|
562 |
+
:return: a non-differentiable batch of samples.
|
563 |
+
"""
|
564 |
+
final = None
|
565 |
+
for sample in self.p_sample_loop_progressive(
|
566 |
+
model,
|
567 |
+
shape,
|
568 |
+
noise=noise,
|
569 |
+
clip_denoised=clip_denoised,
|
570 |
+
denoised_fn=denoised_fn,
|
571 |
+
cond_fn=cond_fn,
|
572 |
+
model_kwargs=model_kwargs,
|
573 |
+
device=device,
|
574 |
+
progress=progress,
|
575 |
+
):
|
576 |
+
final = sample
|
577 |
+
return final["sample"]
|
578 |
+
|
579 |
+
def p_sample_loop_progressive(
|
580 |
+
self,
|
581 |
+
model,
|
582 |
+
shape,
|
583 |
+
noise=None,
|
584 |
+
clip_denoised=True,
|
585 |
+
denoised_fn=None,
|
586 |
+
cond_fn=None,
|
587 |
+
model_kwargs=None,
|
588 |
+
device=None,
|
589 |
+
progress=False,
|
590 |
+
):
|
591 |
+
"""
|
592 |
+
Generate samples from the model and yield intermediate samples from
|
593 |
+
each timestep of diffusion.
|
594 |
+
|
595 |
+
Arguments are the same as p_sample_loop().
|
596 |
+
Returns a generator over dicts, where each dict is the return value of
|
597 |
+
p_sample().
|
598 |
+
"""
|
599 |
+
if device is None:
|
600 |
+
device = next(model.parameters()).device
|
601 |
+
assert isinstance(shape, (tuple, list))
|
602 |
+
if noise is not None:
|
603 |
+
img = noise
|
604 |
+
else:
|
605 |
+
img = th.randn(*shape, device=device)
|
606 |
+
indices = list(range(self.num_timesteps))[::-1]
|
607 |
+
|
608 |
+
for i in tqdm(indices):
|
609 |
+
t = th.tensor([i] * shape[0], device=device)
|
610 |
+
with th.no_grad():
|
611 |
+
out = self.p_sample(
|
612 |
+
model,
|
613 |
+
img,
|
614 |
+
t,
|
615 |
+
clip_denoised=clip_denoised,
|
616 |
+
denoised_fn=denoised_fn,
|
617 |
+
cond_fn=cond_fn,
|
618 |
+
model_kwargs=model_kwargs,
|
619 |
+
)
|
620 |
+
yield out
|
621 |
+
img = out["sample"]
|
622 |
+
|
623 |
+
def ddim_sample(
|
624 |
+
self,
|
625 |
+
model,
|
626 |
+
x,
|
627 |
+
t,
|
628 |
+
clip_denoised=True,
|
629 |
+
denoised_fn=None,
|
630 |
+
cond_fn=None,
|
631 |
+
model_kwargs=None,
|
632 |
+
eta=0.0,
|
633 |
+
):
|
634 |
+
"""
|
635 |
+
Sample x_{t-1} from the model using DDIM.
|
636 |
+
|
637 |
+
Same usage as p_sample().
|
638 |
+
"""
|
639 |
+
out = self.p_mean_variance(
|
640 |
+
model,
|
641 |
+
x,
|
642 |
+
t,
|
643 |
+
clip_denoised=clip_denoised,
|
644 |
+
denoised_fn=denoised_fn,
|
645 |
+
model_kwargs=model_kwargs,
|
646 |
+
)
|
647 |
+
if cond_fn is not None:
|
648 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
649 |
+
|
650 |
+
# Usually our model outputs epsilon, but we re-derive it
|
651 |
+
# in case we used x_start or x_prev prediction.
|
652 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
653 |
+
|
654 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
655 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
656 |
+
sigma = (
|
657 |
+
eta
|
658 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
659 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
660 |
+
)
|
661 |
+
# Equation 12.
|
662 |
+
noise = th.randn_like(x)
|
663 |
+
mean_pred = (
|
664 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
665 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
666 |
+
)
|
667 |
+
nonzero_mask = (
|
668 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
669 |
+
) # no noise when t == 0
|
670 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
671 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
672 |
+
|
673 |
+
def ddim_reverse_sample(
|
674 |
+
self,
|
675 |
+
model,
|
676 |
+
x,
|
677 |
+
t,
|
678 |
+
clip_denoised=True,
|
679 |
+
denoised_fn=None,
|
680 |
+
model_kwargs=None,
|
681 |
+
eta=0.0,
|
682 |
+
):
|
683 |
+
"""
|
684 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
685 |
+
"""
|
686 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
687 |
+
out = self.p_mean_variance(
|
688 |
+
model,
|
689 |
+
x,
|
690 |
+
t,
|
691 |
+
clip_denoised=clip_denoised,
|
692 |
+
denoised_fn=denoised_fn,
|
693 |
+
model_kwargs=model_kwargs,
|
694 |
+
)
|
695 |
+
# Usually our model outputs epsilon, but we re-derive it
|
696 |
+
# in case we used x_start or x_prev prediction.
|
697 |
+
eps = (
|
698 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
699 |
+
- out["pred_xstart"]
|
700 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
701 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
702 |
+
|
703 |
+
# Equation 12. reversed
|
704 |
+
mean_pred = (
|
705 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_next)
|
706 |
+
+ th.sqrt(1 - alpha_bar_next) * eps
|
707 |
+
)
|
708 |
+
|
709 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
710 |
+
|
711 |
+
def ddim_sample_loop(
|
712 |
+
self,
|
713 |
+
model,
|
714 |
+
shape,
|
715 |
+
noise=None,
|
716 |
+
clip_denoised=True,
|
717 |
+
denoised_fn=None,
|
718 |
+
cond_fn=None,
|
719 |
+
model_kwargs=None,
|
720 |
+
device=None,
|
721 |
+
progress=False,
|
722 |
+
eta=0.0,
|
723 |
+
):
|
724 |
+
"""
|
725 |
+
Generate samples from the model using DDIM.
|
726 |
+
|
727 |
+
Same usage as p_sample_loop().
|
728 |
+
"""
|
729 |
+
final = None
|
730 |
+
for sample in self.ddim_sample_loop_progressive(
|
731 |
+
model,
|
732 |
+
shape,
|
733 |
+
noise=noise,
|
734 |
+
clip_denoised=clip_denoised,
|
735 |
+
denoised_fn=denoised_fn,
|
736 |
+
cond_fn=cond_fn,
|
737 |
+
model_kwargs=model_kwargs,
|
738 |
+
device=device,
|
739 |
+
progress=progress,
|
740 |
+
eta=eta,
|
741 |
+
):
|
742 |
+
final = sample
|
743 |
+
return final["sample"]
|
744 |
+
|
745 |
+
def ddim_sample_loop_progressive(
|
746 |
+
self,
|
747 |
+
model,
|
748 |
+
shape,
|
749 |
+
noise=None,
|
750 |
+
clip_denoised=True,
|
751 |
+
denoised_fn=None,
|
752 |
+
cond_fn=None,
|
753 |
+
model_kwargs=None,
|
754 |
+
device=None,
|
755 |
+
progress=False,
|
756 |
+
eta=0.0,
|
757 |
+
):
|
758 |
+
"""
|
759 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
760 |
+
each timestep of DDIM.
|
761 |
+
|
762 |
+
Same usage as p_sample_loop_progressive().
|
763 |
+
"""
|
764 |
+
if device is None:
|
765 |
+
device = next(model.parameters()).device
|
766 |
+
assert isinstance(shape, (tuple, list))
|
767 |
+
if noise is not None:
|
768 |
+
img = noise
|
769 |
+
else:
|
770 |
+
img = th.randn(*shape, device=device)
|
771 |
+
indices = list(range(self.num_timesteps))[::-1]
|
772 |
+
|
773 |
+
if progress:
|
774 |
+
# Lazy import so that we don't depend on tqdm.
|
775 |
+
from tqdm.auto import tqdm
|
776 |
+
|
777 |
+
indices = tqdm(indices)
|
778 |
+
|
779 |
+
for i in indices:
|
780 |
+
t = th.tensor([i] * shape[0], device=device)
|
781 |
+
with th.no_grad():
|
782 |
+
out = self.ddim_sample(
|
783 |
+
model,
|
784 |
+
img,
|
785 |
+
t,
|
786 |
+
clip_denoised=clip_denoised,
|
787 |
+
denoised_fn=denoised_fn,
|
788 |
+
cond_fn=cond_fn,
|
789 |
+
model_kwargs=model_kwargs,
|
790 |
+
eta=eta,
|
791 |
+
)
|
792 |
+
yield out
|
793 |
+
img = out["sample"]
|
794 |
+
|
795 |
+
def _vb_terms_bpd(
|
796 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
797 |
+
):
|
798 |
+
"""
|
799 |
+
Get a term for the variational lower-bound.
|
800 |
+
|
801 |
+
The resulting units are bits (rather than nats, as one might expect).
|
802 |
+
This allows for comparison to other papers.
|
803 |
+
|
804 |
+
:return: a dict with the following keys:
|
805 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
806 |
+
- 'pred_xstart': the x_0 predictions.
|
807 |
+
"""
|
808 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
809 |
+
x_start=x_start, x_t=x_t, t=t
|
810 |
+
)
|
811 |
+
out = self.p_mean_variance(
|
812 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
813 |
+
)
|
814 |
+
kl = normal_kl(
|
815 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
816 |
+
)
|
817 |
+
kl = mean_flat(kl) / np.log(2.0)
|
818 |
+
|
819 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
820 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
821 |
+
)
|
822 |
+
assert decoder_nll.shape == x_start.shape
|
823 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
824 |
+
|
825 |
+
# At the first timestep return the decoder NLL,
|
826 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
827 |
+
output = th.where((t == 0), decoder_nll, kl)
|
828 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
829 |
+
|
830 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
831 |
+
"""
|
832 |
+
Compute training losses for a single timestep.
|
833 |
+
|
834 |
+
:param model: the model to evaluate loss on.
|
835 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
836 |
+
:param t: a batch of timestep indices.
|
837 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
838 |
+
pass to the model. This can be used for conditioning.
|
839 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
840 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
841 |
+
Some mean or variance settings may also have other keys.
|
842 |
+
"""
|
843 |
+
if model_kwargs is None:
|
844 |
+
model_kwargs = {}
|
845 |
+
if noise is None:
|
846 |
+
noise = th.randn_like(x_start)
|
847 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
848 |
+
|
849 |
+
terms = {}
|
850 |
+
|
851 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
852 |
+
# TODO: support multiple model outputs for this mode.
|
853 |
+
terms["loss"] = self._vb_terms_bpd(
|
854 |
+
model=model,
|
855 |
+
x_start=x_start,
|
856 |
+
x_t=x_t,
|
857 |
+
t=t,
|
858 |
+
clip_denoised=False,
|
859 |
+
model_kwargs=model_kwargs,
|
860 |
+
)["output"]
|
861 |
+
if self.loss_type == LossType.RESCALED_KL:
|
862 |
+
terms["loss"] *= self.num_timesteps
|
863 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
864 |
+
model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
|
865 |
+
if isinstance(model_outputs, tuple):
|
866 |
+
model_output = model_outputs[0]
|
867 |
+
terms['extra_outputs'] = model_outputs[1:]
|
868 |
+
else:
|
869 |
+
model_output = model_outputs
|
870 |
+
|
871 |
+
if self.model_var_type in [
|
872 |
+
ModelVarType.LEARNED,
|
873 |
+
ModelVarType.LEARNED_RANGE,
|
874 |
+
]:
|
875 |
+
B, C = x_t.shape[:2]
|
876 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
877 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
878 |
+
# Learn the variance using the variational bound, but don't let
|
879 |
+
# it affect our mean prediction.
|
880 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
881 |
+
terms["vb"] = self._vb_terms_bpd(
|
882 |
+
model=lambda *args, r=frozen_out: r,
|
883 |
+
x_start=x_start,
|
884 |
+
x_t=x_t,
|
885 |
+
t=t,
|
886 |
+
clip_denoised=False,
|
887 |
+
)["output"]
|
888 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
889 |
+
# Divide by 1000 for equivalence with initial implementation.
|
890 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
891 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
892 |
+
|
893 |
+
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
894 |
+
target = self.q_posterior_mean_variance(
|
895 |
+
x_start=x_start, x_t=x_t, t=t
|
896 |
+
)[0]
|
897 |
+
x_start_pred = torch.zeros(x_start) # Not supported.
|
898 |
+
elif self.model_mean_type == ModelMeanType.START_X:
|
899 |
+
target = x_start
|
900 |
+
x_start_pred = model_output
|
901 |
+
elif self.model_mean_type == ModelMeanType.EPSILON:
|
902 |
+
target = noise
|
903 |
+
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
|
904 |
+
else:
|
905 |
+
raise NotImplementedError(self.model_mean_type)
|
906 |
+
assert model_output.shape == target.shape == x_start.shape
|
907 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
908 |
+
terms["x_start_predicted"] = x_start_pred
|
909 |
+
if "vb" in terms:
|
910 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
911 |
+
else:
|
912 |
+
terms["loss"] = terms["mse"]
|
913 |
+
else:
|
914 |
+
raise NotImplementedError(self.loss_type)
|
915 |
+
|
916 |
+
return terms
|
917 |
+
|
918 |
+
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
|
919 |
+
"""
|
920 |
+
Compute training losses for a single timestep.
|
921 |
+
|
922 |
+
:param model: the model to evaluate loss on.
|
923 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
924 |
+
:param t: a batch of timestep indices.
|
925 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
926 |
+
pass to the model. This can be used for conditioning.
|
927 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
928 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
929 |
+
Some mean or variance settings may also have other keys.
|
930 |
+
"""
|
931 |
+
if model_kwargs is None:
|
932 |
+
model_kwargs = {}
|
933 |
+
if noise is None:
|
934 |
+
noise = th.randn_like(x_start)
|
935 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
936 |
+
terms = {}
|
937 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
938 |
+
assert False # not currently supported for this type of diffusion.
|
939 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
940 |
+
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
941 |
+
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
942 |
+
model_output = terms[gd_out_key]
|
943 |
+
if self.model_var_type in [
|
944 |
+
ModelVarType.LEARNED,
|
945 |
+
ModelVarType.LEARNED_RANGE,
|
946 |
+
]:
|
947 |
+
B, C = x_t.shape[:2]
|
948 |
+
assert model_output.shape == (B, C, 2, *x_t.shape[2:])
|
949 |
+
model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1]
|
950 |
+
# Learn the variance using the variational bound, but don't let
|
951 |
+
# it affect our mean prediction.
|
952 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
953 |
+
terms["vb"] = self._vb_terms_bpd(
|
954 |
+
model=lambda *args, r=frozen_out: r,
|
955 |
+
x_start=x_start,
|
956 |
+
x_t=x_t,
|
957 |
+
t=t,
|
958 |
+
clip_denoised=False,
|
959 |
+
)["output"]
|
960 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
961 |
+
# Divide by 1000 for equivalence with initial implementation.
|
962 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
963 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
964 |
+
|
965 |
+
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
966 |
+
target = self.q_posterior_mean_variance(
|
967 |
+
x_start=x_start, x_t=x_t, t=t
|
968 |
+
)[0]
|
969 |
+
x_start_pred = torch.zeros(x_start) # Not supported.
|
970 |
+
elif self.model_mean_type == ModelMeanType.START_X:
|
971 |
+
target = x_start
|
972 |
+
x_start_pred = model_output
|
973 |
+
elif self.model_mean_type == ModelMeanType.EPSILON:
|
974 |
+
target = noise
|
975 |
+
x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output)
|
976 |
+
else:
|
977 |
+
raise NotImplementedError(self.model_mean_type)
|
978 |
+
assert model_output.shape == target.shape == x_start.shape
|
979 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
980 |
+
terms["x_start_predicted"] = x_start_pred
|
981 |
+
if "vb" in terms:
|
982 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
983 |
+
else:
|
984 |
+
terms["loss"] = terms["mse"]
|
985 |
+
else:
|
986 |
+
raise NotImplementedError(self.loss_type)
|
987 |
+
|
988 |
+
return terms
|
989 |
+
|
990 |
+
def _prior_bpd(self, x_start):
|
991 |
+
"""
|
992 |
+
Get the prior KL term for the variational lower-bound, measured in
|
993 |
+
bits-per-dim.
|
994 |
+
|
995 |
+
This term can't be optimized, as it only depends on the encoder.
|
996 |
+
|
997 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
998 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
999 |
+
"""
|
1000 |
+
batch_size = x_start.shape[0]
|
1001 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
1002 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
1003 |
+
kl_prior = normal_kl(
|
1004 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
1005 |
+
)
|
1006 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
1007 |
+
|
1008 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
1009 |
+
"""
|
1010 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
1011 |
+
as well as other related quantities.
|
1012 |
+
|
1013 |
+
:param model: the model to evaluate loss on.
|
1014 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1015 |
+
:param clip_denoised: if True, clip denoised samples.
|
1016 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
1017 |
+
pass to the model. This can be used for conditioning.
|
1018 |
+
|
1019 |
+
:return: a dict containing the following keys:
|
1020 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
1021 |
+
- prior_bpd: the prior term in the lower-bound.
|
1022 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
1023 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
1024 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
1025 |
+
"""
|
1026 |
+
device = x_start.device
|
1027 |
+
batch_size = x_start.shape[0]
|
1028 |
+
|
1029 |
+
vb = []
|
1030 |
+
xstart_mse = []
|
1031 |
+
mse = []
|
1032 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
1033 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
1034 |
+
noise = th.randn_like(x_start)
|
1035 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
1036 |
+
# Calculate VLB term at the current timestep
|
1037 |
+
with th.no_grad():
|
1038 |
+
out = self._vb_terms_bpd(
|
1039 |
+
model,
|
1040 |
+
x_start=x_start,
|
1041 |
+
x_t=x_t,
|
1042 |
+
t=t_batch,
|
1043 |
+
clip_denoised=clip_denoised,
|
1044 |
+
model_kwargs=model_kwargs,
|
1045 |
+
)
|
1046 |
+
vb.append(out["output"])
|
1047 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
1048 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
1049 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
1050 |
+
|
1051 |
+
vb = th.stack(vb, dim=1)
|
1052 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
1053 |
+
mse = th.stack(mse, dim=1)
|
1054 |
+
|
1055 |
+
prior_bpd = self._prior_bpd(x_start)
|
1056 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
1057 |
+
return {
|
1058 |
+
"total_bpd": total_bpd,
|
1059 |
+
"prior_bpd": prior_bpd,
|
1060 |
+
"vb": vb,
|
1061 |
+
"xstart_mse": xstart_mse,
|
1062 |
+
"mse": mse,
|
1063 |
+
}
|
1064 |
+
|
1065 |
+
|
1066 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
1067 |
+
"""
|
1068 |
+
Get a pre-defined beta schedule for the given name.
|
1069 |
+
|
1070 |
+
The beta schedule library consists of beta schedules which remain similar
|
1071 |
+
in the limit of num_diffusion_timesteps.
|
1072 |
+
Beta schedules may be added, but should not be removed or changed once
|
1073 |
+
they are committed to maintain backwards compatibility.
|
1074 |
+
"""
|
1075 |
+
if schedule_name == "linear":
|
1076 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
1077 |
+
# diffusion steps.
|
1078 |
+
scale = 1000 / num_diffusion_timesteps
|
1079 |
+
beta_start = scale * 0.0001
|
1080 |
+
beta_end = scale * 0.02
|
1081 |
+
return np.linspace(
|
1082 |
+
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
1083 |
+
)
|
1084 |
+
elif schedule_name == "cosine":
|
1085 |
+
return betas_for_alpha_bar(
|
1086 |
+
num_diffusion_timesteps,
|
1087 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
1088 |
+
)
|
1089 |
+
else:
|
1090 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
1091 |
+
|
1092 |
+
|
1093 |
+
class SpacedDiffusion(GaussianDiffusion):
|
1094 |
+
"""
|
1095 |
+
A diffusion process which can skip steps in a base diffusion process.
|
1096 |
+
|
1097 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
1098 |
+
original diffusion process to retain.
|
1099 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
1100 |
+
"""
|
1101 |
+
|
1102 |
+
def __init__(self, use_timesteps, **kwargs):
|
1103 |
+
self.use_timesteps = set(use_timesteps)
|
1104 |
+
self.timestep_map = []
|
1105 |
+
self.original_num_steps = len(kwargs["betas"])
|
1106 |
+
|
1107 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
1108 |
+
last_alpha_cumprod = 1.0
|
1109 |
+
new_betas = []
|
1110 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
1111 |
+
if i in self.use_timesteps:
|
1112 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
1113 |
+
last_alpha_cumprod = alpha_cumprod
|
1114 |
+
self.timestep_map.append(i)
|
1115 |
+
kwargs["betas"] = np.array(new_betas)
|
1116 |
+
super().__init__(**kwargs)
|
1117 |
+
|
1118 |
+
def p_mean_variance(
|
1119 |
+
self, model, *args, **kwargs
|
1120 |
+
): # pylint: disable=signature-differs
|
1121 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
1122 |
+
|
1123 |
+
def training_losses(
|
1124 |
+
self, model, *args, **kwargs
|
1125 |
+
): # pylint: disable=signature-differs
|
1126 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
1127 |
+
|
1128 |
+
def autoregressive_training_losses(
|
1129 |
+
self, model, *args, **kwargs
|
1130 |
+
): # pylint: disable=signature-differs
|
1131 |
+
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
|
1132 |
+
|
1133 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
1134 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
1135 |
+
|
1136 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
1137 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
1138 |
+
|
1139 |
+
def _wrap_model(self, model, autoregressive=False):
|
1140 |
+
if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
|
1141 |
+
return model
|
1142 |
+
mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
|
1143 |
+
return mod(
|
1144 |
+
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
def _scale_timesteps(self, t):
|
1148 |
+
# Scaling is done by the wrapped model.
|
1149 |
+
return t
|
1150 |
+
|
1151 |
+
|
1152 |
+
def space_timesteps(num_timesteps, section_counts):
|
1153 |
+
"""
|
1154 |
+
Create a list of timesteps to use from an original diffusion process,
|
1155 |
+
given the number of timesteps we want to take from equally-sized portions
|
1156 |
+
of the original process.
|
1157 |
+
|
1158 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
1159 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
1160 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
1161 |
+
|
1162 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
1163 |
+
from the DDIM paper is used, and only one section is allowed.
|
1164 |
+
|
1165 |
+
:param num_timesteps: the number of diffusion steps in the original
|
1166 |
+
process to divide up.
|
1167 |
+
:param section_counts: either a list of numbers, or a string containing
|
1168 |
+
comma-separated numbers, indicating the step count
|
1169 |
+
per section. As a special case, use "ddimN" where N
|
1170 |
+
is a number of steps to use the striding from the
|
1171 |
+
DDIM paper.
|
1172 |
+
:return: a set of diffusion steps from the original process to use.
|
1173 |
+
"""
|
1174 |
+
if isinstance(section_counts, str):
|
1175 |
+
if section_counts.startswith("ddim"):
|
1176 |
+
desired_count = int(section_counts[len("ddim") :])
|
1177 |
+
for i in range(1, num_timesteps):
|
1178 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
1179 |
+
return set(range(0, num_timesteps, i))
|
1180 |
+
raise ValueError(
|
1181 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
1182 |
+
)
|
1183 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
1184 |
+
size_per = num_timesteps // len(section_counts)
|
1185 |
+
extra = num_timesteps % len(section_counts)
|
1186 |
+
start_idx = 0
|
1187 |
+
all_steps = []
|
1188 |
+
for i, section_count in enumerate(section_counts):
|
1189 |
+
size = size_per + (1 if i < extra else 0)
|
1190 |
+
if size < section_count:
|
1191 |
+
raise ValueError(
|
1192 |
+
f"cannot divide section of {size} steps into {section_count}"
|
1193 |
+
)
|
1194 |
+
if section_count <= 1:
|
1195 |
+
frac_stride = 1
|
1196 |
+
else:
|
1197 |
+
frac_stride = (size - 1) / (section_count - 1)
|
1198 |
+
cur_idx = 0.0
|
1199 |
+
taken_steps = []
|
1200 |
+
for _ in range(section_count):
|
1201 |
+
taken_steps.append(start_idx + round(cur_idx))
|
1202 |
+
cur_idx += frac_stride
|
1203 |
+
all_steps += taken_steps
|
1204 |
+
start_idx += size
|
1205 |
+
return set(all_steps)
|
1206 |
+
|
1207 |
+
|
1208 |
+
class _WrappedModel:
|
1209 |
+
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
1210 |
+
self.model = model
|
1211 |
+
self.timestep_map = timestep_map
|
1212 |
+
self.rescale_timesteps = rescale_timesteps
|
1213 |
+
self.original_num_steps = original_num_steps
|
1214 |
+
|
1215 |
+
def __call__(self, x, ts, **kwargs):
|
1216 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
1217 |
+
new_ts = map_tensor[ts]
|
1218 |
+
if self.rescale_timesteps:
|
1219 |
+
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
1220 |
+
return self.model(x, new_ts, **kwargs)
|
1221 |
+
|
1222 |
+
|
1223 |
+
class _WrappedAutoregressiveModel:
|
1224 |
+
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
1225 |
+
self.model = model
|
1226 |
+
self.timestep_map = timestep_map
|
1227 |
+
self.rescale_timesteps = rescale_timesteps
|
1228 |
+
self.original_num_steps = original_num_steps
|
1229 |
+
|
1230 |
+
def __call__(self, x, x0, ts, **kwargs):
|
1231 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
1232 |
+
new_ts = map_tensor[ts]
|
1233 |
+
if self.rescale_timesteps:
|
1234 |
+
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
1235 |
+
return self.model(x, x0, new_ts, **kwargs)
|
1236 |
+
|
1237 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
1238 |
+
"""
|
1239 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
1240 |
+
|
1241 |
+
:param arr: the 1-D numpy array.
|
1242 |
+
:param timesteps: a tensor of indices into the array to extract.
|
1243 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
1244 |
+
dimension equal to the length of timesteps.
|
1245 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
1246 |
+
"""
|
1247 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
1248 |
+
while len(res.shape) < len(broadcast_shape):
|
1249 |
+
res = res[..., None]
|
1250 |
+
return res.expand(broadcast_shape)
|
utils/stft.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BSD 3-Clause License
|
3 |
+
|
4 |
+
Copyright (c) 2017, Prem Seetharaman
|
5 |
+
All rights reserved.
|
6 |
+
|
7 |
+
* Redistribution and use in source and binary forms, with or without
|
8 |
+
modification, are permitted provided that the following conditions are met:
|
9 |
+
|
10 |
+
* Redistributions of source code must retain the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer.
|
12 |
+
|
13 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
14 |
+
list of conditions and the following disclaimer in the
|
15 |
+
documentation and/or other materials provided with the distribution.
|
16 |
+
|
17 |
+
* Neither the name of the copyright holder nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived from this
|
19 |
+
software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31 |
+
"""
|
32 |
+
|
33 |
+
import torch
|
34 |
+
import numpy as np
|
35 |
+
import torch.nn.functional as F
|
36 |
+
from torch.autograd import Variable
|
37 |
+
from scipy.signal import get_window
|
38 |
+
from librosa.util import pad_center, tiny
|
39 |
+
import librosa.util as librosa_util
|
40 |
+
|
41 |
+
|
42 |
+
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
43 |
+
n_fft=800, dtype=np.float32, norm=None):
|
44 |
+
"""
|
45 |
+
# from librosa 0.6
|
46 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
47 |
+
|
48 |
+
This is used to estimate modulation effects induced by windowing
|
49 |
+
observations in short-time fourier transforms.
|
50 |
+
|
51 |
+
Parameters
|
52 |
+
----------
|
53 |
+
window : string, tuple, number, callable, or list-like
|
54 |
+
Window specification, as in `get_window`
|
55 |
+
|
56 |
+
n_frames : int > 0
|
57 |
+
The number of analysis frames
|
58 |
+
|
59 |
+
hop_length : int > 0
|
60 |
+
The number of samples to advance between frames
|
61 |
+
|
62 |
+
win_length : [optional]
|
63 |
+
The length of the window function. By default, this matches `n_fft`.
|
64 |
+
|
65 |
+
n_fft : int > 0
|
66 |
+
The length of each analysis frame.
|
67 |
+
|
68 |
+
dtype : np.dtype
|
69 |
+
The data type of the output
|
70 |
+
|
71 |
+
Returns
|
72 |
+
-------
|
73 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
74 |
+
The sum-squared envelope of the window function
|
75 |
+
"""
|
76 |
+
if win_length is None:
|
77 |
+
win_length = n_fft
|
78 |
+
|
79 |
+
n = n_fft + hop_length * (n_frames - 1)
|
80 |
+
x = np.zeros(n, dtype=dtype)
|
81 |
+
|
82 |
+
# Compute the squared window at the desired length
|
83 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
84 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
85 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
86 |
+
|
87 |
+
# Fill the envelope
|
88 |
+
for i in range(n_frames):
|
89 |
+
sample = i * hop_length
|
90 |
+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class STFT(torch.nn.Module):
|
95 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
96 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
97 |
+
window='hann'):
|
98 |
+
super(STFT, self).__init__()
|
99 |
+
self.filter_length = filter_length
|
100 |
+
self.hop_length = hop_length
|
101 |
+
self.win_length = win_length
|
102 |
+
self.window = window
|
103 |
+
self.forward_transform = None
|
104 |
+
scale = self.filter_length / self.hop_length
|
105 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
106 |
+
|
107 |
+
cutoff = int((self.filter_length / 2 + 1))
|
108 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
109 |
+
np.imag(fourier_basis[:cutoff, :])])
|
110 |
+
|
111 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
112 |
+
inverse_basis = torch.FloatTensor(
|
113 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
114 |
+
|
115 |
+
if window is not None:
|
116 |
+
assert(filter_length >= win_length)
|
117 |
+
# get window and zero center pad it to filter_length
|
118 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
119 |
+
fft_window = pad_center(fft_window, filter_length)
|
120 |
+
fft_window = torch.from_numpy(fft_window).float()
|
121 |
+
|
122 |
+
# window the bases
|
123 |
+
forward_basis *= fft_window
|
124 |
+
inverse_basis *= fft_window
|
125 |
+
|
126 |
+
self.register_buffer('forward_basis', forward_basis.float())
|
127 |
+
self.register_buffer('inverse_basis', inverse_basis.float())
|
128 |
+
|
129 |
+
def transform(self, input_data):
|
130 |
+
num_batches = input_data.size(0)
|
131 |
+
num_samples = input_data.size(1)
|
132 |
+
|
133 |
+
self.num_samples = num_samples
|
134 |
+
|
135 |
+
# similar to librosa, reflect-pad the input
|
136 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
137 |
+
input_data = F.pad(
|
138 |
+
input_data.unsqueeze(1),
|
139 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
140 |
+
mode='reflect')
|
141 |
+
input_data = input_data.squeeze(1)
|
142 |
+
|
143 |
+
forward_transform = F.conv1d(
|
144 |
+
input_data,
|
145 |
+
Variable(self.forward_basis, requires_grad=False),
|
146 |
+
stride=self.hop_length,
|
147 |
+
padding=0)
|
148 |
+
|
149 |
+
cutoff = int((self.filter_length / 2) + 1)
|
150 |
+
real_part = forward_transform[:, :cutoff, :]
|
151 |
+
imag_part = forward_transform[:, cutoff:, :]
|
152 |
+
|
153 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
154 |
+
phase = torch.autograd.Variable(
|
155 |
+
torch.atan2(imag_part.data, real_part.data))
|
156 |
+
|
157 |
+
return magnitude, phase
|
158 |
+
|
159 |
+
def inverse(self, magnitude, phase):
|
160 |
+
recombine_magnitude_phase = torch.cat(
|
161 |
+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
162 |
+
|
163 |
+
inverse_transform = F.conv_transpose1d(
|
164 |
+
recombine_magnitude_phase,
|
165 |
+
Variable(self.inverse_basis, requires_grad=False),
|
166 |
+
stride=self.hop_length,
|
167 |
+
padding=0)
|
168 |
+
|
169 |
+
if self.window is not None:
|
170 |
+
window_sum = window_sumsquare(
|
171 |
+
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
172 |
+
win_length=self.win_length, n_fft=self.filter_length,
|
173 |
+
dtype=np.float32)
|
174 |
+
# remove modulation effects
|
175 |
+
approx_nonzero_indices = torch.from_numpy(
|
176 |
+
np.where(window_sum > tiny(window_sum))[0])
|
177 |
+
window_sum = torch.autograd.Variable(
|
178 |
+
torch.from_numpy(window_sum), requires_grad=False)
|
179 |
+
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
180 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
181 |
+
|
182 |
+
# scale by hop ratio
|
183 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
184 |
+
|
185 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
186 |
+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
187 |
+
|
188 |
+
return inverse_transform
|
189 |
+
|
190 |
+
def forward(self, input_data):
|
191 |
+
self.magnitude, self.phase = self.transform(input_data)
|
192 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
193 |
+
return reconstruction
|
utils/tokenizer.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import inflect
|
4 |
+
import torch
|
5 |
+
from tokenizers import Tokenizer
|
6 |
+
|
7 |
+
|
8 |
+
# Regular expression matching whitespace:
|
9 |
+
from unidecode import unidecode
|
10 |
+
|
11 |
+
_whitespace_re = re.compile(r'\s+')
|
12 |
+
|
13 |
+
|
14 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
15 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
16 |
+
('mrs', 'misess'),
|
17 |
+
('mr', 'mister'),
|
18 |
+
('dr', 'doctor'),
|
19 |
+
('st', 'saint'),
|
20 |
+
('co', 'company'),
|
21 |
+
('jr', 'junior'),
|
22 |
+
('maj', 'major'),
|
23 |
+
('gen', 'general'),
|
24 |
+
('drs', 'doctors'),
|
25 |
+
('rev', 'reverend'),
|
26 |
+
('lt', 'lieutenant'),
|
27 |
+
('hon', 'honorable'),
|
28 |
+
('sgt', 'sergeant'),
|
29 |
+
('capt', 'captain'),
|
30 |
+
('esq', 'esquire'),
|
31 |
+
('ltd', 'limited'),
|
32 |
+
('col', 'colonel'),
|
33 |
+
('ft', 'fort'),
|
34 |
+
]]
|
35 |
+
|
36 |
+
|
37 |
+
def expand_abbreviations(text):
|
38 |
+
for regex, replacement in _abbreviations:
|
39 |
+
text = re.sub(regex, replacement, text)
|
40 |
+
return text
|
41 |
+
|
42 |
+
|
43 |
+
_inflect = inflect.engine()
|
44 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
45 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
46 |
+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
47 |
+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
48 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
49 |
+
_number_re = re.compile(r'[0-9]+')
|
50 |
+
|
51 |
+
|
52 |
+
def _remove_commas(m):
|
53 |
+
return m.group(1).replace(',', '')
|
54 |
+
|
55 |
+
|
56 |
+
def _expand_decimal_point(m):
|
57 |
+
return m.group(1).replace('.', ' point ')
|
58 |
+
|
59 |
+
|
60 |
+
def _expand_dollars(m):
|
61 |
+
match = m.group(1)
|
62 |
+
parts = match.split('.')
|
63 |
+
if len(parts) > 2:
|
64 |
+
return match + ' dollars' # Unexpected format
|
65 |
+
dollars = int(parts[0]) if parts[0] else 0
|
66 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
67 |
+
if dollars and cents:
|
68 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
69 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
70 |
+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
71 |
+
elif dollars:
|
72 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
73 |
+
return '%s %s' % (dollars, dollar_unit)
|
74 |
+
elif cents:
|
75 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
76 |
+
return '%s %s' % (cents, cent_unit)
|
77 |
+
else:
|
78 |
+
return 'zero dollars'
|
79 |
+
|
80 |
+
|
81 |
+
def _expand_ordinal(m):
|
82 |
+
return _inflect.number_to_words(m.group(0))
|
83 |
+
|
84 |
+
|
85 |
+
def _expand_number(m):
|
86 |
+
num = int(m.group(0))
|
87 |
+
if num > 1000 and num < 3000:
|
88 |
+
if num == 2000:
|
89 |
+
return 'two thousand'
|
90 |
+
elif num > 2000 and num < 2010:
|
91 |
+
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
92 |
+
elif num % 100 == 0:
|
93 |
+
return _inflect.number_to_words(num // 100) + ' hundred'
|
94 |
+
else:
|
95 |
+
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
96 |
+
else:
|
97 |
+
return _inflect.number_to_words(num, andword='')
|
98 |
+
|
99 |
+
|
100 |
+
def normalize_numbers(text):
|
101 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
102 |
+
text = re.sub(_pounds_re, r'\1 pounds', text)
|
103 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
104 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
105 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
106 |
+
text = re.sub(_number_re, _expand_number, text)
|
107 |
+
return text
|
108 |
+
|
109 |
+
|
110 |
+
def expand_numbers(text):
|
111 |
+
return normalize_numbers(text)
|
112 |
+
|
113 |
+
|
114 |
+
def lowercase(text):
|
115 |
+
return text.lower()
|
116 |
+
|
117 |
+
|
118 |
+
def collapse_whitespace(text):
|
119 |
+
return re.sub(_whitespace_re, ' ', text)
|
120 |
+
|
121 |
+
|
122 |
+
def convert_to_ascii(text):
|
123 |
+
return unidecode(text)
|
124 |
+
|
125 |
+
|
126 |
+
def basic_cleaners(text):
|
127 |
+
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
128 |
+
text = lowercase(text)
|
129 |
+
text = collapse_whitespace(text)
|
130 |
+
return text
|
131 |
+
|
132 |
+
|
133 |
+
def transliteration_cleaners(text):
|
134 |
+
'''Pipeline for non-English text that transliterates to ASCII.'''
|
135 |
+
text = convert_to_ascii(text)
|
136 |
+
text = lowercase(text)
|
137 |
+
text = collapse_whitespace(text)
|
138 |
+
return text
|
139 |
+
|
140 |
+
|
141 |
+
def english_cleaners(text):
|
142 |
+
'''Pipeline for English text, including number and abbreviation expansion.'''
|
143 |
+
text = convert_to_ascii(text)
|
144 |
+
text = lowercase(text)
|
145 |
+
text = expand_numbers(text)
|
146 |
+
text = expand_abbreviations(text)
|
147 |
+
text = collapse_whitespace(text)
|
148 |
+
text = text.replace('"', '')
|
149 |
+
return text
|
150 |
+
|
151 |
+
def lev_distance(s1, s2):
|
152 |
+
if len(s1) > len(s2):
|
153 |
+
s1, s2 = s2, s1
|
154 |
+
|
155 |
+
distances = range(len(s1) + 1)
|
156 |
+
for i2, c2 in enumerate(s2):
|
157 |
+
distances_ = [i2 + 1]
|
158 |
+
for i1, c1 in enumerate(s1):
|
159 |
+
if c1 == c2:
|
160 |
+
distances_.append(distances[i1])
|
161 |
+
else:
|
162 |
+
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
163 |
+
distances = distances_
|
164 |
+
return distances[-1]
|
165 |
+
|
166 |
+
class VoiceBpeTokenizer:
|
167 |
+
def __init__(self, vocab_file='data/tokenizer.json'):
|
168 |
+
if vocab_file is not None:
|
169 |
+
self.tokenizer = Tokenizer.from_file(vocab_file)
|
170 |
+
|
171 |
+
def preprocess_text(self, txt):
|
172 |
+
txt = english_cleaners(txt)
|
173 |
+
return txt
|
174 |
+
|
175 |
+
def encode(self, txt):
|
176 |
+
txt = self.preprocess_text(txt)
|
177 |
+
txt = txt.replace(' ', '[SPACE]')
|
178 |
+
return self.tokenizer.encode(txt).ids
|
179 |
+
|
180 |
+
def decode(self, seq):
|
181 |
+
if isinstance(seq, torch.Tensor):
|
182 |
+
seq = seq.cpu().numpy()
|
183 |
+
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
|
184 |
+
txt = txt.replace('[SPACE]', ' ')
|
185 |
+
txt = txt.replace('[STOP]', '')
|
186 |
+
txt = txt.replace('[UNK]', '')
|
187 |
+
return txt
|
utils/typical_sampling.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import LogitsWarper
|
3 |
+
|
4 |
+
|
5 |
+
class TypicalLogitsWarper(LogitsWarper):
|
6 |
+
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
7 |
+
self.filter_value = filter_value
|
8 |
+
self.mass = mass
|
9 |
+
self.min_tokens_to_keep = min_tokens_to_keep
|
10 |
+
|
11 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
12 |
+
# calculate entropy
|
13 |
+
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
14 |
+
p = torch.exp(normalized)
|
15 |
+
ent = -(normalized * p).nansum(-1, keepdim=True)
|
16 |
+
|
17 |
+
# shift and sort
|
18 |
+
shifted_scores = torch.abs((-normalized) - ent)
|
19 |
+
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
20 |
+
sorted_logits = scores.gather(-1, sorted_indices)
|
21 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
22 |
+
|
23 |
+
# Remove tokens with cumulative mass above the threshold
|
24 |
+
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
25 |
+
last_ind[last_ind < 0] = 0
|
26 |
+
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
27 |
+
if self.min_tokens_to_keep > 1:
|
28 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
29 |
+
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
30 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
31 |
+
|
32 |
+
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
33 |
+
return scores
|