Spaces:
Runtime error
Runtime error
Duplicate from haoheliu/audioldm-text-to-audio-generation
Browse filesCo-authored-by: haoheliu <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +34 -0
- .gitignore +4 -0
- LICENSE +13 -0
- README.md +14 -0
- app.py +236 -0
- audioldm/__init__.py +3 -0
- audioldm/audio/__init__.py +0 -0
- audioldm/audio/audio_processing.py +100 -0
- audioldm/audio/stft.py +180 -0
- audioldm/audio/tools.py +33 -0
- audioldm/clap/__init__.py +0 -0
- audioldm/clap/encoders.py +169 -0
- audioldm/clap/open_clip/__init__.py +25 -0
- audioldm/clap/open_clip/bert.py +40 -0
- audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- audioldm/clap/open_clip/factory.py +277 -0
- audioldm/clap/open_clip/feature_fusion.py +192 -0
- audioldm/clap/open_clip/htsat.py +1308 -0
- audioldm/clap/open_clip/linear_probe.py +66 -0
- audioldm/clap/open_clip/loss.py +398 -0
- audioldm/clap/open_clip/model.py +936 -0
- audioldm/clap/open_clip/model_configs/HTSAT-base.json +23 -0
- audioldm/clap/open_clip/model_configs/HTSAT-large.json +23 -0
- audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +23 -0
- audioldm/clap/open_clip/model_configs/HTSAT-tiny.json +23 -0
- audioldm/clap/open_clip/model_configs/PANN-10.json +23 -0
- audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json +23 -0
- audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +23 -0
- audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +23 -0
- audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json +23 -0
- audioldm/clap/open_clip/model_configs/PANN-14.json +23 -0
- audioldm/clap/open_clip/model_configs/PANN-6.json +23 -0
- audioldm/clap/open_clip/model_configs/RN101-quickgelu.json +22 -0
- audioldm/clap/open_clip/model_configs/RN101.json +21 -0
- audioldm/clap/open_clip/model_configs/RN50-quickgelu.json +22 -0
- audioldm/clap/open_clip/model_configs/RN50.json +21 -0
- audioldm/clap/open_clip/model_configs/RN50x16.json +21 -0
- audioldm/clap/open_clip/model_configs/RN50x4.json +21 -0
- audioldm/clap/open_clip/model_configs/ViT-B-16.json +16 -0
- audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
- audioldm/clap/open_clip/model_configs/ViT-B-32.json +16 -0
- audioldm/clap/open_clip/model_configs/ViT-L-14.json +16 -0
- audioldm/clap/open_clip/openai.py +156 -0
- audioldm/clap/open_clip/pann_model.py +703 -0
- audioldm/clap/open_clip/pretrained.py +167 -0
- audioldm/clap/open_clip/timm_model.py +112 -0
- audioldm/clap/open_clip/tokenizer.py +197 -0
- audioldm/clap/open_clip/transform.py +45 -0
- audioldm/clap/open_clip/utils.py +361 -0
- audioldm/clap/open_clip/version.py +1 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
__pycache__
|
3 |
+
test.py
|
4 |
+
flagged
|
LICENSE
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
“Commons Clause” License Condition v1.0
|
2 |
+
|
3 |
+
The Software is provided to you by the Licensor under the License, as defined below, subject to the following condition.
|
4 |
+
|
5 |
+
Without limiting other conditions in the License, the grant of rights under the License will not include, and the License does not grant to you, the right to Sell the Software.
|
6 |
+
|
7 |
+
For purposes of the foregoing, “Sell” means practicing any or all of the rights granted to you under the License to provide to third parties, for a fee or other consideration (including without limitation fees for hosting or consulting/ support services related to the Software), a product or service whose value derives, entirely or substantially, from the functionality of the Software. Any license notice or attribution required by the License must also include this Commons Clause License Condition notice.
|
8 |
+
|
9 |
+
Software: AudioLDM (including all related model and software)
|
10 |
+
|
11 |
+
License: Apache 2.0
|
12 |
+
|
13 |
+
Licensor: Haohe Liu
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Audioldm Text To Audio Generation
|
3 |
+
emoji: 🔊
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.16.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: bigscience-openrail-m
|
11 |
+
duplicated_from: haoheliu/audioldm-text-to-audio-generation
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from audioldm import text_to_audio, build_model
|
4 |
+
from share_btn import community_icon_html, loading_icon_html, share_js
|
5 |
+
|
6 |
+
model_id="haoheliu/AudioLDM-S-Full"
|
7 |
+
|
8 |
+
audioldm = build_model()
|
9 |
+
# audioldm=None
|
10 |
+
|
11 |
+
# def predict(input, history=[]):
|
12 |
+
# # tokenize the new input sentence
|
13 |
+
# new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
|
14 |
+
|
15 |
+
# # append the new user input tokens to the chat history
|
16 |
+
# bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
|
17 |
+
|
18 |
+
# # generate a response
|
19 |
+
# history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
|
20 |
+
|
21 |
+
# # convert the tokens to text, and then split the responses into lines
|
22 |
+
# response = tokenizer.decode(history[0]).split("<|endoftext|>")
|
23 |
+
# response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
|
24 |
+
# return response, history
|
25 |
+
|
26 |
+
def text2audio(text, duration, guidance_scale, random_seed, n_candidates):
|
27 |
+
# print(text, length, guidance_scale)
|
28 |
+
waveform = text_to_audio(audioldm, text, random_seed, duration=duration, guidance_scale=guidance_scale, n_candidate_gen_per_text=int(n_candidates)) # [bs, 1, samples]
|
29 |
+
waveform = [gr.make_waveform((16000, wave[0])) for wave in waveform]
|
30 |
+
# waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
|
31 |
+
if(len(waveform) == 1):
|
32 |
+
waveform = waveform[0]
|
33 |
+
return waveform,gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
34 |
+
|
35 |
+
# iface = gr.Interface(fn=text2audio, inputs=[
|
36 |
+
# gr.Textbox(value="A man is speaking in a huge room", max_lines=1),
|
37 |
+
# gr.Slider(2.5, 10, value=5, step=2.5),
|
38 |
+
# gr.Slider(0, 5, value=2.5, step=0.5),
|
39 |
+
# gr.Number(value=42)
|
40 |
+
# ], outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")],
|
41 |
+
# allow_flagging="never"
|
42 |
+
# )
|
43 |
+
# iface.launch(share=True)
|
44 |
+
|
45 |
+
css = """
|
46 |
+
.gradio-container {
|
47 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
48 |
+
}
|
49 |
+
.gr-button {
|
50 |
+
color: white;
|
51 |
+
border-color: black;
|
52 |
+
background: black;
|
53 |
+
}
|
54 |
+
input[type='range'] {
|
55 |
+
accent-color: black;
|
56 |
+
}
|
57 |
+
.dark input[type='range'] {
|
58 |
+
accent-color: #dfdfdf;
|
59 |
+
}
|
60 |
+
.container {
|
61 |
+
max-width: 730px;
|
62 |
+
margin: auto;
|
63 |
+
padding-top: 1.5rem;
|
64 |
+
}
|
65 |
+
#gallery {
|
66 |
+
min-height: 22rem;
|
67 |
+
margin-bottom: 15px;
|
68 |
+
margin-left: auto;
|
69 |
+
margin-right: auto;
|
70 |
+
border-bottom-right-radius: .5rem !important;
|
71 |
+
border-bottom-left-radius: .5rem !important;
|
72 |
+
}
|
73 |
+
#gallery>div>.h-full {
|
74 |
+
min-height: 20rem;
|
75 |
+
}
|
76 |
+
.details:hover {
|
77 |
+
text-decoration: underline;
|
78 |
+
}
|
79 |
+
.gr-button {
|
80 |
+
white-space: nowrap;
|
81 |
+
}
|
82 |
+
.gr-button:focus {
|
83 |
+
border-color: rgb(147 197 253 / var(--tw-border-opacity));
|
84 |
+
outline: none;
|
85 |
+
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
|
86 |
+
--tw-border-opacity: 1;
|
87 |
+
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
|
88 |
+
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
|
89 |
+
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
|
90 |
+
--tw-ring-opacity: .5;
|
91 |
+
}
|
92 |
+
#advanced-btn {
|
93 |
+
font-size: .7rem !important;
|
94 |
+
line-height: 19px;
|
95 |
+
margin-top: 12px;
|
96 |
+
margin-bottom: 12px;
|
97 |
+
padding: 2px 8px;
|
98 |
+
border-radius: 14px !important;
|
99 |
+
}
|
100 |
+
#advanced-options {
|
101 |
+
display: none;
|
102 |
+
margin-bottom: 20px;
|
103 |
+
}
|
104 |
+
.footer {
|
105 |
+
margin-bottom: 45px;
|
106 |
+
margin-top: 35px;
|
107 |
+
text-align: center;
|
108 |
+
border-bottom: 1px solid #e5e5e5;
|
109 |
+
}
|
110 |
+
.footer>p {
|
111 |
+
font-size: .8rem;
|
112 |
+
display: inline-block;
|
113 |
+
padding: 0 10px;
|
114 |
+
transform: translateY(10px);
|
115 |
+
background: white;
|
116 |
+
}
|
117 |
+
.dark .footer {
|
118 |
+
border-color: #303030;
|
119 |
+
}
|
120 |
+
.dark .footer>p {
|
121 |
+
background: #0b0f19;
|
122 |
+
}
|
123 |
+
.acknowledgments h4{
|
124 |
+
margin: 1.25em 0 .25em 0;
|
125 |
+
font-weight: bold;
|
126 |
+
font-size: 115%;
|
127 |
+
}
|
128 |
+
.animate-spin {
|
129 |
+
animation: spin 1s linear infinite;
|
130 |
+
}
|
131 |
+
@keyframes spin {
|
132 |
+
from {
|
133 |
+
transform: rotate(0deg);
|
134 |
+
}
|
135 |
+
to {
|
136 |
+
transform: rotate(360deg);
|
137 |
+
}
|
138 |
+
}
|
139 |
+
#share-btn-container {
|
140 |
+
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
|
141 |
+
margin-top: 10px;
|
142 |
+
margin-left: auto;
|
143 |
+
}
|
144 |
+
#share-btn {
|
145 |
+
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
|
146 |
+
}
|
147 |
+
#share-btn * {
|
148 |
+
all: unset;
|
149 |
+
}
|
150 |
+
#share-btn-container div:nth-child(-n+2){
|
151 |
+
width: auto !important;
|
152 |
+
min-height: 0px !important;
|
153 |
+
}
|
154 |
+
#share-btn-container .wrap {
|
155 |
+
display: none !important;
|
156 |
+
}
|
157 |
+
|
158 |
+
.gr-form{
|
159 |
+
flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
|
160 |
+
}
|
161 |
+
#prompt-container{
|
162 |
+
gap: 0;
|
163 |
+
}
|
164 |
+
#prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
|
165 |
+
#component-16{border-top-width: 1px!important;margin-top: 1em}
|
166 |
+
.image_duplication{position: absolute; width: 100px; left: 50px}
|
167 |
+
"""
|
168 |
+
iface = gr.Blocks(css=css)
|
169 |
+
|
170 |
+
with iface:
|
171 |
+
gr.HTML(
|
172 |
+
"""
|
173 |
+
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
174 |
+
<div
|
175 |
+
style="
|
176 |
+
display: inline-flex;
|
177 |
+
align-items: center;
|
178 |
+
gap: 0.8rem;
|
179 |
+
font-size: 1.75rem;
|
180 |
+
"
|
181 |
+
>
|
182 |
+
<h1 style="font-weight: 900; margin-bottom: 7px;">
|
183 |
+
AudioLDM: Text-to-Audio Generation with Latent Diffusion Models
|
184 |
+
</h1>
|
185 |
+
</div>
|
186 |
+
<p style="margin-bottom: 10px; font-size: 94%">
|
187 |
+
<a href="https://arxiv.org/abs/2301.12503">[Paper]</a> <a href="https://audioldm.github.io/">[Project page]</a>
|
188 |
+
</p>
|
189 |
+
</div>
|
190 |
+
"""
|
191 |
+
)
|
192 |
+
with gr.Group():
|
193 |
+
with gr.Box():
|
194 |
+
############# Input
|
195 |
+
textbox = gr.Textbox(value="A hammer is hitting a wooden surface", max_lines=1, label="Input your text here. Please ensure it is descriptive and of moderate length.")
|
196 |
+
|
197 |
+
with gr.Accordion("Click to modify detailed configurations", open=False):
|
198 |
+
seed = gr.Number(value=42, label="Change this value (any integer number) will lead to a different generation result.")
|
199 |
+
duration = gr.Slider(2.5, 10, value=5, step=2.5, label="Duration (seconds)")
|
200 |
+
guidance_scale = gr.Slider(0, 5, value=2.5, step=0.5, label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)")
|
201 |
+
n_candidates = gr.Slider(1, 5, value=3, step=1, label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation")
|
202 |
+
############# Output
|
203 |
+
# outputs=gr.Audio(label="Output", type="numpy")
|
204 |
+
outputs=gr.Video(label="Output")
|
205 |
+
with gr.Group(elem_id="container-advanced-btns"):
|
206 |
+
# advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
|
207 |
+
with gr.Group(elem_id="share-btn-container"):
|
208 |
+
community_icon = gr.HTML(community_icon_html, visible=False)
|
209 |
+
loading_icon = gr.HTML(loading_icon_html, visible=False)
|
210 |
+
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
211 |
+
# outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")]
|
212 |
+
|
213 |
+
btn = gr.Button("Submit").style(full_width=True)
|
214 |
+
btn.click(text2audio, inputs=[textbox, duration, guidance_scale, seed, n_candidates], outputs=[outputs, community_icon, loading_icon, share_button]) # , share_button, community_icon, loading_icon
|
215 |
+
share_button.click(None, [], [], _js=share_js)
|
216 |
+
gr.HTML('''
|
217 |
+
<hr>
|
218 |
+
<div class="footer" style="text-align: center; max-width: 700px; margin: 0 auto;">
|
219 |
+
<p>Model by <a href="https://twitter.com/LiuHaohe" style="text-decoration: underline;" target="_blank">Haohe Liu</a>
|
220 |
+
</p>
|
221 |
+
</div>
|
222 |
+
''')
|
223 |
+
|
224 |
+
with gr.Accordion("Additional information", open=False):
|
225 |
+
gr.HTML(
|
226 |
+
"""
|
227 |
+
<div class="acknowledgments">
|
228 |
+
<p> We build the model with data from <a href="http://research.google.com/audioset/">AudioSet</a>, <a href="https://freesound.org/">Freesound</a> and <a href="https://sound-effects.bbcrewind.co.uk/">BBC Sound Effect library</a>. We share this demo based on the <a href="https://assets.publishing.service.gov.uk/government/uploads/system/uploads/attachment_data/file/375954/Research.pdf">UK copyright exception</a> of data for academic research. </p>
|
229 |
+
<p>This demo is strictly for research demo purpose only. For commercial use please <a href="[email protected]">contact us</a>.</p>
|
230 |
+
</div>
|
231 |
+
"""
|
232 |
+
)
|
233 |
+
|
234 |
+
iface.queue(concurrency_count = 2)
|
235 |
+
iface.launch(debug=True)
|
236 |
+
# iface.launch(debug=True, share=True)
|
audioldm/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .ldm import LatentDiffusion
|
2 |
+
from .utils import seed_everything
|
3 |
+
from .pipeline import *
|
audioldm/audio/__init__.py
ADDED
File without changes
|
audioldm/audio/audio_processing.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import librosa.util as librosa_util
|
4 |
+
from scipy.signal import get_window
|
5 |
+
|
6 |
+
|
7 |
+
def window_sumsquare(
|
8 |
+
window,
|
9 |
+
n_frames,
|
10 |
+
hop_length,
|
11 |
+
win_length,
|
12 |
+
n_fft,
|
13 |
+
dtype=np.float32,
|
14 |
+
norm=None,
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
# from librosa 0.6
|
18 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
19 |
+
|
20 |
+
This is used to estimate modulation effects induced by windowing
|
21 |
+
observations in short-time fourier transforms.
|
22 |
+
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
window : string, tuple, number, callable, or list-like
|
26 |
+
Window specification, as in `get_window`
|
27 |
+
|
28 |
+
n_frames : int > 0
|
29 |
+
The number of analysis frames
|
30 |
+
|
31 |
+
hop_length : int > 0
|
32 |
+
The number of samples to advance between frames
|
33 |
+
|
34 |
+
win_length : [optional]
|
35 |
+
The length of the window function. By default, this matches `n_fft`.
|
36 |
+
|
37 |
+
n_fft : int > 0
|
38 |
+
The length of each analysis frame.
|
39 |
+
|
40 |
+
dtype : np.dtype
|
41 |
+
The data type of the output
|
42 |
+
|
43 |
+
Returns
|
44 |
+
-------
|
45 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
46 |
+
The sum-squared envelope of the window function
|
47 |
+
"""
|
48 |
+
if win_length is None:
|
49 |
+
win_length = n_fft
|
50 |
+
|
51 |
+
n = n_fft + hop_length * (n_frames - 1)
|
52 |
+
x = np.zeros(n, dtype=dtype)
|
53 |
+
|
54 |
+
# Compute the squared window at the desired length
|
55 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
56 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
57 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
58 |
+
|
59 |
+
# Fill the envelope
|
60 |
+
for i in range(n_frames):
|
61 |
+
sample = i * hop_length
|
62 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
67 |
+
"""
|
68 |
+
PARAMS
|
69 |
+
------
|
70 |
+
magnitudes: spectrogram magnitudes
|
71 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
72 |
+
"""
|
73 |
+
|
74 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
75 |
+
angles = angles.astype(np.float32)
|
76 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
77 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
78 |
+
|
79 |
+
for i in range(n_iters):
|
80 |
+
_, angles = stft_fn.transform(signal)
|
81 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
82 |
+
return signal
|
83 |
+
|
84 |
+
|
85 |
+
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
|
86 |
+
"""
|
87 |
+
PARAMS
|
88 |
+
------
|
89 |
+
C: compression factor
|
90 |
+
"""
|
91 |
+
return normalize_fun(torch.clamp(x, min=clip_val) * C)
|
92 |
+
|
93 |
+
|
94 |
+
def dynamic_range_decompression(x, C=1):
|
95 |
+
"""
|
96 |
+
PARAMS
|
97 |
+
------
|
98 |
+
C: compression factor used to compress
|
99 |
+
"""
|
100 |
+
return torch.exp(x) / C
|
audioldm/audio/stft.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy.signal import get_window
|
5 |
+
from librosa.util import pad_center, tiny
|
6 |
+
from librosa.filters import mel as librosa_mel_fn
|
7 |
+
|
8 |
+
from audioldm.audio.audio_processing import (
|
9 |
+
dynamic_range_compression,
|
10 |
+
dynamic_range_decompression,
|
11 |
+
window_sumsquare,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class STFT(torch.nn.Module):
|
16 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
17 |
+
|
18 |
+
def __init__(self, filter_length, hop_length, win_length, window="hann"):
|
19 |
+
super(STFT, self).__init__()
|
20 |
+
self.filter_length = filter_length
|
21 |
+
self.hop_length = hop_length
|
22 |
+
self.win_length = win_length
|
23 |
+
self.window = window
|
24 |
+
self.forward_transform = None
|
25 |
+
scale = self.filter_length / self.hop_length
|
26 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
27 |
+
|
28 |
+
cutoff = int((self.filter_length / 2 + 1))
|
29 |
+
fourier_basis = np.vstack(
|
30 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
31 |
+
)
|
32 |
+
|
33 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
34 |
+
inverse_basis = torch.FloatTensor(
|
35 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
36 |
+
)
|
37 |
+
|
38 |
+
if window is not None:
|
39 |
+
assert filter_length >= win_length
|
40 |
+
# get window and zero center pad it to filter_length
|
41 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
42 |
+
fft_window = pad_center(fft_window, filter_length)
|
43 |
+
fft_window = torch.from_numpy(fft_window).float()
|
44 |
+
|
45 |
+
# window the bases
|
46 |
+
forward_basis *= fft_window
|
47 |
+
inverse_basis *= fft_window
|
48 |
+
|
49 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
50 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
51 |
+
|
52 |
+
def transform(self, input_data):
|
53 |
+
num_batches = input_data.size(0)
|
54 |
+
num_samples = input_data.size(1)
|
55 |
+
|
56 |
+
self.num_samples = num_samples
|
57 |
+
|
58 |
+
# similar to librosa, reflect-pad the input
|
59 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
60 |
+
input_data = F.pad(
|
61 |
+
input_data.unsqueeze(1),
|
62 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
63 |
+
mode="reflect",
|
64 |
+
)
|
65 |
+
input_data = input_data.squeeze(1)
|
66 |
+
|
67 |
+
forward_transform = F.conv1d(
|
68 |
+
input_data,
|
69 |
+
torch.autograd.Variable(self.forward_basis, requires_grad=False),
|
70 |
+
stride=self.hop_length,
|
71 |
+
padding=0,
|
72 |
+
).cpu()
|
73 |
+
|
74 |
+
cutoff = int((self.filter_length / 2) + 1)
|
75 |
+
real_part = forward_transform[:, :cutoff, :]
|
76 |
+
imag_part = forward_transform[:, cutoff:, :]
|
77 |
+
|
78 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
79 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
80 |
+
|
81 |
+
return magnitude, phase
|
82 |
+
|
83 |
+
def inverse(self, magnitude, phase):
|
84 |
+
recombine_magnitude_phase = torch.cat(
|
85 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
86 |
+
)
|
87 |
+
|
88 |
+
inverse_transform = F.conv_transpose1d(
|
89 |
+
recombine_magnitude_phase,
|
90 |
+
torch.autograd.Variable(self.inverse_basis, requires_grad=False),
|
91 |
+
stride=self.hop_length,
|
92 |
+
padding=0,
|
93 |
+
)
|
94 |
+
|
95 |
+
if self.window is not None:
|
96 |
+
window_sum = window_sumsquare(
|
97 |
+
self.window,
|
98 |
+
magnitude.size(-1),
|
99 |
+
hop_length=self.hop_length,
|
100 |
+
win_length=self.win_length,
|
101 |
+
n_fft=self.filter_length,
|
102 |
+
dtype=np.float32,
|
103 |
+
)
|
104 |
+
# remove modulation effects
|
105 |
+
approx_nonzero_indices = torch.from_numpy(
|
106 |
+
np.where(window_sum > tiny(window_sum))[0]
|
107 |
+
)
|
108 |
+
window_sum = torch.autograd.Variable(
|
109 |
+
torch.from_numpy(window_sum), requires_grad=False
|
110 |
+
)
|
111 |
+
window_sum = window_sum
|
112 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
113 |
+
approx_nonzero_indices
|
114 |
+
]
|
115 |
+
|
116 |
+
# scale by hop ratio
|
117 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
118 |
+
|
119 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
120 |
+
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
121 |
+
|
122 |
+
return inverse_transform
|
123 |
+
|
124 |
+
def forward(self, input_data):
|
125 |
+
self.magnitude, self.phase = self.transform(input_data)
|
126 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
127 |
+
return reconstruction
|
128 |
+
|
129 |
+
|
130 |
+
class TacotronSTFT(torch.nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
filter_length,
|
134 |
+
hop_length,
|
135 |
+
win_length,
|
136 |
+
n_mel_channels,
|
137 |
+
sampling_rate,
|
138 |
+
mel_fmin,
|
139 |
+
mel_fmax,
|
140 |
+
):
|
141 |
+
super(TacotronSTFT, self).__init__()
|
142 |
+
self.n_mel_channels = n_mel_channels
|
143 |
+
self.sampling_rate = sampling_rate
|
144 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
145 |
+
mel_basis = librosa_mel_fn(
|
146 |
+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
|
147 |
+
)
|
148 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
149 |
+
self.register_buffer("mel_basis", mel_basis)
|
150 |
+
|
151 |
+
def spectral_normalize(self, magnitudes, normalize_fun):
|
152 |
+
output = dynamic_range_compression(magnitudes, normalize_fun)
|
153 |
+
return output
|
154 |
+
|
155 |
+
def spectral_de_normalize(self, magnitudes):
|
156 |
+
output = dynamic_range_decompression(magnitudes)
|
157 |
+
return output
|
158 |
+
|
159 |
+
def mel_spectrogram(self, y, normalize_fun=torch.log):
|
160 |
+
"""Computes mel-spectrograms from a batch of waves
|
161 |
+
PARAMS
|
162 |
+
------
|
163 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
164 |
+
|
165 |
+
RETURNS
|
166 |
+
-------
|
167 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
168 |
+
"""
|
169 |
+
assert torch.min(y.data) >= -1, torch.min(y.data)
|
170 |
+
assert torch.max(y.data) <= 1, torch.max(y.data)
|
171 |
+
|
172 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
173 |
+
magnitudes = magnitudes.data
|
174 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
175 |
+
mel_output = self.spectral_normalize(mel_output, normalize_fun)
|
176 |
+
energy = torch.norm(magnitudes, dim=1)
|
177 |
+
|
178 |
+
log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
|
179 |
+
|
180 |
+
return mel_output, log_magnitudes, energy
|
audioldm/audio/tools.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def get_mel_from_wav(audio, _stft):
|
6 |
+
audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
|
7 |
+
audio = torch.autograd.Variable(audio, requires_grad=False)
|
8 |
+
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
|
9 |
+
melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
|
10 |
+
log_magnitudes_stft = (
|
11 |
+
torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
|
12 |
+
)
|
13 |
+
energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
|
14 |
+
return melspec, log_magnitudes_stft, energy
|
15 |
+
|
16 |
+
|
17 |
+
# def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
|
18 |
+
# mel = torch.stack([mel])
|
19 |
+
# mel_decompress = _stft.spectral_de_normalize(mel)
|
20 |
+
# mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
|
21 |
+
# spec_from_mel_scaling = 1000
|
22 |
+
# spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
|
23 |
+
# spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
|
24 |
+
# spec_from_mel = spec_from_mel * spec_from_mel_scaling
|
25 |
+
|
26 |
+
# audio = griffin_lim(
|
27 |
+
# torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
|
28 |
+
# )
|
29 |
+
|
30 |
+
# audio = audio.squeeze()
|
31 |
+
# audio = audio.cpu().numpy()
|
32 |
+
# audio_path = out_filename
|
33 |
+
# write(audio_path, _stft.sampling_rate, audio)
|
audioldm/clap/__init__.py
ADDED
File without changes
|
audioldm/clap/encoders.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from audioldm.clap.open_clip import create_model
|
4 |
+
from audioldm.clap.training.data import get_audio_features
|
5 |
+
import torchaudio
|
6 |
+
from transformers import RobertaTokenizer
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
pretrained_path="",
|
14 |
+
key="class",
|
15 |
+
sampling_rate=16000,
|
16 |
+
embed_mode="audio",
|
17 |
+
unconditional_prob=0.1,
|
18 |
+
random_mute=False,
|
19 |
+
max_random_mute_portion=0.5,
|
20 |
+
training_mode=True,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.key = key
|
25 |
+
self.device = "cpu"
|
26 |
+
self.precision = "fp32"
|
27 |
+
self.amodel = "HTSAT-tiny" # or 'PANN-14'
|
28 |
+
self.tmodel = "roberta" # the best text encoder in our training
|
29 |
+
self.enable_fusion = False # False if you do not want to use the fusion model
|
30 |
+
self.fusion_type = "aff_2d"
|
31 |
+
self.pretrained = pretrained_path
|
32 |
+
self.embed_mode = embed_mode
|
33 |
+
self.embed_mode_orig = embed_mode
|
34 |
+
self.sampling_rate = sampling_rate
|
35 |
+
self.unconditional_prob = unconditional_prob
|
36 |
+
self.random_mute = random_mute
|
37 |
+
self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
|
38 |
+
self.max_random_mute_portion = max_random_mute_portion
|
39 |
+
self.training_mode = training_mode
|
40 |
+
self.model, self.model_cfg = create_model(
|
41 |
+
self.amodel,
|
42 |
+
self.tmodel,
|
43 |
+
self.pretrained,
|
44 |
+
precision=self.precision,
|
45 |
+
device=self.device,
|
46 |
+
enable_fusion=self.enable_fusion,
|
47 |
+
fusion_type=self.fusion_type,
|
48 |
+
)
|
49 |
+
for p in self.model.parameters():
|
50 |
+
p.requires_grad = False
|
51 |
+
|
52 |
+
self.model.eval()
|
53 |
+
|
54 |
+
def get_unconditional_condition(self, batchsize):
|
55 |
+
self.unconditional_token = self.model.get_text_embedding(
|
56 |
+
self.tokenizer(["", ""])
|
57 |
+
)[0:1]
|
58 |
+
return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
|
59 |
+
|
60 |
+
def batch_to_list(self, batch):
|
61 |
+
ret = []
|
62 |
+
for i in range(batch.size(0)):
|
63 |
+
ret.append(batch[i])
|
64 |
+
return ret
|
65 |
+
|
66 |
+
def make_decision(self, probability):
|
67 |
+
if float(torch.rand(1)) < probability:
|
68 |
+
return True
|
69 |
+
else:
|
70 |
+
return False
|
71 |
+
|
72 |
+
def random_uniform(self, start, end):
|
73 |
+
val = torch.rand(1).item()
|
74 |
+
return start + (end - start) * val
|
75 |
+
|
76 |
+
def _random_mute(self, waveform):
|
77 |
+
# waveform: [bs, t-steps]
|
78 |
+
t_steps = waveform.size(-1)
|
79 |
+
for i in range(waveform.size(0)):
|
80 |
+
mute_size = int(
|
81 |
+
self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
|
82 |
+
)
|
83 |
+
mute_start = int(self.random_uniform(0, t_steps - mute_size))
|
84 |
+
waveform[i, mute_start : mute_start + mute_size] = 0
|
85 |
+
return waveform
|
86 |
+
|
87 |
+
def cos_similarity(self, waveform, text):
|
88 |
+
# waveform: [bs, t_steps]
|
89 |
+
with torch.no_grad():
|
90 |
+
self.embed_mode = "audio"
|
91 |
+
audio_emb = self(waveform.cuda())
|
92 |
+
self.embed_mode = "text"
|
93 |
+
text_emb = self(text)
|
94 |
+
similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
|
95 |
+
return similarity.squeeze()
|
96 |
+
|
97 |
+
def forward(self, batch, key=None):
|
98 |
+
# If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
|
99 |
+
# If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
|
100 |
+
if self.model.training == True and not self.training_mode:
|
101 |
+
print(
|
102 |
+
"The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
|
103 |
+
)
|
104 |
+
self.model, self.model_cfg = create_model(
|
105 |
+
self.amodel,
|
106 |
+
self.tmodel,
|
107 |
+
self.pretrained,
|
108 |
+
precision=self.precision,
|
109 |
+
device="cuda",
|
110 |
+
enable_fusion=self.enable_fusion,
|
111 |
+
fusion_type=self.fusion_type,
|
112 |
+
)
|
113 |
+
for p in self.model.parameters():
|
114 |
+
p.requires_grad = False
|
115 |
+
self.model.eval()
|
116 |
+
|
117 |
+
# the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
|
118 |
+
if self.embed_mode == "audio":
|
119 |
+
with torch.no_grad():
|
120 |
+
audio_dict_list = []
|
121 |
+
assert (
|
122 |
+
self.sampling_rate == 16000
|
123 |
+
), "We only support 16000 sampling rate"
|
124 |
+
if self.random_mute:
|
125 |
+
batch = self._random_mute(batch)
|
126 |
+
# batch: [bs, 1, t-samples]
|
127 |
+
batch = torchaudio.functional.resample(
|
128 |
+
batch, orig_freq=self.sampling_rate, new_freq=48000
|
129 |
+
)
|
130 |
+
for waveform in self.batch_to_list(batch):
|
131 |
+
audio_dict = {}
|
132 |
+
audio_dict = get_audio_features(
|
133 |
+
audio_dict,
|
134 |
+
waveform,
|
135 |
+
480000,
|
136 |
+
data_truncating="fusion",
|
137 |
+
data_filling="repeatpad",
|
138 |
+
audio_cfg=self.model_cfg["audio_cfg"],
|
139 |
+
)
|
140 |
+
audio_dict_list.append(audio_dict)
|
141 |
+
# [bs, 512]
|
142 |
+
embed = self.model.get_audio_embedding(audio_dict_list)
|
143 |
+
elif self.embed_mode == "text":
|
144 |
+
with torch.no_grad():
|
145 |
+
# the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
|
146 |
+
text_data = self.tokenizer(batch)
|
147 |
+
embed = self.model.get_text_embedding(text_data)
|
148 |
+
|
149 |
+
embed = embed.unsqueeze(1)
|
150 |
+
self.unconditional_token = self.model.get_text_embedding(
|
151 |
+
self.tokenizer(["", ""])
|
152 |
+
)[0:1]
|
153 |
+
|
154 |
+
for i in range(embed.size(0)):
|
155 |
+
if self.make_decision(self.unconditional_prob):
|
156 |
+
embed[i] = self.unconditional_token
|
157 |
+
|
158 |
+
# [bs, 1, 512]
|
159 |
+
return embed.detach()
|
160 |
+
|
161 |
+
def tokenizer(self, text):
|
162 |
+
result = self.tokenize(
|
163 |
+
text,
|
164 |
+
padding="max_length",
|
165 |
+
truncation=True,
|
166 |
+
max_length=77,
|
167 |
+
return_tensors="pt",
|
168 |
+
)
|
169 |
+
return {k: v.squeeze(0) for k, v in result.items()}
|
audioldm/clap/open_clip/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .factory import (
|
2 |
+
list_models,
|
3 |
+
create_model,
|
4 |
+
create_model_and_transforms,
|
5 |
+
add_model_config,
|
6 |
+
)
|
7 |
+
from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
|
8 |
+
from .model import (
|
9 |
+
CLAP,
|
10 |
+
CLAPTextCfg,
|
11 |
+
CLAPVisionCfg,
|
12 |
+
CLAPAudioCfp,
|
13 |
+
convert_weights_to_fp16,
|
14 |
+
trace_model,
|
15 |
+
)
|
16 |
+
from .openai import load_openai_model, list_openai_models
|
17 |
+
from .pretrained import (
|
18 |
+
list_pretrained,
|
19 |
+
list_pretrained_tag_models,
|
20 |
+
list_pretrained_model_tags,
|
21 |
+
get_pretrained_url,
|
22 |
+
download_pretrained,
|
23 |
+
)
|
24 |
+
from .tokenizer import SimpleTokenizer, tokenize
|
25 |
+
from .transform import image_transform
|
audioldm/clap/open_clip/bert.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizer, BertModel
|
2 |
+
|
3 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
4 |
+
model = BertModel.from_pretrained("bert-base-uncased")
|
5 |
+
text = "Replace me by any text you'd like."
|
6 |
+
|
7 |
+
|
8 |
+
def bert_embeddings(text):
|
9 |
+
# text = "Replace me by any text you'd like."
|
10 |
+
encoded_input = tokenizer(text, return_tensors="pt")
|
11 |
+
output = model(**encoded_input)
|
12 |
+
return output
|
13 |
+
|
14 |
+
|
15 |
+
from transformers import RobertaTokenizer, RobertaModel
|
16 |
+
|
17 |
+
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
18 |
+
model = RobertaModel.from_pretrained("roberta-base")
|
19 |
+
text = "Replace me by any text you'd like."
|
20 |
+
|
21 |
+
|
22 |
+
def Roberta_embeddings(text):
|
23 |
+
# text = "Replace me by any text you'd like."
|
24 |
+
encoded_input = tokenizer(text, return_tensors="pt")
|
25 |
+
output = model(**encoded_input)
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
from transformers import BartTokenizer, BartModel
|
30 |
+
|
31 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
32 |
+
model = BartModel.from_pretrained("facebook/bart-base")
|
33 |
+
text = "Replace me by any text you'd like."
|
34 |
+
|
35 |
+
|
36 |
+
def bart_embeddings(text):
|
37 |
+
# text = "Replace me by any text you'd like."
|
38 |
+
encoded_input = tokenizer(text, return_tensors="pt")
|
39 |
+
output = model(**encoded_input)
|
40 |
+
return output
|
audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
audioldm/clap/open_clip/factory.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .model import CLAP, convert_weights_to_fp16
|
12 |
+
from .openai import load_openai_model
|
13 |
+
from .pretrained import get_pretrained_url, download_pretrained
|
14 |
+
from .transform import image_transform
|
15 |
+
|
16 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
17 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
18 |
+
|
19 |
+
|
20 |
+
def _natural_key(string_):
|
21 |
+
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
22 |
+
|
23 |
+
|
24 |
+
def _rescan_model_configs():
|
25 |
+
global _MODEL_CONFIGS
|
26 |
+
|
27 |
+
config_ext = (".json",)
|
28 |
+
config_files = []
|
29 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
30 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
31 |
+
config_files.append(config_path)
|
32 |
+
elif config_path.is_dir():
|
33 |
+
for ext in config_ext:
|
34 |
+
config_files.extend(config_path.glob(f"*{ext}"))
|
35 |
+
|
36 |
+
for cf in config_files:
|
37 |
+
if os.path.basename(cf)[0] == ".":
|
38 |
+
continue # Ignore hidden files
|
39 |
+
|
40 |
+
with open(cf, "r") as f:
|
41 |
+
model_cfg = json.load(f)
|
42 |
+
if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
|
43 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
44 |
+
|
45 |
+
_MODEL_CONFIGS = {
|
46 |
+
k: v
|
47 |
+
for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
_rescan_model_configs() # initial populate of model config registry
|
52 |
+
|
53 |
+
|
54 |
+
def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
|
55 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
56 |
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
57 |
+
state_dict = checkpoint["state_dict"]
|
58 |
+
else:
|
59 |
+
state_dict = checkpoint
|
60 |
+
if skip_params:
|
61 |
+
if next(iter(state_dict.items()))[0].startswith("module"):
|
62 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
63 |
+
# for k in state_dict:
|
64 |
+
# if k.startswith('transformer'):
|
65 |
+
# v = state_dict.pop(k)
|
66 |
+
# state_dict['text_branch.' + k[12:]] = v
|
67 |
+
return state_dict
|
68 |
+
|
69 |
+
|
70 |
+
def create_model(
|
71 |
+
amodel_name: str,
|
72 |
+
tmodel_name: str,
|
73 |
+
pretrained: str = "",
|
74 |
+
precision: str = "fp32",
|
75 |
+
device: torch.device = torch.device("cpu"),
|
76 |
+
jit: bool = False,
|
77 |
+
force_quick_gelu: bool = False,
|
78 |
+
openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
|
79 |
+
skip_params=True,
|
80 |
+
pretrained_audio: str = "",
|
81 |
+
pretrained_text: str = "",
|
82 |
+
enable_fusion: bool = False,
|
83 |
+
fusion_type: str = "None"
|
84 |
+
# pretrained_image: bool = False,
|
85 |
+
):
|
86 |
+
amodel_name = amodel_name.replace(
|
87 |
+
"/", "-"
|
88 |
+
) # for callers using old naming with / in ViT names
|
89 |
+
pretrained_orig = pretrained
|
90 |
+
pretrained = pretrained.lower()
|
91 |
+
if pretrained == "openai":
|
92 |
+
if amodel_name in _MODEL_CONFIGS:
|
93 |
+
logging.info(f"Loading {amodel_name} model config.")
|
94 |
+
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
|
95 |
+
else:
|
96 |
+
logging.error(
|
97 |
+
f"Model config for {amodel_name} not found; available models {list_models()}."
|
98 |
+
)
|
99 |
+
raise RuntimeError(f"Model config for {amodel_name} not found.")
|
100 |
+
|
101 |
+
logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
|
102 |
+
# Hard Code in model name
|
103 |
+
model_cfg["text_cfg"]["model_type"] = tmodel_name
|
104 |
+
model = load_openai_model(
|
105 |
+
"ViT-B-16",
|
106 |
+
model_cfg,
|
107 |
+
device=device,
|
108 |
+
jit=jit,
|
109 |
+
cache_dir=openai_model_cache_dir,
|
110 |
+
enable_fusion=enable_fusion,
|
111 |
+
fusion_type=fusion_type,
|
112 |
+
)
|
113 |
+
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
|
114 |
+
if precision == "amp" or precision == "fp32":
|
115 |
+
model = model.float()
|
116 |
+
else:
|
117 |
+
if amodel_name in _MODEL_CONFIGS:
|
118 |
+
logging.info(f"Loading {amodel_name} model config.")
|
119 |
+
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
|
120 |
+
else:
|
121 |
+
logging.error(
|
122 |
+
f"Model config for {amodel_name} not found; available models {list_models()}."
|
123 |
+
)
|
124 |
+
raise RuntimeError(f"Model config for {amodel_name} not found.")
|
125 |
+
|
126 |
+
if force_quick_gelu:
|
127 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
128 |
+
model_cfg["quick_gelu"] = True
|
129 |
+
|
130 |
+
# if pretrained_image:
|
131 |
+
# if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
|
132 |
+
# # pretrained weight loading for timm models set via vision_cfg
|
133 |
+
# model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
134 |
+
# else:
|
135 |
+
# assert False, 'pretrained image towers currently only supported for timm models'
|
136 |
+
model_cfg["text_cfg"]["model_type"] = tmodel_name
|
137 |
+
model_cfg["enable_fusion"] = enable_fusion
|
138 |
+
model_cfg["fusion_type"] = fusion_type
|
139 |
+
model = CLAP(**model_cfg)
|
140 |
+
|
141 |
+
if pretrained:
|
142 |
+
checkpoint_path = ""
|
143 |
+
url = get_pretrained_url(amodel_name, pretrained)
|
144 |
+
if url:
|
145 |
+
checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
|
146 |
+
elif os.path.exists(pretrained_orig):
|
147 |
+
checkpoint_path = pretrained_orig
|
148 |
+
if checkpoint_path:
|
149 |
+
logging.info(
|
150 |
+
f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
|
151 |
+
)
|
152 |
+
ckpt = load_state_dict(checkpoint_path, skip_params=True)
|
153 |
+
model.load_state_dict(ckpt)
|
154 |
+
param_names = [n for n, p in model.named_parameters()]
|
155 |
+
# for n in param_names:
|
156 |
+
# print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
|
157 |
+
else:
|
158 |
+
logging.warning(
|
159 |
+
f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
|
160 |
+
)
|
161 |
+
raise RuntimeError(
|
162 |
+
f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
|
163 |
+
)
|
164 |
+
|
165 |
+
if pretrained_audio:
|
166 |
+
if amodel_name.startswith("PANN"):
|
167 |
+
if "Cnn14_mAP" in pretrained_audio: # official checkpoint
|
168 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
169 |
+
audio_ckpt = audio_ckpt["model"]
|
170 |
+
keys = list(audio_ckpt.keys())
|
171 |
+
for key in keys:
|
172 |
+
if (
|
173 |
+
"spectrogram_extractor" not in key
|
174 |
+
and "logmel_extractor" not in key
|
175 |
+
):
|
176 |
+
v = audio_ckpt.pop(key)
|
177 |
+
audio_ckpt["audio_branch." + key] = v
|
178 |
+
elif os.path.basename(pretrained_audio).startswith(
|
179 |
+
"PANN"
|
180 |
+
): # checkpoint trained via HTSAT codebase
|
181 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
182 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
183 |
+
keys = list(audio_ckpt.keys())
|
184 |
+
for key in keys:
|
185 |
+
if key.startswith("sed_model"):
|
186 |
+
v = audio_ckpt.pop(key)
|
187 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
188 |
+
elif os.path.basename(pretrained_audio).startswith(
|
189 |
+
"finetuned"
|
190 |
+
): # checkpoint trained via linear probe codebase
|
191 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
192 |
+
else:
|
193 |
+
raise ValueError("Unknown audio checkpoint")
|
194 |
+
elif amodel_name.startswith("HTSAT"):
|
195 |
+
if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
|
196 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
197 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
198 |
+
keys = list(audio_ckpt.keys())
|
199 |
+
for key in keys:
|
200 |
+
if key.startswith("sed_model") and (
|
201 |
+
"spectrogram_extractor" not in key
|
202 |
+
and "logmel_extractor" not in key
|
203 |
+
):
|
204 |
+
v = audio_ckpt.pop(key)
|
205 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
206 |
+
elif os.path.basename(pretrained_audio).startswith(
|
207 |
+
"HTSAT"
|
208 |
+
): # checkpoint trained via HTSAT codebase
|
209 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
210 |
+
audio_ckpt = audio_ckpt["state_dict"]
|
211 |
+
keys = list(audio_ckpt.keys())
|
212 |
+
for key in keys:
|
213 |
+
if key.startswith("sed_model"):
|
214 |
+
v = audio_ckpt.pop(key)
|
215 |
+
audio_ckpt["audio_branch." + key[10:]] = v
|
216 |
+
elif os.path.basename(pretrained_audio).startswith(
|
217 |
+
"finetuned"
|
218 |
+
): # checkpoint trained via linear probe codebase
|
219 |
+
audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
|
220 |
+
else:
|
221 |
+
raise ValueError("Unknown audio checkpoint")
|
222 |
+
else:
|
223 |
+
raise f"this audio encoder pretrained checkpoint is not support"
|
224 |
+
|
225 |
+
model.load_state_dict(audio_ckpt, strict=False)
|
226 |
+
logging.info(
|
227 |
+
f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
|
228 |
+
)
|
229 |
+
param_names = [n for n, p in model.named_parameters()]
|
230 |
+
for n in param_names:
|
231 |
+
print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
|
232 |
+
|
233 |
+
model.to(device=device)
|
234 |
+
if precision == "fp16":
|
235 |
+
assert device.type != "cpu"
|
236 |
+
convert_weights_to_fp16(model)
|
237 |
+
|
238 |
+
if jit:
|
239 |
+
model = torch.jit.script(model)
|
240 |
+
|
241 |
+
return model, model_cfg
|
242 |
+
|
243 |
+
|
244 |
+
def create_model_and_transforms(
|
245 |
+
model_name: str,
|
246 |
+
pretrained: str = "",
|
247 |
+
precision: str = "fp32",
|
248 |
+
device: torch.device = torch.device("cpu"),
|
249 |
+
jit: bool = False,
|
250 |
+
force_quick_gelu: bool = False,
|
251 |
+
# pretrained_image: bool = False,
|
252 |
+
):
|
253 |
+
model = create_model(
|
254 |
+
model_name,
|
255 |
+
pretrained,
|
256 |
+
precision,
|
257 |
+
device,
|
258 |
+
jit,
|
259 |
+
force_quick_gelu=force_quick_gelu,
|
260 |
+
# pretrained_image=pretrained_image
|
261 |
+
)
|
262 |
+
preprocess_train = image_transform(model.visual.image_size, is_train=True)
|
263 |
+
preprocess_val = image_transform(model.visual.image_size, is_train=False)
|
264 |
+
return model, preprocess_train, preprocess_val
|
265 |
+
|
266 |
+
|
267 |
+
def list_models():
|
268 |
+
"""enumerate available model architectures based on config files"""
|
269 |
+
return list(_MODEL_CONFIGS.keys())
|
270 |
+
|
271 |
+
|
272 |
+
def add_model_config(path):
|
273 |
+
"""add model config path or file and update registry"""
|
274 |
+
if not isinstance(path, Path):
|
275 |
+
path = Path(path)
|
276 |
+
_MODEL_CONFIG_PATHS.append(path)
|
277 |
+
_rescan_model_configs()
|
audioldm/clap/open_clip/feature_fusion.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Feature Fusion for Varible-Length Data Processing
|
3 |
+
AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
|
4 |
+
According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class DAF(nn.Module):
|
12 |
+
"""
|
13 |
+
直接相加 DirectAddFuse
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super(DAF, self).__init__()
|
18 |
+
|
19 |
+
def forward(self, x, residual):
|
20 |
+
return x + residual
|
21 |
+
|
22 |
+
|
23 |
+
class iAFF(nn.Module):
|
24 |
+
"""
|
25 |
+
多特征融合 iAFF
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, channels=64, r=4, type="2D"):
|
29 |
+
super(iAFF, self).__init__()
|
30 |
+
inter_channels = int(channels // r)
|
31 |
+
|
32 |
+
if type == "1D":
|
33 |
+
# 本地注意力
|
34 |
+
self.local_att = nn.Sequential(
|
35 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
36 |
+
nn.BatchNorm1d(inter_channels),
|
37 |
+
nn.ReLU(inplace=True),
|
38 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
39 |
+
nn.BatchNorm1d(channels),
|
40 |
+
)
|
41 |
+
|
42 |
+
# 全局注意力
|
43 |
+
self.global_att = nn.Sequential(
|
44 |
+
nn.AdaptiveAvgPool1d(1),
|
45 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
46 |
+
nn.BatchNorm1d(inter_channels),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
49 |
+
nn.BatchNorm1d(channels),
|
50 |
+
)
|
51 |
+
|
52 |
+
# 第二次本地注意力
|
53 |
+
self.local_att2 = nn.Sequential(
|
54 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
55 |
+
nn.BatchNorm1d(inter_channels),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
58 |
+
nn.BatchNorm1d(channels),
|
59 |
+
)
|
60 |
+
# 第二次全局注意力
|
61 |
+
self.global_att2 = nn.Sequential(
|
62 |
+
nn.AdaptiveAvgPool1d(1),
|
63 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
64 |
+
nn.BatchNorm1d(inter_channels),
|
65 |
+
nn.ReLU(inplace=True),
|
66 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
67 |
+
nn.BatchNorm1d(channels),
|
68 |
+
)
|
69 |
+
elif type == "2D":
|
70 |
+
# 本地注意力
|
71 |
+
self.local_att = nn.Sequential(
|
72 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
73 |
+
nn.BatchNorm2d(inter_channels),
|
74 |
+
nn.ReLU(inplace=True),
|
75 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
76 |
+
nn.BatchNorm2d(channels),
|
77 |
+
)
|
78 |
+
|
79 |
+
# 全局注意力
|
80 |
+
self.global_att = nn.Sequential(
|
81 |
+
nn.AdaptiveAvgPool2d(1),
|
82 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
83 |
+
nn.BatchNorm2d(inter_channels),
|
84 |
+
nn.ReLU(inplace=True),
|
85 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
86 |
+
nn.BatchNorm2d(channels),
|
87 |
+
)
|
88 |
+
|
89 |
+
# 第二次本地注意力
|
90 |
+
self.local_att2 = nn.Sequential(
|
91 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
92 |
+
nn.BatchNorm2d(inter_channels),
|
93 |
+
nn.ReLU(inplace=True),
|
94 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
95 |
+
nn.BatchNorm2d(channels),
|
96 |
+
)
|
97 |
+
# 第二次全局注意力
|
98 |
+
self.global_att2 = nn.Sequential(
|
99 |
+
nn.AdaptiveAvgPool2d(1),
|
100 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
101 |
+
nn.BatchNorm2d(inter_channels),
|
102 |
+
nn.ReLU(inplace=True),
|
103 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
104 |
+
nn.BatchNorm2d(channels),
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
raise f"the type is not supported"
|
108 |
+
|
109 |
+
self.sigmoid = nn.Sigmoid()
|
110 |
+
|
111 |
+
def forward(self, x, residual):
|
112 |
+
flag = False
|
113 |
+
xa = x + residual
|
114 |
+
if xa.size(0) == 1:
|
115 |
+
xa = torch.cat([xa, xa], dim=0)
|
116 |
+
flag = True
|
117 |
+
xl = self.local_att(xa)
|
118 |
+
xg = self.global_att(xa)
|
119 |
+
xlg = xl + xg
|
120 |
+
wei = self.sigmoid(xlg)
|
121 |
+
xi = x * wei + residual * (1 - wei)
|
122 |
+
|
123 |
+
xl2 = self.local_att2(xi)
|
124 |
+
xg2 = self.global_att(xi)
|
125 |
+
xlg2 = xl2 + xg2
|
126 |
+
wei2 = self.sigmoid(xlg2)
|
127 |
+
xo = x * wei2 + residual * (1 - wei2)
|
128 |
+
if flag:
|
129 |
+
xo = xo[0].unsqueeze(0)
|
130 |
+
return xo
|
131 |
+
|
132 |
+
|
133 |
+
class AFF(nn.Module):
|
134 |
+
"""
|
135 |
+
多特征融合 AFF
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, channels=64, r=4, type="2D"):
|
139 |
+
super(AFF, self).__init__()
|
140 |
+
inter_channels = int(channels // r)
|
141 |
+
|
142 |
+
if type == "1D":
|
143 |
+
self.local_att = nn.Sequential(
|
144 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
145 |
+
nn.BatchNorm1d(inter_channels),
|
146 |
+
nn.ReLU(inplace=True),
|
147 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
148 |
+
nn.BatchNorm1d(channels),
|
149 |
+
)
|
150 |
+
self.global_att = nn.Sequential(
|
151 |
+
nn.AdaptiveAvgPool1d(1),
|
152 |
+
nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
153 |
+
nn.BatchNorm1d(inter_channels),
|
154 |
+
nn.ReLU(inplace=True),
|
155 |
+
nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
156 |
+
nn.BatchNorm1d(channels),
|
157 |
+
)
|
158 |
+
elif type == "2D":
|
159 |
+
self.local_att = nn.Sequential(
|
160 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
161 |
+
nn.BatchNorm2d(inter_channels),
|
162 |
+
nn.ReLU(inplace=True),
|
163 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
164 |
+
nn.BatchNorm2d(channels),
|
165 |
+
)
|
166 |
+
self.global_att = nn.Sequential(
|
167 |
+
nn.AdaptiveAvgPool2d(1),
|
168 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
169 |
+
nn.BatchNorm2d(inter_channels),
|
170 |
+
nn.ReLU(inplace=True),
|
171 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
172 |
+
nn.BatchNorm2d(channels),
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
raise f"the type is not supported."
|
176 |
+
|
177 |
+
self.sigmoid = nn.Sigmoid()
|
178 |
+
|
179 |
+
def forward(self, x, residual):
|
180 |
+
flag = False
|
181 |
+
xa = x + residual
|
182 |
+
if xa.size(0) == 1:
|
183 |
+
xa = torch.cat([xa, xa], dim=0)
|
184 |
+
flag = True
|
185 |
+
xl = self.local_att(xa)
|
186 |
+
xg = self.global_att(xa)
|
187 |
+
xlg = xl + xg
|
188 |
+
wei = self.sigmoid(xlg)
|
189 |
+
xo = 2 * x * wei + 2 * residual * (1 - wei)
|
190 |
+
if flag:
|
191 |
+
xo = xo[0].unsqueeze(0)
|
192 |
+
return xo
|
audioldm/clap/open_clip/htsat.py
ADDED
@@ -0,0 +1,1308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
|
4 |
+
# Some layers designed on the model
|
5 |
+
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
|
6 |
+
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from itertools import repeat
|
12 |
+
import collections.abc
|
13 |
+
import math
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
17 |
+
import torch.utils.checkpoint as checkpoint
|
18 |
+
|
19 |
+
import random
|
20 |
+
|
21 |
+
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
22 |
+
from torchlibrosa.augmentation import SpecAugmentation
|
23 |
+
|
24 |
+
from itertools import repeat
|
25 |
+
from .utils import do_mixup, interpolate
|
26 |
+
|
27 |
+
from .feature_fusion import iAFF, AFF, DAF
|
28 |
+
|
29 |
+
# from PyTorch internals
|
30 |
+
def _ntuple(n):
|
31 |
+
def parse(x):
|
32 |
+
if isinstance(x, collections.abc.Iterable):
|
33 |
+
return x
|
34 |
+
return tuple(repeat(x, n))
|
35 |
+
|
36 |
+
return parse
|
37 |
+
|
38 |
+
|
39 |
+
to_1tuple = _ntuple(1)
|
40 |
+
to_2tuple = _ntuple(2)
|
41 |
+
to_3tuple = _ntuple(3)
|
42 |
+
to_4tuple = _ntuple(4)
|
43 |
+
to_ntuple = _ntuple
|
44 |
+
|
45 |
+
|
46 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
47 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
48 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
49 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
50 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
51 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
52 |
+
'survival rate' as the argument.
|
53 |
+
"""
|
54 |
+
if drop_prob == 0.0 or not training:
|
55 |
+
return x
|
56 |
+
keep_prob = 1 - drop_prob
|
57 |
+
shape = (x.shape[0],) + (1,) * (
|
58 |
+
x.ndim - 1
|
59 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
60 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
61 |
+
random_tensor.floor_() # binarize
|
62 |
+
output = x.div(keep_prob) * random_tensor
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class DropPath(nn.Module):
|
67 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
68 |
+
|
69 |
+
def __init__(self, drop_prob=None):
|
70 |
+
super(DropPath, self).__init__()
|
71 |
+
self.drop_prob = drop_prob
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
return drop_path(x, self.drop_prob, self.training)
|
75 |
+
|
76 |
+
|
77 |
+
class PatchEmbed(nn.Module):
|
78 |
+
"""2D Image to Patch Embedding"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
img_size=224,
|
83 |
+
patch_size=16,
|
84 |
+
in_chans=3,
|
85 |
+
embed_dim=768,
|
86 |
+
norm_layer=None,
|
87 |
+
flatten=True,
|
88 |
+
patch_stride=16,
|
89 |
+
enable_fusion=False,
|
90 |
+
fusion_type="None",
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
img_size = to_2tuple(img_size)
|
94 |
+
patch_size = to_2tuple(patch_size)
|
95 |
+
patch_stride = to_2tuple(patch_stride)
|
96 |
+
self.img_size = img_size
|
97 |
+
self.patch_size = patch_size
|
98 |
+
self.patch_stride = patch_stride
|
99 |
+
self.grid_size = (
|
100 |
+
img_size[0] // patch_stride[0],
|
101 |
+
img_size[1] // patch_stride[1],
|
102 |
+
)
|
103 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
104 |
+
self.flatten = flatten
|
105 |
+
self.in_chans = in_chans
|
106 |
+
self.embed_dim = embed_dim
|
107 |
+
|
108 |
+
self.enable_fusion = enable_fusion
|
109 |
+
self.fusion_type = fusion_type
|
110 |
+
|
111 |
+
padding = (
|
112 |
+
(patch_size[0] - patch_stride[0]) // 2,
|
113 |
+
(patch_size[1] - patch_stride[1]) // 2,
|
114 |
+
)
|
115 |
+
|
116 |
+
if (self.enable_fusion) and (self.fusion_type == "channel_map"):
|
117 |
+
self.proj = nn.Conv2d(
|
118 |
+
in_chans * 4,
|
119 |
+
embed_dim,
|
120 |
+
kernel_size=patch_size,
|
121 |
+
stride=patch_stride,
|
122 |
+
padding=padding,
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
self.proj = nn.Conv2d(
|
126 |
+
in_chans,
|
127 |
+
embed_dim,
|
128 |
+
kernel_size=patch_size,
|
129 |
+
stride=patch_stride,
|
130 |
+
padding=padding,
|
131 |
+
)
|
132 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
133 |
+
|
134 |
+
if (self.enable_fusion) and (
|
135 |
+
self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
|
136 |
+
):
|
137 |
+
self.mel_conv2d = nn.Conv2d(
|
138 |
+
in_chans,
|
139 |
+
embed_dim,
|
140 |
+
kernel_size=(patch_size[0], patch_size[1] * 3),
|
141 |
+
stride=(patch_stride[0], patch_stride[1] * 3),
|
142 |
+
padding=padding,
|
143 |
+
)
|
144 |
+
if self.fusion_type == "daf_2d":
|
145 |
+
self.fusion_model = DAF()
|
146 |
+
elif self.fusion_type == "aff_2d":
|
147 |
+
self.fusion_model = AFF(channels=embed_dim, type="2D")
|
148 |
+
elif self.fusion_type == "iaff_2d":
|
149 |
+
self.fusion_model = iAFF(channels=embed_dim, type="2D")
|
150 |
+
|
151 |
+
def forward(self, x, longer_idx=None):
|
152 |
+
if (self.enable_fusion) and (
|
153 |
+
self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
|
154 |
+
):
|
155 |
+
global_x = x[:, 0:1, :, :]
|
156 |
+
|
157 |
+
# global processing
|
158 |
+
B, C, H, W = global_x.shape
|
159 |
+
assert (
|
160 |
+
H == self.img_size[0] and W == self.img_size[1]
|
161 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
162 |
+
global_x = self.proj(global_x)
|
163 |
+
TW = global_x.size(-1)
|
164 |
+
if len(longer_idx) > 0:
|
165 |
+
# local processing
|
166 |
+
local_x = x[longer_idx, 1:, :, :].contiguous()
|
167 |
+
B, C, H, W = local_x.shape
|
168 |
+
local_x = local_x.view(B * C, 1, H, W)
|
169 |
+
local_x = self.mel_conv2d(local_x)
|
170 |
+
local_x = local_x.view(
|
171 |
+
B, C, local_x.size(1), local_x.size(2), local_x.size(3)
|
172 |
+
)
|
173 |
+
local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
|
174 |
+
TB, TC, TH, _ = local_x.size()
|
175 |
+
if local_x.size(-1) < TW:
|
176 |
+
local_x = torch.cat(
|
177 |
+
[
|
178 |
+
local_x,
|
179 |
+
torch.zeros(
|
180 |
+
(TB, TC, TH, TW - local_x.size(-1)),
|
181 |
+
device=global_x.device,
|
182 |
+
),
|
183 |
+
],
|
184 |
+
dim=-1,
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
local_x = local_x[:, :, :, :TW]
|
188 |
+
|
189 |
+
global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
|
190 |
+
x = global_x
|
191 |
+
else:
|
192 |
+
B, C, H, W = x.shape
|
193 |
+
assert (
|
194 |
+
H == self.img_size[0] and W == self.img_size[1]
|
195 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
196 |
+
x = self.proj(x)
|
197 |
+
|
198 |
+
if self.flatten:
|
199 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
200 |
+
x = self.norm(x)
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
class Mlp(nn.Module):
|
205 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
206 |
+
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
in_features,
|
210 |
+
hidden_features=None,
|
211 |
+
out_features=None,
|
212 |
+
act_layer=nn.GELU,
|
213 |
+
drop=0.0,
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
out_features = out_features or in_features
|
217 |
+
hidden_features = hidden_features or in_features
|
218 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
219 |
+
self.act = act_layer()
|
220 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
221 |
+
self.drop = nn.Dropout(drop)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
x = self.fc1(x)
|
225 |
+
x = self.act(x)
|
226 |
+
x = self.drop(x)
|
227 |
+
x = self.fc2(x)
|
228 |
+
x = self.drop(x)
|
229 |
+
return x
|
230 |
+
|
231 |
+
|
232 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
233 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
234 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
235 |
+
def norm_cdf(x):
|
236 |
+
# Computes standard normal cumulative distribution function
|
237 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
238 |
+
|
239 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
240 |
+
warnings.warn(
|
241 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
242 |
+
"The distribution of values may be incorrect.",
|
243 |
+
stacklevel=2,
|
244 |
+
)
|
245 |
+
|
246 |
+
with torch.no_grad():
|
247 |
+
# Values are generated by using a truncated uniform distribution and
|
248 |
+
# then using the inverse CDF for the normal distribution.
|
249 |
+
# Get upper and lower cdf values
|
250 |
+
l = norm_cdf((a - mean) / std)
|
251 |
+
u = norm_cdf((b - mean) / std)
|
252 |
+
|
253 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
254 |
+
# [2l-1, 2u-1].
|
255 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
256 |
+
|
257 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
258 |
+
# standard normal
|
259 |
+
tensor.erfinv_()
|
260 |
+
|
261 |
+
# Transform to proper mean, std
|
262 |
+
tensor.mul_(std * math.sqrt(2.0))
|
263 |
+
tensor.add_(mean)
|
264 |
+
|
265 |
+
# Clamp to ensure it's in the proper range
|
266 |
+
tensor.clamp_(min=a, max=b)
|
267 |
+
return tensor
|
268 |
+
|
269 |
+
|
270 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
271 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
272 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
273 |
+
normal distribution. The values are effectively drawn from the
|
274 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
275 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
276 |
+
the bounds. The method used for generating the random values works
|
277 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
278 |
+
Args:
|
279 |
+
tensor: an n-dimensional `torch.Tensor`
|
280 |
+
mean: the mean of the normal distribution
|
281 |
+
std: the standard deviation of the normal distribution
|
282 |
+
a: the minimum cutoff value
|
283 |
+
b: the maximum cutoff value
|
284 |
+
Examples:
|
285 |
+
>>> w = torch.empty(3, 5)
|
286 |
+
>>> nn.init.trunc_normal_(w)
|
287 |
+
"""
|
288 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
289 |
+
|
290 |
+
|
291 |
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
292 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
293 |
+
if mode == "fan_in":
|
294 |
+
denom = fan_in
|
295 |
+
elif mode == "fan_out":
|
296 |
+
denom = fan_out
|
297 |
+
elif mode == "fan_avg":
|
298 |
+
denom = (fan_in + fan_out) / 2
|
299 |
+
|
300 |
+
variance = scale / denom
|
301 |
+
|
302 |
+
if distribution == "truncated_normal":
|
303 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
304 |
+
trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
305 |
+
elif distribution == "normal":
|
306 |
+
tensor.normal_(std=math.sqrt(variance))
|
307 |
+
elif distribution == "uniform":
|
308 |
+
bound = math.sqrt(3 * variance)
|
309 |
+
tensor.uniform_(-bound, bound)
|
310 |
+
else:
|
311 |
+
raise ValueError(f"invalid distribution {distribution}")
|
312 |
+
|
313 |
+
|
314 |
+
def lecun_normal_(tensor):
|
315 |
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
316 |
+
|
317 |
+
|
318 |
+
def window_partition(x, window_size):
|
319 |
+
"""
|
320 |
+
Args:
|
321 |
+
x: (B, H, W, C)
|
322 |
+
window_size (int): window size
|
323 |
+
Returns:
|
324 |
+
windows: (num_windows*B, window_size, window_size, C)
|
325 |
+
"""
|
326 |
+
B, H, W, C = x.shape
|
327 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
328 |
+
windows = (
|
329 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
330 |
+
)
|
331 |
+
return windows
|
332 |
+
|
333 |
+
|
334 |
+
def window_reverse(windows, window_size, H, W):
|
335 |
+
"""
|
336 |
+
Args:
|
337 |
+
windows: (num_windows*B, window_size, window_size, C)
|
338 |
+
window_size (int): Window size
|
339 |
+
H (int): Height of image
|
340 |
+
W (int): Width of image
|
341 |
+
Returns:
|
342 |
+
x: (B, H, W, C)
|
343 |
+
"""
|
344 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
345 |
+
x = windows.view(
|
346 |
+
B, H // window_size, W // window_size, window_size, window_size, -1
|
347 |
+
)
|
348 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
349 |
+
return x
|
350 |
+
|
351 |
+
|
352 |
+
class WindowAttention(nn.Module):
|
353 |
+
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
354 |
+
It supports both of shifted and non-shifted window.
|
355 |
+
Args:
|
356 |
+
dim (int): Number of input channels.
|
357 |
+
window_size (tuple[int]): The height and width of the window.
|
358 |
+
num_heads (int): Number of attention heads.
|
359 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
360 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
361 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
362 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
363 |
+
"""
|
364 |
+
|
365 |
+
def __init__(
|
366 |
+
self,
|
367 |
+
dim,
|
368 |
+
window_size,
|
369 |
+
num_heads,
|
370 |
+
qkv_bias=True,
|
371 |
+
qk_scale=None,
|
372 |
+
attn_drop=0.0,
|
373 |
+
proj_drop=0.0,
|
374 |
+
):
|
375 |
+
|
376 |
+
super().__init__()
|
377 |
+
self.dim = dim
|
378 |
+
self.window_size = window_size # Wh, Ww
|
379 |
+
self.num_heads = num_heads
|
380 |
+
head_dim = dim // num_heads
|
381 |
+
self.scale = qk_scale or head_dim**-0.5
|
382 |
+
|
383 |
+
# define a parameter table of relative position bias
|
384 |
+
self.relative_position_bias_table = nn.Parameter(
|
385 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
386 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
387 |
+
|
388 |
+
# get pair-wise relative position index for each token inside the window
|
389 |
+
coords_h = torch.arange(self.window_size[0])
|
390 |
+
coords_w = torch.arange(self.window_size[1])
|
391 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
392 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
393 |
+
relative_coords = (
|
394 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
395 |
+
) # 2, Wh*Ww, Wh*Ww
|
396 |
+
relative_coords = relative_coords.permute(
|
397 |
+
1, 2, 0
|
398 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
399 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
400 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
401 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
402 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
403 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
404 |
+
|
405 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
406 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
407 |
+
self.proj = nn.Linear(dim, dim)
|
408 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
409 |
+
|
410 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
411 |
+
self.softmax = nn.Softmax(dim=-1)
|
412 |
+
|
413 |
+
def forward(self, x, mask=None):
|
414 |
+
"""
|
415 |
+
Args:
|
416 |
+
x: input features with shape of (num_windows*B, N, C)
|
417 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
418 |
+
"""
|
419 |
+
B_, N, C = x.shape
|
420 |
+
qkv = (
|
421 |
+
self.qkv(x)
|
422 |
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
423 |
+
.permute(2, 0, 3, 1, 4)
|
424 |
+
)
|
425 |
+
q, k, v = (
|
426 |
+
qkv[0],
|
427 |
+
qkv[1],
|
428 |
+
qkv[2],
|
429 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
430 |
+
|
431 |
+
q = q * self.scale
|
432 |
+
attn = q @ k.transpose(-2, -1)
|
433 |
+
|
434 |
+
relative_position_bias = self.relative_position_bias_table[
|
435 |
+
self.relative_position_index.view(-1)
|
436 |
+
].view(
|
437 |
+
self.window_size[0] * self.window_size[1],
|
438 |
+
self.window_size[0] * self.window_size[1],
|
439 |
+
-1,
|
440 |
+
) # Wh*Ww,Wh*Ww,nH
|
441 |
+
relative_position_bias = relative_position_bias.permute(
|
442 |
+
2, 0, 1
|
443 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
444 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
445 |
+
|
446 |
+
if mask is not None:
|
447 |
+
nW = mask.shape[0]
|
448 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
449 |
+
1
|
450 |
+
).unsqueeze(0)
|
451 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
452 |
+
attn = self.softmax(attn)
|
453 |
+
else:
|
454 |
+
attn = self.softmax(attn)
|
455 |
+
|
456 |
+
attn = self.attn_drop(attn)
|
457 |
+
|
458 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
459 |
+
x = self.proj(x)
|
460 |
+
x = self.proj_drop(x)
|
461 |
+
return x, attn
|
462 |
+
|
463 |
+
def extra_repr(self):
|
464 |
+
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
|
465 |
+
|
466 |
+
|
467 |
+
# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
|
468 |
+
class SwinTransformerBlock(nn.Module):
|
469 |
+
r"""Swin Transformer Block.
|
470 |
+
Args:
|
471 |
+
dim (int): Number of input channels.
|
472 |
+
input_resolution (tuple[int]): Input resulotion.
|
473 |
+
num_heads (int): Number of attention heads.
|
474 |
+
window_size (int): Window size.
|
475 |
+
shift_size (int): Shift size for SW-MSA.
|
476 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
477 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
478 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
479 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
480 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
481 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
482 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
483 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
484 |
+
"""
|
485 |
+
|
486 |
+
def __init__(
|
487 |
+
self,
|
488 |
+
dim,
|
489 |
+
input_resolution,
|
490 |
+
num_heads,
|
491 |
+
window_size=7,
|
492 |
+
shift_size=0,
|
493 |
+
mlp_ratio=4.0,
|
494 |
+
qkv_bias=True,
|
495 |
+
qk_scale=None,
|
496 |
+
drop=0.0,
|
497 |
+
attn_drop=0.0,
|
498 |
+
drop_path=0.0,
|
499 |
+
act_layer=nn.GELU,
|
500 |
+
norm_layer=nn.LayerNorm,
|
501 |
+
norm_before_mlp="ln",
|
502 |
+
):
|
503 |
+
super().__init__()
|
504 |
+
self.dim = dim
|
505 |
+
self.input_resolution = input_resolution
|
506 |
+
self.num_heads = num_heads
|
507 |
+
self.window_size = window_size
|
508 |
+
self.shift_size = shift_size
|
509 |
+
self.mlp_ratio = mlp_ratio
|
510 |
+
self.norm_before_mlp = norm_before_mlp
|
511 |
+
if min(self.input_resolution) <= self.window_size:
|
512 |
+
# if window size is larger than input resolution, we don't partition windows
|
513 |
+
self.shift_size = 0
|
514 |
+
self.window_size = min(self.input_resolution)
|
515 |
+
assert (
|
516 |
+
0 <= self.shift_size < self.window_size
|
517 |
+
), "shift_size must in 0-window_size"
|
518 |
+
|
519 |
+
self.norm1 = norm_layer(dim)
|
520 |
+
self.attn = WindowAttention(
|
521 |
+
dim,
|
522 |
+
window_size=to_2tuple(self.window_size),
|
523 |
+
num_heads=num_heads,
|
524 |
+
qkv_bias=qkv_bias,
|
525 |
+
qk_scale=qk_scale,
|
526 |
+
attn_drop=attn_drop,
|
527 |
+
proj_drop=drop,
|
528 |
+
)
|
529 |
+
|
530 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
531 |
+
if self.norm_before_mlp == "ln":
|
532 |
+
self.norm2 = nn.LayerNorm(dim)
|
533 |
+
elif self.norm_before_mlp == "bn":
|
534 |
+
self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
|
535 |
+
1, 2
|
536 |
+
)
|
537 |
+
else:
|
538 |
+
raise NotImplementedError
|
539 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
540 |
+
self.mlp = Mlp(
|
541 |
+
in_features=dim,
|
542 |
+
hidden_features=mlp_hidden_dim,
|
543 |
+
act_layer=act_layer,
|
544 |
+
drop=drop,
|
545 |
+
)
|
546 |
+
|
547 |
+
if self.shift_size > 0:
|
548 |
+
# calculate attention mask for SW-MSA
|
549 |
+
H, W = self.input_resolution
|
550 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
551 |
+
h_slices = (
|
552 |
+
slice(0, -self.window_size),
|
553 |
+
slice(-self.window_size, -self.shift_size),
|
554 |
+
slice(-self.shift_size, None),
|
555 |
+
)
|
556 |
+
w_slices = (
|
557 |
+
slice(0, -self.window_size),
|
558 |
+
slice(-self.window_size, -self.shift_size),
|
559 |
+
slice(-self.shift_size, None),
|
560 |
+
)
|
561 |
+
cnt = 0
|
562 |
+
for h in h_slices:
|
563 |
+
for w in w_slices:
|
564 |
+
img_mask[:, h, w, :] = cnt
|
565 |
+
cnt += 1
|
566 |
+
|
567 |
+
mask_windows = window_partition(
|
568 |
+
img_mask, self.window_size
|
569 |
+
) # nW, window_size, window_size, 1
|
570 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
571 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
572 |
+
attn_mask = attn_mask.masked_fill(
|
573 |
+
attn_mask != 0, float(-100.0)
|
574 |
+
).masked_fill(attn_mask == 0, float(0.0))
|
575 |
+
else:
|
576 |
+
attn_mask = None
|
577 |
+
|
578 |
+
self.register_buffer("attn_mask", attn_mask)
|
579 |
+
|
580 |
+
def forward(self, x):
|
581 |
+
# pdb.set_trace()
|
582 |
+
H, W = self.input_resolution
|
583 |
+
# print("H: ", H)
|
584 |
+
# print("W: ", W)
|
585 |
+
# pdb.set_trace()
|
586 |
+
B, L, C = x.shape
|
587 |
+
# assert L == H * W, "input feature has wrong size"
|
588 |
+
|
589 |
+
shortcut = x
|
590 |
+
x = self.norm1(x)
|
591 |
+
x = x.view(B, H, W, C)
|
592 |
+
|
593 |
+
# cyclic shift
|
594 |
+
if self.shift_size > 0:
|
595 |
+
shifted_x = torch.roll(
|
596 |
+
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
597 |
+
)
|
598 |
+
else:
|
599 |
+
shifted_x = x
|
600 |
+
|
601 |
+
# partition windows
|
602 |
+
x_windows = window_partition(
|
603 |
+
shifted_x, self.window_size
|
604 |
+
) # nW*B, window_size, window_size, C
|
605 |
+
x_windows = x_windows.view(
|
606 |
+
-1, self.window_size * self.window_size, C
|
607 |
+
) # nW*B, window_size*window_size, C
|
608 |
+
|
609 |
+
# W-MSA/SW-MSA
|
610 |
+
attn_windows, attn = self.attn(
|
611 |
+
x_windows, mask=self.attn_mask
|
612 |
+
) # nW*B, window_size*window_size, C
|
613 |
+
|
614 |
+
# merge windows
|
615 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
616 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
617 |
+
|
618 |
+
# reverse cyclic shift
|
619 |
+
if self.shift_size > 0:
|
620 |
+
x = torch.roll(
|
621 |
+
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
|
622 |
+
)
|
623 |
+
else:
|
624 |
+
x = shifted_x
|
625 |
+
x = x.view(B, H * W, C)
|
626 |
+
|
627 |
+
# FFN
|
628 |
+
x = shortcut + self.drop_path(x)
|
629 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
630 |
+
|
631 |
+
return x, attn
|
632 |
+
|
633 |
+
def extra_repr(self):
|
634 |
+
return (
|
635 |
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
636 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
637 |
+
)
|
638 |
+
|
639 |
+
|
640 |
+
class PatchMerging(nn.Module):
|
641 |
+
r"""Patch Merging Layer.
|
642 |
+
Args:
|
643 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
644 |
+
dim (int): Number of input channels.
|
645 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
646 |
+
"""
|
647 |
+
|
648 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
649 |
+
super().__init__()
|
650 |
+
self.input_resolution = input_resolution
|
651 |
+
self.dim = dim
|
652 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
653 |
+
self.norm = norm_layer(4 * dim)
|
654 |
+
|
655 |
+
def forward(self, x):
|
656 |
+
"""
|
657 |
+
x: B, H*W, C
|
658 |
+
"""
|
659 |
+
H, W = self.input_resolution
|
660 |
+
B, L, C = x.shape
|
661 |
+
assert L == H * W, "input feature has wrong size"
|
662 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
663 |
+
|
664 |
+
x = x.view(B, H, W, C)
|
665 |
+
|
666 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
667 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
668 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
669 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
670 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
671 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
672 |
+
|
673 |
+
x = self.norm(x)
|
674 |
+
x = self.reduction(x)
|
675 |
+
|
676 |
+
return x
|
677 |
+
|
678 |
+
def extra_repr(self):
|
679 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
680 |
+
|
681 |
+
|
682 |
+
class BasicLayer(nn.Module):
|
683 |
+
"""A basic Swin Transformer layer for one stage.
|
684 |
+
Args:
|
685 |
+
dim (int): Number of input channels.
|
686 |
+
input_resolution (tuple[int]): Input resolution.
|
687 |
+
depth (int): Number of blocks.
|
688 |
+
num_heads (int): Number of attention heads.
|
689 |
+
window_size (int): Local window size.
|
690 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
691 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
692 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
693 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
694 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
695 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
696 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
697 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
698 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
699 |
+
"""
|
700 |
+
|
701 |
+
def __init__(
|
702 |
+
self,
|
703 |
+
dim,
|
704 |
+
input_resolution,
|
705 |
+
depth,
|
706 |
+
num_heads,
|
707 |
+
window_size,
|
708 |
+
mlp_ratio=4.0,
|
709 |
+
qkv_bias=True,
|
710 |
+
qk_scale=None,
|
711 |
+
drop=0.0,
|
712 |
+
attn_drop=0.0,
|
713 |
+
drop_path=0.0,
|
714 |
+
norm_layer=nn.LayerNorm,
|
715 |
+
downsample=None,
|
716 |
+
use_checkpoint=False,
|
717 |
+
norm_before_mlp="ln",
|
718 |
+
):
|
719 |
+
|
720 |
+
super().__init__()
|
721 |
+
self.dim = dim
|
722 |
+
self.input_resolution = input_resolution
|
723 |
+
self.depth = depth
|
724 |
+
self.use_checkpoint = use_checkpoint
|
725 |
+
|
726 |
+
# build blocks
|
727 |
+
self.blocks = nn.ModuleList(
|
728 |
+
[
|
729 |
+
SwinTransformerBlock(
|
730 |
+
dim=dim,
|
731 |
+
input_resolution=input_resolution,
|
732 |
+
num_heads=num_heads,
|
733 |
+
window_size=window_size,
|
734 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
735 |
+
mlp_ratio=mlp_ratio,
|
736 |
+
qkv_bias=qkv_bias,
|
737 |
+
qk_scale=qk_scale,
|
738 |
+
drop=drop,
|
739 |
+
attn_drop=attn_drop,
|
740 |
+
drop_path=drop_path[i]
|
741 |
+
if isinstance(drop_path, list)
|
742 |
+
else drop_path,
|
743 |
+
norm_layer=norm_layer,
|
744 |
+
norm_before_mlp=norm_before_mlp,
|
745 |
+
)
|
746 |
+
for i in range(depth)
|
747 |
+
]
|
748 |
+
)
|
749 |
+
|
750 |
+
# patch merging layer
|
751 |
+
if downsample is not None:
|
752 |
+
self.downsample = downsample(
|
753 |
+
input_resolution, dim=dim, norm_layer=norm_layer
|
754 |
+
)
|
755 |
+
else:
|
756 |
+
self.downsample = None
|
757 |
+
|
758 |
+
def forward(self, x):
|
759 |
+
attns = []
|
760 |
+
for blk in self.blocks:
|
761 |
+
if self.use_checkpoint:
|
762 |
+
x = checkpoint.checkpoint(blk, x)
|
763 |
+
else:
|
764 |
+
x, attn = blk(x)
|
765 |
+
if not self.training:
|
766 |
+
attns.append(attn.unsqueeze(0))
|
767 |
+
if self.downsample is not None:
|
768 |
+
x = self.downsample(x)
|
769 |
+
if not self.training:
|
770 |
+
attn = torch.cat(attns, dim=0)
|
771 |
+
attn = torch.mean(attn, dim=0)
|
772 |
+
return x, attn
|
773 |
+
|
774 |
+
def extra_repr(self):
|
775 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
776 |
+
|
777 |
+
|
778 |
+
# The Core of HTSAT
|
779 |
+
class HTSAT_Swin_Transformer(nn.Module):
|
780 |
+
r"""HTSAT based on the Swin Transformer
|
781 |
+
Args:
|
782 |
+
spec_size (int | tuple(int)): Input Spectrogram size. Default 256
|
783 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
784 |
+
path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
|
785 |
+
in_chans (int): Number of input image channels. Default: 1 (mono)
|
786 |
+
num_classes (int): Number of classes for classification head. Default: 527
|
787 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
788 |
+
depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
|
789 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
790 |
+
window_size (int): Window size. Default: 8
|
791 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
792 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
793 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
794 |
+
drop_rate (float): Dropout rate. Default: 0
|
795 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
796 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
797 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
798 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
799 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
800 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
801 |
+
config (module): The configuration Module from config.py
|
802 |
+
"""
|
803 |
+
|
804 |
+
def __init__(
|
805 |
+
self,
|
806 |
+
spec_size=256,
|
807 |
+
patch_size=4,
|
808 |
+
patch_stride=(4, 4),
|
809 |
+
in_chans=1,
|
810 |
+
num_classes=527,
|
811 |
+
embed_dim=96,
|
812 |
+
depths=[2, 2, 6, 2],
|
813 |
+
num_heads=[4, 8, 16, 32],
|
814 |
+
window_size=8,
|
815 |
+
mlp_ratio=4.0,
|
816 |
+
qkv_bias=True,
|
817 |
+
qk_scale=None,
|
818 |
+
drop_rate=0.0,
|
819 |
+
attn_drop_rate=0.0,
|
820 |
+
drop_path_rate=0.1,
|
821 |
+
norm_layer=nn.LayerNorm,
|
822 |
+
ape=False,
|
823 |
+
patch_norm=True,
|
824 |
+
use_checkpoint=False,
|
825 |
+
norm_before_mlp="ln",
|
826 |
+
config=None,
|
827 |
+
enable_fusion=False,
|
828 |
+
fusion_type="None",
|
829 |
+
**kwargs,
|
830 |
+
):
|
831 |
+
super(HTSAT_Swin_Transformer, self).__init__()
|
832 |
+
|
833 |
+
self.config = config
|
834 |
+
self.spec_size = spec_size
|
835 |
+
self.patch_stride = patch_stride
|
836 |
+
self.patch_size = patch_size
|
837 |
+
self.window_size = window_size
|
838 |
+
self.embed_dim = embed_dim
|
839 |
+
self.depths = depths
|
840 |
+
self.ape = ape
|
841 |
+
self.in_chans = in_chans
|
842 |
+
self.num_classes = num_classes
|
843 |
+
self.num_heads = num_heads
|
844 |
+
self.num_layers = len(self.depths)
|
845 |
+
self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
|
846 |
+
|
847 |
+
self.drop_rate = drop_rate
|
848 |
+
self.attn_drop_rate = attn_drop_rate
|
849 |
+
self.drop_path_rate = drop_path_rate
|
850 |
+
|
851 |
+
self.qkv_bias = qkv_bias
|
852 |
+
self.qk_scale = None
|
853 |
+
|
854 |
+
self.patch_norm = patch_norm
|
855 |
+
self.norm_layer = norm_layer if self.patch_norm else None
|
856 |
+
self.norm_before_mlp = norm_before_mlp
|
857 |
+
self.mlp_ratio = mlp_ratio
|
858 |
+
|
859 |
+
self.use_checkpoint = use_checkpoint
|
860 |
+
|
861 |
+
self.enable_fusion = enable_fusion
|
862 |
+
self.fusion_type = fusion_type
|
863 |
+
|
864 |
+
# process mel-spec ; used only once
|
865 |
+
self.freq_ratio = self.spec_size // self.config.mel_bins
|
866 |
+
window = "hann"
|
867 |
+
center = True
|
868 |
+
pad_mode = "reflect"
|
869 |
+
ref = 1.0
|
870 |
+
amin = 1e-10
|
871 |
+
top_db = None
|
872 |
+
self.interpolate_ratio = 32 # Downsampled ratio
|
873 |
+
# Spectrogram extractor
|
874 |
+
self.spectrogram_extractor = Spectrogram(
|
875 |
+
n_fft=config.window_size,
|
876 |
+
hop_length=config.hop_size,
|
877 |
+
win_length=config.window_size,
|
878 |
+
window=window,
|
879 |
+
center=center,
|
880 |
+
pad_mode=pad_mode,
|
881 |
+
freeze_parameters=True,
|
882 |
+
)
|
883 |
+
# Logmel feature extractor
|
884 |
+
self.logmel_extractor = LogmelFilterBank(
|
885 |
+
sr=config.sample_rate,
|
886 |
+
n_fft=config.window_size,
|
887 |
+
n_mels=config.mel_bins,
|
888 |
+
fmin=config.fmin,
|
889 |
+
fmax=config.fmax,
|
890 |
+
ref=ref,
|
891 |
+
amin=amin,
|
892 |
+
top_db=top_db,
|
893 |
+
freeze_parameters=True,
|
894 |
+
)
|
895 |
+
# Spec augmenter
|
896 |
+
self.spec_augmenter = SpecAugmentation(
|
897 |
+
time_drop_width=64,
|
898 |
+
time_stripes_num=2,
|
899 |
+
freq_drop_width=8,
|
900 |
+
freq_stripes_num=2,
|
901 |
+
) # 2 2
|
902 |
+
self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
|
903 |
+
|
904 |
+
# split spctrogram into non-overlapping patches
|
905 |
+
self.patch_embed = PatchEmbed(
|
906 |
+
img_size=self.spec_size,
|
907 |
+
patch_size=self.patch_size,
|
908 |
+
in_chans=self.in_chans,
|
909 |
+
embed_dim=self.embed_dim,
|
910 |
+
norm_layer=self.norm_layer,
|
911 |
+
patch_stride=patch_stride,
|
912 |
+
enable_fusion=self.enable_fusion,
|
913 |
+
fusion_type=self.fusion_type,
|
914 |
+
)
|
915 |
+
|
916 |
+
num_patches = self.patch_embed.num_patches
|
917 |
+
patches_resolution = self.patch_embed.grid_size
|
918 |
+
self.patches_resolution = patches_resolution
|
919 |
+
|
920 |
+
# absolute position embedding
|
921 |
+
if self.ape:
|
922 |
+
self.absolute_pos_embed = nn.Parameter(
|
923 |
+
torch.zeros(1, num_patches, self.embed_dim)
|
924 |
+
)
|
925 |
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
926 |
+
|
927 |
+
self.pos_drop = nn.Dropout(p=self.drop_rate)
|
928 |
+
|
929 |
+
# stochastic depth
|
930 |
+
dpr = [
|
931 |
+
x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
|
932 |
+
] # stochastic depth decay rule
|
933 |
+
|
934 |
+
# build layers
|
935 |
+
self.layers = nn.ModuleList()
|
936 |
+
for i_layer in range(self.num_layers):
|
937 |
+
layer = BasicLayer(
|
938 |
+
dim=int(self.embed_dim * 2**i_layer),
|
939 |
+
input_resolution=(
|
940 |
+
patches_resolution[0] // (2**i_layer),
|
941 |
+
patches_resolution[1] // (2**i_layer),
|
942 |
+
),
|
943 |
+
depth=self.depths[i_layer],
|
944 |
+
num_heads=self.num_heads[i_layer],
|
945 |
+
window_size=self.window_size,
|
946 |
+
mlp_ratio=self.mlp_ratio,
|
947 |
+
qkv_bias=self.qkv_bias,
|
948 |
+
qk_scale=self.qk_scale,
|
949 |
+
drop=self.drop_rate,
|
950 |
+
attn_drop=self.attn_drop_rate,
|
951 |
+
drop_path=dpr[
|
952 |
+
sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
|
953 |
+
],
|
954 |
+
norm_layer=self.norm_layer,
|
955 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
956 |
+
use_checkpoint=use_checkpoint,
|
957 |
+
norm_before_mlp=self.norm_before_mlp,
|
958 |
+
)
|
959 |
+
self.layers.append(layer)
|
960 |
+
|
961 |
+
self.norm = self.norm_layer(self.num_features)
|
962 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
963 |
+
self.maxpool = nn.AdaptiveMaxPool1d(1)
|
964 |
+
|
965 |
+
SF = (
|
966 |
+
self.spec_size
|
967 |
+
// (2 ** (len(self.depths) - 1))
|
968 |
+
// self.patch_stride[0]
|
969 |
+
// self.freq_ratio
|
970 |
+
)
|
971 |
+
self.tscam_conv = nn.Conv2d(
|
972 |
+
in_channels=self.num_features,
|
973 |
+
out_channels=self.num_classes,
|
974 |
+
kernel_size=(SF, 3),
|
975 |
+
padding=(0, 1),
|
976 |
+
)
|
977 |
+
self.head = nn.Linear(num_classes, num_classes)
|
978 |
+
|
979 |
+
if (self.enable_fusion) and (
|
980 |
+
self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
|
981 |
+
):
|
982 |
+
self.mel_conv1d = nn.Sequential(
|
983 |
+
nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
|
984 |
+
nn.BatchNorm1d(64),
|
985 |
+
)
|
986 |
+
if self.fusion_type == "daf_1d":
|
987 |
+
self.fusion_model = DAF()
|
988 |
+
elif self.fusion_type == "aff_1d":
|
989 |
+
self.fusion_model = AFF(channels=64, type="1D")
|
990 |
+
elif self.fusion_type == "iaff_1d":
|
991 |
+
self.fusion_model = iAFF(channels=64, type="1D")
|
992 |
+
|
993 |
+
self.apply(self._init_weights)
|
994 |
+
|
995 |
+
def _init_weights(self, m):
|
996 |
+
if isinstance(m, nn.Linear):
|
997 |
+
trunc_normal_(m.weight, std=0.02)
|
998 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
999 |
+
nn.init.constant_(m.bias, 0)
|
1000 |
+
elif isinstance(m, nn.LayerNorm):
|
1001 |
+
nn.init.constant_(m.bias, 0)
|
1002 |
+
nn.init.constant_(m.weight, 1.0)
|
1003 |
+
|
1004 |
+
@torch.jit.ignore
|
1005 |
+
def no_weight_decay(self):
|
1006 |
+
return {"absolute_pos_embed"}
|
1007 |
+
|
1008 |
+
@torch.jit.ignore
|
1009 |
+
def no_weight_decay_keywords(self):
|
1010 |
+
return {"relative_position_bias_table"}
|
1011 |
+
|
1012 |
+
def forward_features(self, x, longer_idx=None):
|
1013 |
+
# A deprecated optimization for using a hierarchical output from different blocks
|
1014 |
+
|
1015 |
+
frames_num = x.shape[2]
|
1016 |
+
x = self.patch_embed(x, longer_idx=longer_idx)
|
1017 |
+
if self.ape:
|
1018 |
+
x = x + self.absolute_pos_embed
|
1019 |
+
x = self.pos_drop(x)
|
1020 |
+
for i, layer in enumerate(self.layers):
|
1021 |
+
x, attn = layer(x)
|
1022 |
+
# for x
|
1023 |
+
x = self.norm(x)
|
1024 |
+
B, N, C = x.shape
|
1025 |
+
SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
|
1026 |
+
ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
|
1027 |
+
x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
|
1028 |
+
B, C, F, T = x.shape
|
1029 |
+
# group 2D CNN
|
1030 |
+
c_freq_bin = F // self.freq_ratio
|
1031 |
+
x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
|
1032 |
+
x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
|
1033 |
+
# get latent_output
|
1034 |
+
fine_grained_latent_output = torch.mean(x, dim=2)
|
1035 |
+
fine_grained_latent_output = interpolate(
|
1036 |
+
fine_grained_latent_output.permute(0, 2, 1).contiguous(),
|
1037 |
+
8 * self.patch_stride[1],
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
latent_output = self.avgpool(torch.flatten(x, 2))
|
1041 |
+
latent_output = torch.flatten(latent_output, 1)
|
1042 |
+
|
1043 |
+
# display the attention map, if needed
|
1044 |
+
|
1045 |
+
x = self.tscam_conv(x)
|
1046 |
+
x = torch.flatten(x, 2) # B, C, T
|
1047 |
+
|
1048 |
+
fpx = interpolate(
|
1049 |
+
torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
x = self.avgpool(x)
|
1053 |
+
x = torch.flatten(x, 1)
|
1054 |
+
|
1055 |
+
output_dict = {
|
1056 |
+
"framewise_output": fpx, # already sigmoided
|
1057 |
+
"clipwise_output": torch.sigmoid(x),
|
1058 |
+
"fine_grained_embedding": fine_grained_latent_output,
|
1059 |
+
"embedding": latent_output,
|
1060 |
+
}
|
1061 |
+
|
1062 |
+
return output_dict
|
1063 |
+
|
1064 |
+
def crop_wav(self, x, crop_size, spe_pos=None):
|
1065 |
+
time_steps = x.shape[2]
|
1066 |
+
tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
|
1067 |
+
for i in range(len(x)):
|
1068 |
+
if spe_pos is None:
|
1069 |
+
crop_pos = random.randint(0, time_steps - crop_size - 1)
|
1070 |
+
else:
|
1071 |
+
crop_pos = spe_pos
|
1072 |
+
tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
|
1073 |
+
return tx
|
1074 |
+
|
1075 |
+
# Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
|
1076 |
+
def reshape_wav2img(self, x):
|
1077 |
+
B, C, T, F = x.shape
|
1078 |
+
target_T = int(self.spec_size * self.freq_ratio)
|
1079 |
+
target_F = self.spec_size // self.freq_ratio
|
1080 |
+
assert (
|
1081 |
+
T <= target_T and F <= target_F
|
1082 |
+
), "the wav size should less than or equal to the swin input size"
|
1083 |
+
# to avoid bicubic zero error
|
1084 |
+
if T < target_T:
|
1085 |
+
x = nn.functional.interpolate(
|
1086 |
+
x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
|
1087 |
+
)
|
1088 |
+
if F < target_F:
|
1089 |
+
x = nn.functional.interpolate(
|
1090 |
+
x, (x.shape[2], target_F), mode="bicubic", align_corners=True
|
1091 |
+
)
|
1092 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
1093 |
+
x = x.reshape(
|
1094 |
+
x.shape[0],
|
1095 |
+
x.shape[1],
|
1096 |
+
x.shape[2],
|
1097 |
+
self.freq_ratio,
|
1098 |
+
x.shape[3] // self.freq_ratio,
|
1099 |
+
)
|
1100 |
+
# print(x.shape)
|
1101 |
+
x = x.permute(0, 1, 3, 2, 4).contiguous()
|
1102 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
|
1103 |
+
return x
|
1104 |
+
|
1105 |
+
# Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
|
1106 |
+
def repeat_wat2img(self, x, cur_pos):
|
1107 |
+
B, C, T, F = x.shape
|
1108 |
+
target_T = int(self.spec_size * self.freq_ratio)
|
1109 |
+
target_F = self.spec_size // self.freq_ratio
|
1110 |
+
assert (
|
1111 |
+
T <= target_T and F <= target_F
|
1112 |
+
), "the wav size should less than or equal to the swin input size"
|
1113 |
+
# to avoid bicubic zero error
|
1114 |
+
if T < target_T:
|
1115 |
+
x = nn.functional.interpolate(
|
1116 |
+
x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
|
1117 |
+
)
|
1118 |
+
if F < target_F:
|
1119 |
+
x = nn.functional.interpolate(
|
1120 |
+
x, (x.shape[2], target_F), mode="bicubic", align_corners=True
|
1121 |
+
)
|
1122 |
+
x = x.permute(0, 1, 3, 2).contiguous() # B C F T
|
1123 |
+
x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
|
1124 |
+
x = x.repeat(repeats=(1, 1, 4, 1))
|
1125 |
+
return x
|
1126 |
+
|
1127 |
+
def forward(
|
1128 |
+
self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
|
1129 |
+
): # out_feat_keys: List[str] = None):
|
1130 |
+
|
1131 |
+
if self.enable_fusion and x["longer"].sum() == 0:
|
1132 |
+
# if no audio is longer than 10s, then randomly select one audio to be longer
|
1133 |
+
x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
|
1134 |
+
|
1135 |
+
if not self.enable_fusion:
|
1136 |
+
x = x["waveform"].to(device=device, non_blocking=True)
|
1137 |
+
x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
|
1138 |
+
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
|
1139 |
+
x = x.transpose(1, 3)
|
1140 |
+
x = self.bn0(x)
|
1141 |
+
x = x.transpose(1, 3)
|
1142 |
+
if self.training:
|
1143 |
+
x = self.spec_augmenter(x)
|
1144 |
+
|
1145 |
+
if self.training and mixup_lambda is not None:
|
1146 |
+
x = do_mixup(x, mixup_lambda)
|
1147 |
+
|
1148 |
+
x = self.reshape_wav2img(x)
|
1149 |
+
output_dict = self.forward_features(x)
|
1150 |
+
else:
|
1151 |
+
longer_list = x["longer"].to(device=device, non_blocking=True)
|
1152 |
+
x = x["mel_fusion"].to(device=device, non_blocking=True)
|
1153 |
+
x = x.transpose(1, 3)
|
1154 |
+
x = self.bn0(x)
|
1155 |
+
x = x.transpose(1, 3)
|
1156 |
+
longer_list_idx = torch.where(longer_list)[0]
|
1157 |
+
if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
|
1158 |
+
new_x = x[:, 0:1, :, :].clone().contiguous()
|
1159 |
+
if len(longer_list_idx) > 0:
|
1160 |
+
# local processing
|
1161 |
+
fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
|
1162 |
+
FB, FC, FT, FF = fusion_x_local.size()
|
1163 |
+
fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
|
1164 |
+
fusion_x_local = torch.permute(
|
1165 |
+
fusion_x_local, (0, 2, 1)
|
1166 |
+
).contiguous()
|
1167 |
+
fusion_x_local = self.mel_conv1d(fusion_x_local)
|
1168 |
+
fusion_x_local = fusion_x_local.view(
|
1169 |
+
FB, FC, FF, fusion_x_local.size(-1)
|
1170 |
+
)
|
1171 |
+
fusion_x_local = (
|
1172 |
+
torch.permute(fusion_x_local, (0, 2, 1, 3))
|
1173 |
+
.contiguous()
|
1174 |
+
.flatten(2)
|
1175 |
+
)
|
1176 |
+
if fusion_x_local.size(-1) < FT:
|
1177 |
+
fusion_x_local = torch.cat(
|
1178 |
+
[
|
1179 |
+
fusion_x_local,
|
1180 |
+
torch.zeros(
|
1181 |
+
(FB, FF, FT - fusion_x_local.size(-1)),
|
1182 |
+
device=device,
|
1183 |
+
),
|
1184 |
+
],
|
1185 |
+
dim=-1,
|
1186 |
+
)
|
1187 |
+
else:
|
1188 |
+
fusion_x_local = fusion_x_local[:, :, :FT]
|
1189 |
+
# 1D fusion
|
1190 |
+
new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
|
1191 |
+
new_x[longer_list_idx] = self.fusion_model(
|
1192 |
+
new_x[longer_list_idx], fusion_x_local
|
1193 |
+
)
|
1194 |
+
x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
|
1195 |
+
else:
|
1196 |
+
x = new_x
|
1197 |
+
|
1198 |
+
elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
|
1199 |
+
x = x # no change
|
1200 |
+
|
1201 |
+
if self.training:
|
1202 |
+
x = self.spec_augmenter(x)
|
1203 |
+
if self.training and mixup_lambda is not None:
|
1204 |
+
x = do_mixup(x, mixup_lambda)
|
1205 |
+
|
1206 |
+
x = self.reshape_wav2img(x)
|
1207 |
+
output_dict = self.forward_features(x, longer_idx=longer_list_idx)
|
1208 |
+
|
1209 |
+
# if infer_mode:
|
1210 |
+
# # in infer mode. we need to handle different length audio input
|
1211 |
+
# frame_num = x.shape[2]
|
1212 |
+
# target_T = int(self.spec_size * self.freq_ratio)
|
1213 |
+
# repeat_ratio = math.floor(target_T / frame_num)
|
1214 |
+
# x = x.repeat(repeats=(1,1,repeat_ratio,1))
|
1215 |
+
# x = self.reshape_wav2img(x)
|
1216 |
+
# output_dict = self.forward_features(x)
|
1217 |
+
# else:
|
1218 |
+
# if x.shape[2] > self.freq_ratio * self.spec_size:
|
1219 |
+
# if self.training:
|
1220 |
+
# x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
|
1221 |
+
# x = self.reshape_wav2img(x)
|
1222 |
+
# output_dict = self.forward_features(x)
|
1223 |
+
# else:
|
1224 |
+
# # Change: Hard code here
|
1225 |
+
# overlap_size = (x.shape[2] - 1) // 4
|
1226 |
+
# output_dicts = []
|
1227 |
+
# crop_size = (x.shape[2] - 1) // 2
|
1228 |
+
# for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
|
1229 |
+
# tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
|
1230 |
+
# tx = self.reshape_wav2img(tx)
|
1231 |
+
# output_dicts.append(self.forward_features(tx))
|
1232 |
+
# clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
|
1233 |
+
# framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
|
1234 |
+
# for d in output_dicts:
|
1235 |
+
# clipwise_output += d["clipwise_output"]
|
1236 |
+
# framewise_output += d["framewise_output"]
|
1237 |
+
# clipwise_output = clipwise_output / len(output_dicts)
|
1238 |
+
# framewise_output = framewise_output / len(output_dicts)
|
1239 |
+
# output_dict = {
|
1240 |
+
# 'framewise_output': framewise_output,
|
1241 |
+
# 'clipwise_output': clipwise_output
|
1242 |
+
# }
|
1243 |
+
# else: # this part is typically used, and most easy one
|
1244 |
+
# x = self.reshape_wav2img(x)
|
1245 |
+
# output_dict = self.forward_features(x)
|
1246 |
+
# x = self.head(x)
|
1247 |
+
|
1248 |
+
# We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
|
1249 |
+
|
1250 |
+
return output_dict
|
1251 |
+
|
1252 |
+
|
1253 |
+
def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
|
1254 |
+
try:
|
1255 |
+
|
1256 |
+
assert audio_cfg.model_name in [
|
1257 |
+
"tiny",
|
1258 |
+
"base",
|
1259 |
+
"large",
|
1260 |
+
], "model name for HTS-AT is wrong!"
|
1261 |
+
if audio_cfg.model_name == "tiny":
|
1262 |
+
model = HTSAT_Swin_Transformer(
|
1263 |
+
spec_size=256,
|
1264 |
+
patch_size=4,
|
1265 |
+
patch_stride=(4, 4),
|
1266 |
+
num_classes=audio_cfg.class_num,
|
1267 |
+
embed_dim=96,
|
1268 |
+
depths=[2, 2, 6, 2],
|
1269 |
+
num_heads=[4, 8, 16, 32],
|
1270 |
+
window_size=8,
|
1271 |
+
config=audio_cfg,
|
1272 |
+
enable_fusion=enable_fusion,
|
1273 |
+
fusion_type=fusion_type,
|
1274 |
+
)
|
1275 |
+
elif audio_cfg.model_name == "base":
|
1276 |
+
model = HTSAT_Swin_Transformer(
|
1277 |
+
spec_size=256,
|
1278 |
+
patch_size=4,
|
1279 |
+
patch_stride=(4, 4),
|
1280 |
+
num_classes=audio_cfg.class_num,
|
1281 |
+
embed_dim=128,
|
1282 |
+
depths=[2, 2, 12, 2],
|
1283 |
+
num_heads=[4, 8, 16, 32],
|
1284 |
+
window_size=8,
|
1285 |
+
config=audio_cfg,
|
1286 |
+
enable_fusion=enable_fusion,
|
1287 |
+
fusion_type=fusion_type,
|
1288 |
+
)
|
1289 |
+
elif audio_cfg.model_name == "large":
|
1290 |
+
model = HTSAT_Swin_Transformer(
|
1291 |
+
spec_size=256,
|
1292 |
+
patch_size=4,
|
1293 |
+
patch_stride=(4, 4),
|
1294 |
+
num_classes=audio_cfg.class_num,
|
1295 |
+
embed_dim=256,
|
1296 |
+
depths=[2, 2, 12, 2],
|
1297 |
+
num_heads=[4, 8, 16, 32],
|
1298 |
+
window_size=8,
|
1299 |
+
config=audio_cfg,
|
1300 |
+
enable_fusion=enable_fusion,
|
1301 |
+
fusion_type=fusion_type,
|
1302 |
+
)
|
1303 |
+
|
1304 |
+
return model
|
1305 |
+
except:
|
1306 |
+
raise RuntimeError(
|
1307 |
+
f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
|
1308 |
+
)
|
audioldm/clap/open_clip/linear_probe.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from .model import MLPLayers
|
5 |
+
|
6 |
+
|
7 |
+
class LinearProbe(nn.Module):
|
8 |
+
def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
|
9 |
+
"""
|
10 |
+
Args:
|
11 |
+
model: nn.Module
|
12 |
+
mlp: bool, if True, then use the MLP layer as the linear probe module
|
13 |
+
freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
|
14 |
+
in_ch: int, the output channel from CLAP model
|
15 |
+
out_ch: int, the output channel from linear probe (class_num)
|
16 |
+
act: torch.nn.functional, the activation function before the loss function
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
in_ch = 512
|
20 |
+
self.clap_model = model
|
21 |
+
self.clap_model.text_branch = None # to save memory
|
22 |
+
self.freeze = freeze
|
23 |
+
if mlp:
|
24 |
+
self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
|
25 |
+
else:
|
26 |
+
self.lp_layer = nn.Linear(in_ch, out_ch)
|
27 |
+
|
28 |
+
if self.freeze:
|
29 |
+
for param in self.clap_model.parameters():
|
30 |
+
param.requires_grad = False
|
31 |
+
|
32 |
+
if act == "None":
|
33 |
+
self.act = None
|
34 |
+
elif act == "relu":
|
35 |
+
self.act = nn.ReLU()
|
36 |
+
elif act == "elu":
|
37 |
+
self.act = nn.ELU()
|
38 |
+
elif act == "prelu":
|
39 |
+
self.act = nn.PReLU(num_parameters=in_ch)
|
40 |
+
elif act == "softmax":
|
41 |
+
self.act = nn.Softmax(dim=-1)
|
42 |
+
elif act == "sigmoid":
|
43 |
+
self.act = nn.Sigmoid()
|
44 |
+
|
45 |
+
def forward(self, x, mix_lambda=None, device=None):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
|
49 |
+
mix_lambda: torch.tensor [batch], the mixup lambda
|
50 |
+
Returns:
|
51 |
+
class_prob: torch.tensor [batch, class_num]
|
52 |
+
|
53 |
+
"""
|
54 |
+
# batchnorm cancel grandient
|
55 |
+
if self.freeze:
|
56 |
+
self.clap_model.eval()
|
57 |
+
|
58 |
+
x = self.clap_model.audio_projection(
|
59 |
+
self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[
|
60 |
+
"embedding"
|
61 |
+
]
|
62 |
+
)
|
63 |
+
out = self.lp_layer(x)
|
64 |
+
if self.act is not None:
|
65 |
+
out = self.act(out)
|
66 |
+
return out
|
audioldm/clap/open_clip/loss.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing.sharedctypes import Value
|
2 |
+
import torch
|
3 |
+
import torch.distributed.nn
|
4 |
+
from torch import distributed as dist, nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
|
8 |
+
|
9 |
+
try:
|
10 |
+
import horovod.torch as hvd
|
11 |
+
except ImportError:
|
12 |
+
hvd = None
|
13 |
+
|
14 |
+
|
15 |
+
def gather_features(
|
16 |
+
audio_features,
|
17 |
+
text_features,
|
18 |
+
audio_features_mlp=None,
|
19 |
+
text_features_mlp=None,
|
20 |
+
local_loss=False,
|
21 |
+
gather_with_grad=False,
|
22 |
+
rank=0,
|
23 |
+
world_size=1,
|
24 |
+
use_horovod=False,
|
25 |
+
mlp_loss=False,
|
26 |
+
):
|
27 |
+
if use_horovod:
|
28 |
+
assert hvd is not None, "Please install horovod"
|
29 |
+
if gather_with_grad:
|
30 |
+
all_audio_features = hvd.allgather(audio_features)
|
31 |
+
all_text_features = hvd.allgather(text_features)
|
32 |
+
if mlp_loss:
|
33 |
+
all_audio_features_mlp = hvd.allgather(audio_features_mlp)
|
34 |
+
all_text_features_mlp = hvd.allgather(text_features_mlp)
|
35 |
+
else:
|
36 |
+
with torch.no_grad():
|
37 |
+
all_audio_features = hvd.allgather(audio_features)
|
38 |
+
all_text_features = hvd.allgather(text_features)
|
39 |
+
if mlp_loss:
|
40 |
+
all_audio_features_mlp = hvd.allgather(audio_features_mlp)
|
41 |
+
all_text_features_mlp = hvd.allgather(text_features_mlp)
|
42 |
+
if not local_loss:
|
43 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
44 |
+
gathered_audio_features = list(
|
45 |
+
all_audio_features.chunk(world_size, dim=0)
|
46 |
+
)
|
47 |
+
gathered_text_features = list(
|
48 |
+
all_text_features.chunk(world_size, dim=0)
|
49 |
+
)
|
50 |
+
gathered_audio_features[rank] = audio_features
|
51 |
+
gathered_text_features[rank] = text_features
|
52 |
+
all_audio_features = torch.cat(gathered_audio_features, dim=0)
|
53 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
54 |
+
if mlp_loss:
|
55 |
+
gathered_audio_features_mlp = list(
|
56 |
+
all_audio_features_mlp.chunk(world_size, dim=0)
|
57 |
+
)
|
58 |
+
gathered_text_features_mlp = list(
|
59 |
+
all_text_features_mlp.chunk(world_size, dim=0)
|
60 |
+
)
|
61 |
+
gathered_audio_features_mlp[rank] = audio_features_mlp
|
62 |
+
gathered_text_features_mlp[rank] = text_features_mlp
|
63 |
+
all_audio_features_mlp = torch.cat(
|
64 |
+
gathered_audio_features_mlp, dim=0
|
65 |
+
)
|
66 |
+
all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
|
67 |
+
else:
|
68 |
+
# We gather tensors from all gpus
|
69 |
+
if gather_with_grad:
|
70 |
+
all_audio_features = torch.cat(
|
71 |
+
torch.distributed.nn.all_gather(audio_features), dim=0
|
72 |
+
)
|
73 |
+
all_text_features = torch.cat(
|
74 |
+
torch.distributed.nn.all_gather(text_features), dim=0
|
75 |
+
)
|
76 |
+
if mlp_loss:
|
77 |
+
all_audio_features_mlp = torch.cat(
|
78 |
+
torch.distributed.nn.all_gather(audio_features_mlp), dim=0
|
79 |
+
)
|
80 |
+
all_text_features_mlp = torch.cat(
|
81 |
+
torch.distributed.nn.all_gather(text_features_mlp), dim=0
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
gathered_audio_features = [
|
85 |
+
torch.zeros_like(audio_features) for _ in range(world_size)
|
86 |
+
]
|
87 |
+
gathered_text_features = [
|
88 |
+
torch.zeros_like(text_features) for _ in range(world_size)
|
89 |
+
]
|
90 |
+
dist.all_gather(gathered_audio_features, audio_features)
|
91 |
+
dist.all_gather(gathered_text_features, text_features)
|
92 |
+
if mlp_loss:
|
93 |
+
gathered_audio_features_mlp = [
|
94 |
+
torch.zeros_like(audio_features_mlp) for _ in range(world_size)
|
95 |
+
]
|
96 |
+
gathered_text_features_mlp = [
|
97 |
+
torch.zeros_like(text_features_mlp) for _ in range(world_size)
|
98 |
+
]
|
99 |
+
dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
|
100 |
+
dist.all_gather(gathered_text_features_mlp, text_features_mlp)
|
101 |
+
if not local_loss:
|
102 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
103 |
+
gathered_audio_features[rank] = audio_features
|
104 |
+
gathered_text_features[rank] = text_features
|
105 |
+
if mlp_loss:
|
106 |
+
gathered_audio_features_mlp[rank] = audio_features_mlp
|
107 |
+
gathered_text_features_mlp[rank] = text_features_mlp
|
108 |
+
|
109 |
+
all_audio_features = torch.cat(gathered_audio_features, dim=0)
|
110 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
111 |
+
if mlp_loss:
|
112 |
+
all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
|
113 |
+
all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
|
114 |
+
if mlp_loss:
|
115 |
+
return (
|
116 |
+
all_audio_features,
|
117 |
+
all_text_features,
|
118 |
+
all_audio_features_mlp,
|
119 |
+
all_text_features_mlp,
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
return all_audio_features, all_text_features
|
123 |
+
|
124 |
+
|
125 |
+
class ClipLoss(nn.Module):
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
local_loss=False,
|
129 |
+
gather_with_grad=False,
|
130 |
+
cache_labels=False,
|
131 |
+
rank=0,
|
132 |
+
world_size=1,
|
133 |
+
use_horovod=False,
|
134 |
+
mlp_loss=False,
|
135 |
+
weight_loss_kappa=0,
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
self.local_loss = local_loss
|
139 |
+
self.gather_with_grad = gather_with_grad
|
140 |
+
self.cache_labels = cache_labels
|
141 |
+
self.rank = rank
|
142 |
+
self.world_size = world_size
|
143 |
+
self.use_horovod = use_horovod
|
144 |
+
self.mlp_loss = mlp_loss
|
145 |
+
self.weighted_loss = bool(weight_loss_kappa != 0)
|
146 |
+
self.weight_loss_kappa = weight_loss_kappa
|
147 |
+
# cache state
|
148 |
+
self.prev_num_logits = 0
|
149 |
+
self.labels = {}
|
150 |
+
|
151 |
+
def forward(
|
152 |
+
self,
|
153 |
+
audio_features,
|
154 |
+
text_features,
|
155 |
+
logit_scale_a,
|
156 |
+
logit_scale_t=None,
|
157 |
+
audio_features_mlp=None,
|
158 |
+
text_features_mlp=None,
|
159 |
+
):
|
160 |
+
device = audio_features.device
|
161 |
+
if self.mlp_loss:
|
162 |
+
if self.world_size > 1:
|
163 |
+
(
|
164 |
+
all_audio_features,
|
165 |
+
all_text_features,
|
166 |
+
all_audio_features_mlp,
|
167 |
+
all_text_features_mlp,
|
168 |
+
) = gather_features(
|
169 |
+
audio_features=audio_features,
|
170 |
+
text_features=text_features,
|
171 |
+
audio_features_mlp=audio_features_mlp,
|
172 |
+
text_features_mlp=text_features_mlp,
|
173 |
+
local_loss=self.local_loss,
|
174 |
+
gather_with_grad=self.gather_with_grad,
|
175 |
+
rank=self.rank,
|
176 |
+
world_size=self.world_size,
|
177 |
+
use_horovod=self.use_horovod,
|
178 |
+
mlp_loss=self.mlp_loss,
|
179 |
+
)
|
180 |
+
if self.local_loss:
|
181 |
+
a_logits_per_audio = (
|
182 |
+
logit_scale_a * audio_features @ all_text_features_mlp.T
|
183 |
+
)
|
184 |
+
a_logits_per_text = (
|
185 |
+
logit_scale_a * text_features_mlp @ all_audio_features.T
|
186 |
+
)
|
187 |
+
t_logits_per_audio = (
|
188 |
+
logit_scale_t * audio_features_mlp @ all_text_features.T
|
189 |
+
)
|
190 |
+
t_logits_per_text = (
|
191 |
+
logit_scale_t * text_features @ all_audio_features_mlp.T
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
a_logits_per_audio = (
|
195 |
+
logit_scale_a * all_audio_features @ all_text_features_mlp.T
|
196 |
+
)
|
197 |
+
a_logits_per_text = a_logits_per_audio.T
|
198 |
+
t_logits_per_audio = (
|
199 |
+
logit_scale_t * all_audio_features_mlp @ all_text_features.T
|
200 |
+
)
|
201 |
+
t_logits_per_text = t_logits_per_audio.T
|
202 |
+
else:
|
203 |
+
a_logits_per_audio = (
|
204 |
+
logit_scale_a * audio_features @ text_features_mlp.T
|
205 |
+
)
|
206 |
+
a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
|
207 |
+
t_logits_per_audio = (
|
208 |
+
logit_scale_t * audio_features_mlp @ text_features.T
|
209 |
+
)
|
210 |
+
t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
|
211 |
+
|
212 |
+
# calculated ground-truth and cache if enabled
|
213 |
+
num_logits = a_logits_per_audio.shape[0]
|
214 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
215 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
216 |
+
if self.world_size > 1 and self.local_loss:
|
217 |
+
labels = labels + num_logits * self.rank
|
218 |
+
if self.cache_labels:
|
219 |
+
self.labels[device] = labels
|
220 |
+
self.prev_num_logits = num_logits
|
221 |
+
else:
|
222 |
+
labels = self.labels[device]
|
223 |
+
|
224 |
+
if not self.weighted_loss:
|
225 |
+
total_loss = (
|
226 |
+
F.cross_entropy(a_logits_per_audio, labels)
|
227 |
+
+ F.cross_entropy(a_logits_per_text, labels)
|
228 |
+
+ F.cross_entropy(t_logits_per_audio, labels)
|
229 |
+
+ F.cross_entropy(t_logits_per_text, labels)
|
230 |
+
) / 4
|
231 |
+
else:
|
232 |
+
audio_weight = (audio_features @ audio_features.T).detach()
|
233 |
+
audio_weight = (
|
234 |
+
torch.exp(
|
235 |
+
torch.sum(audio_weight, axis=1)
|
236 |
+
/ (self.weight_loss_kappa * len(audio_weight))
|
237 |
+
)
|
238 |
+
).detach()
|
239 |
+
text_weight = (text_features @ text_features.T).detach()
|
240 |
+
text_weight = (
|
241 |
+
torch.exp(
|
242 |
+
torch.sum(text_weight, axis=1)
|
243 |
+
/ (self.weight_loss_kappa * len(text_features))
|
244 |
+
)
|
245 |
+
).detach()
|
246 |
+
total_loss = (
|
247 |
+
F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
|
248 |
+
+ F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
|
249 |
+
+ F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
|
250 |
+
+ F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
|
251 |
+
) / 4
|
252 |
+
else:
|
253 |
+
if self.world_size > 1:
|
254 |
+
all_audio_features, all_text_features = gather_features(
|
255 |
+
audio_features=audio_features,
|
256 |
+
text_features=text_features,
|
257 |
+
local_loss=self.local_loss,
|
258 |
+
gather_with_grad=self.gather_with_grad,
|
259 |
+
rank=self.rank,
|
260 |
+
world_size=self.world_size,
|
261 |
+
use_horovod=self.use_horovod,
|
262 |
+
mlp_loss=self.mlp_loss,
|
263 |
+
)
|
264 |
+
|
265 |
+
if self.local_loss:
|
266 |
+
logits_per_audio = (
|
267 |
+
logit_scale_a * audio_features @ all_text_features.T
|
268 |
+
)
|
269 |
+
logits_per_text = (
|
270 |
+
logit_scale_a * text_features @ all_audio_features.T
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
logits_per_audio = (
|
274 |
+
logit_scale_a * all_audio_features @ all_text_features.T
|
275 |
+
)
|
276 |
+
logits_per_text = logits_per_audio.T
|
277 |
+
else:
|
278 |
+
logits_per_audio = logit_scale_a * audio_features @ text_features.T
|
279 |
+
logits_per_text = logit_scale_a * text_features @ audio_features.T
|
280 |
+
|
281 |
+
# calculated ground-truth and cache if enabled
|
282 |
+
num_logits = logits_per_audio.shape[0]
|
283 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
284 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
285 |
+
if self.world_size > 1 and self.local_loss:
|
286 |
+
labels = labels + num_logits * self.rank
|
287 |
+
if self.cache_labels:
|
288 |
+
self.labels[device] = labels
|
289 |
+
self.prev_num_logits = num_logits
|
290 |
+
else:
|
291 |
+
labels = self.labels[device]
|
292 |
+
if not self.weighted_loss:
|
293 |
+
total_loss = (
|
294 |
+
F.cross_entropy(logits_per_audio, labels)
|
295 |
+
+ F.cross_entropy(logits_per_text, labels)
|
296 |
+
) / 2
|
297 |
+
else:
|
298 |
+
audio_weight = (all_audio_features @ all_audio_features.T).detach()
|
299 |
+
audio_weight = (
|
300 |
+
torch.exp(
|
301 |
+
torch.sum(audio_weight, axis=1)
|
302 |
+
/ (self.weight_loss_kappa * len(all_audio_features))
|
303 |
+
)
|
304 |
+
).detach()
|
305 |
+
text_weight = (all_text_features @ all_text_features.T).detach()
|
306 |
+
text_weight = (
|
307 |
+
torch.exp(
|
308 |
+
torch.sum(text_weight, axis=1)
|
309 |
+
/ (self.weight_loss_kappa * len(all_text_features))
|
310 |
+
)
|
311 |
+
).detach()
|
312 |
+
total_loss = (
|
313 |
+
F.cross_entropy(logits_per_audio, labels, weight=text_weight)
|
314 |
+
+ F.cross_entropy(logits_per_text, labels, weight=audio_weight)
|
315 |
+
) / 2
|
316 |
+
return total_loss
|
317 |
+
|
318 |
+
|
319 |
+
def lp_gather_features(pred, target, world_size=1, use_horovod=False):
|
320 |
+
if use_horovod:
|
321 |
+
assert hvd is not None, "Please install horovod"
|
322 |
+
with torch.no_grad():
|
323 |
+
all_preds = hvd.allgather(pred)
|
324 |
+
all_targets = hvd.allgath(target)
|
325 |
+
else:
|
326 |
+
gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
|
327 |
+
gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
|
328 |
+
|
329 |
+
dist.all_gather(gathered_preds, pred)
|
330 |
+
dist.all_gather(gathered_targets, target)
|
331 |
+
all_preds = torch.cat(gathered_preds, dim=0)
|
332 |
+
all_targets = torch.cat(gathered_targets, dim=0)
|
333 |
+
|
334 |
+
return all_preds, all_targets
|
335 |
+
|
336 |
+
|
337 |
+
def get_map(pred, target):
|
338 |
+
pred = torch.sigmoid(pred).numpy()
|
339 |
+
target = target.numpy()
|
340 |
+
return np.mean(average_precision_score(target, pred, average=None))
|
341 |
+
|
342 |
+
|
343 |
+
def get_acc(pred, target):
|
344 |
+
pred = torch.argmax(pred, 1).numpy()
|
345 |
+
target = torch.argmax(target, 1).numpy()
|
346 |
+
return accuracy_score(target, pred)
|
347 |
+
|
348 |
+
|
349 |
+
def get_mauc(pred, target):
|
350 |
+
pred = torch.sigmoid(pred).numpy()
|
351 |
+
target = target.numpy()
|
352 |
+
return np.mean(roc_auc_score(target, pred, average=None))
|
353 |
+
|
354 |
+
|
355 |
+
class LPMetrics(object):
|
356 |
+
def __init__(self, metric_names=["map", "acc", "mauc"]):
|
357 |
+
self.metrics = []
|
358 |
+
for name in metric_names:
|
359 |
+
self.metrics.append(self.get_metric(name))
|
360 |
+
self.metric_names = metric_names
|
361 |
+
|
362 |
+
def get_metric(self, name):
|
363 |
+
if name == "map":
|
364 |
+
return get_map
|
365 |
+
elif name == "acc":
|
366 |
+
return get_acc
|
367 |
+
elif name == "mauc":
|
368 |
+
return get_mauc
|
369 |
+
else:
|
370 |
+
raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
|
371 |
+
|
372 |
+
def evaluate_mertics(self, pred, target):
|
373 |
+
metric_dict = {}
|
374 |
+
for i in range(len(self.metric_names)):
|
375 |
+
metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
|
376 |
+
return metric_dict
|
377 |
+
|
378 |
+
|
379 |
+
def calc_celoss(pred, target):
|
380 |
+
target = torch.argmax(target, 1).long()
|
381 |
+
return nn.CrossEntropyLoss()(pred, target)
|
382 |
+
|
383 |
+
|
384 |
+
class LPLoss(nn.Module):
|
385 |
+
def __init__(self, loss_name):
|
386 |
+
super().__init__()
|
387 |
+
if loss_name == "bce":
|
388 |
+
self.loss_func = nn.BCEWithLogitsLoss()
|
389 |
+
elif loss_name == "ce":
|
390 |
+
self.loss_func = calc_celoss
|
391 |
+
elif loss_name == "mse":
|
392 |
+
self.loss_func = nn.MSELoss()
|
393 |
+
else:
|
394 |
+
raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
|
395 |
+
|
396 |
+
def forward(self, pred, target):
|
397 |
+
loss = self.loss_func(pred, target)
|
398 |
+
return loss
|
audioldm/clap/open_clip/model.py
ADDED
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLAP Model
|
2 |
+
|
3 |
+
Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
Adapted to the Audio Task.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from email.mime import audio
|
10 |
+
from typing import Tuple, Union, Callable, Optional
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
from .timm_model import TimmModel
|
18 |
+
import logging
|
19 |
+
from .utils import freeze_batch_norm_2d
|
20 |
+
|
21 |
+
from .pann_model import create_pann_model
|
22 |
+
from .htsat import create_htsat_model
|
23 |
+
from transformers import BertModel, RobertaModel, BartModel
|
24 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
25 |
+
|
26 |
+
|
27 |
+
class MLPLayers(nn.Module):
|
28 |
+
def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
|
29 |
+
super(MLPLayers, self).__init__()
|
30 |
+
self.nonlin = nonlin
|
31 |
+
self.dropout = dropout
|
32 |
+
|
33 |
+
sequence = []
|
34 |
+
for u0, u1 in zip(units[:-1], units[1:]):
|
35 |
+
sequence.append(nn.Linear(u0, u1))
|
36 |
+
sequence.append(self.nonlin)
|
37 |
+
sequence.append(nn.Dropout(self.dropout))
|
38 |
+
sequence = sequence[:-2]
|
39 |
+
|
40 |
+
self.sequential = nn.Sequential(*sequence)
|
41 |
+
|
42 |
+
def forward(self, X):
|
43 |
+
X = self.sequential(X)
|
44 |
+
return X
|
45 |
+
|
46 |
+
|
47 |
+
class Bottleneck(nn.Module):
|
48 |
+
expansion = 4
|
49 |
+
|
50 |
+
def __init__(self, inplanes, planes, stride=1):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
54 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
55 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
56 |
+
|
57 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
58 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
59 |
+
|
60 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
61 |
+
|
62 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
63 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
64 |
+
|
65 |
+
self.relu = nn.ReLU(inplace=True)
|
66 |
+
self.downsample = None
|
67 |
+
self.stride = stride
|
68 |
+
|
69 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
70 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
71 |
+
self.downsample = nn.Sequential(
|
72 |
+
OrderedDict(
|
73 |
+
[
|
74 |
+
("-1", nn.AvgPool2d(stride)),
|
75 |
+
(
|
76 |
+
"0",
|
77 |
+
nn.Conv2d(
|
78 |
+
inplanes,
|
79 |
+
planes * self.expansion,
|
80 |
+
1,
|
81 |
+
stride=1,
|
82 |
+
bias=False,
|
83 |
+
),
|
84 |
+
),
|
85 |
+
("1", nn.BatchNorm2d(planes * self.expansion)),
|
86 |
+
]
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, x: torch.Tensor):
|
91 |
+
identity = x
|
92 |
+
|
93 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
94 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
95 |
+
out = self.avgpool(out)
|
96 |
+
out = self.bn3(self.conv3(out))
|
97 |
+
|
98 |
+
if self.downsample is not None:
|
99 |
+
identity = self.downsample(x)
|
100 |
+
|
101 |
+
out += identity
|
102 |
+
out = self.relu(out)
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
class AttentionPool2d(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.positional_embedding = nn.Parameter(
|
112 |
+
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
|
113 |
+
)
|
114 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
115 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
116 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
117 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
118 |
+
self.num_heads = num_heads
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
|
122 |
+
2, 0, 1
|
123 |
+
) # NCHW -> (HW)NC
|
124 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
125 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
126 |
+
x, _ = F.multi_head_attention_forward(
|
127 |
+
query=x,
|
128 |
+
key=x,
|
129 |
+
value=x,
|
130 |
+
embed_dim_to_check=x.shape[-1],
|
131 |
+
num_heads=self.num_heads,
|
132 |
+
q_proj_weight=self.q_proj.weight,
|
133 |
+
k_proj_weight=self.k_proj.weight,
|
134 |
+
v_proj_weight=self.v_proj.weight,
|
135 |
+
in_proj_weight=None,
|
136 |
+
in_proj_bias=torch.cat(
|
137 |
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
|
138 |
+
),
|
139 |
+
bias_k=None,
|
140 |
+
bias_v=None,
|
141 |
+
add_zero_attn=False,
|
142 |
+
dropout_p=0,
|
143 |
+
out_proj_weight=self.c_proj.weight,
|
144 |
+
out_proj_bias=self.c_proj.bias,
|
145 |
+
use_separate_proj_weight=True,
|
146 |
+
training=self.training,
|
147 |
+
need_weights=False,
|
148 |
+
)
|
149 |
+
|
150 |
+
return x[0]
|
151 |
+
|
152 |
+
|
153 |
+
class ModifiedResNet(nn.Module):
|
154 |
+
"""
|
155 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
156 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
157 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
158 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
162 |
+
super().__init__()
|
163 |
+
self.output_dim = output_dim
|
164 |
+
self.image_size = image_size
|
165 |
+
|
166 |
+
# the 3-layer stem
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
|
169 |
+
)
|
170 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
width // 2, width // 2, kernel_size=3, padding=1, bias=False
|
173 |
+
)
|
174 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
175 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
176 |
+
self.bn3 = nn.BatchNorm2d(width)
|
177 |
+
self.avgpool = nn.AvgPool2d(2)
|
178 |
+
self.relu = nn.ReLU(inplace=True)
|
179 |
+
|
180 |
+
# residual layers
|
181 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
182 |
+
self.layer1 = self._make_layer(width, layers[0])
|
183 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
184 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
185 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
186 |
+
|
187 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
188 |
+
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
189 |
+
|
190 |
+
self.init_parameters()
|
191 |
+
|
192 |
+
def _make_layer(self, planes, blocks, stride=1):
|
193 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
194 |
+
|
195 |
+
self._inplanes = planes * Bottleneck.expansion
|
196 |
+
for _ in range(1, blocks):
|
197 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
198 |
+
|
199 |
+
return nn.Sequential(*layers)
|
200 |
+
|
201 |
+
def init_parameters(self):
|
202 |
+
if self.attnpool is not None:
|
203 |
+
std = self.attnpool.c_proj.in_features**-0.5
|
204 |
+
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
205 |
+
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
206 |
+
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
207 |
+
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
208 |
+
|
209 |
+
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
210 |
+
for name, param in resnet_block.named_parameters():
|
211 |
+
if name.endswith("bn3.weight"):
|
212 |
+
nn.init.zeros_(param)
|
213 |
+
|
214 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
215 |
+
assert (
|
216 |
+
unlocked_groups == 0
|
217 |
+
), "partial locking not currently supported for this model"
|
218 |
+
for param in self.parameters():
|
219 |
+
param.requires_grad = False
|
220 |
+
if freeze_bn_stats:
|
221 |
+
freeze_batch_norm_2d(self)
|
222 |
+
|
223 |
+
def stem(self, x):
|
224 |
+
for conv, bn in [
|
225 |
+
(self.conv1, self.bn1),
|
226 |
+
(self.conv2, self.bn2),
|
227 |
+
(self.conv3, self.bn3),
|
228 |
+
]:
|
229 |
+
x = self.relu(bn(conv(x)))
|
230 |
+
x = self.avgpool(x)
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
x = self.stem(x)
|
235 |
+
x = self.layer1(x)
|
236 |
+
x = self.layer2(x)
|
237 |
+
x = self.layer3(x)
|
238 |
+
x = self.layer4(x)
|
239 |
+
x = self.attnpool(x)
|
240 |
+
|
241 |
+
return x
|
242 |
+
|
243 |
+
|
244 |
+
class LayerNorm(nn.LayerNorm):
|
245 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
246 |
+
|
247 |
+
def forward(self, x: torch.Tensor):
|
248 |
+
orig_type = x.dtype
|
249 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
250 |
+
return x.to(orig_type)
|
251 |
+
|
252 |
+
|
253 |
+
class QuickGELU(nn.Module):
|
254 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
255 |
+
def forward(self, x: torch.Tensor):
|
256 |
+
return x * torch.sigmoid(1.702 * x)
|
257 |
+
|
258 |
+
|
259 |
+
class ResidualAttentionBlock(nn.Module):
|
260 |
+
def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
|
261 |
+
super().__init__()
|
262 |
+
|
263 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
264 |
+
self.ln_1 = LayerNorm(d_model)
|
265 |
+
self.mlp = nn.Sequential(
|
266 |
+
OrderedDict(
|
267 |
+
[
|
268 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
269 |
+
("gelu", act_layer()),
|
270 |
+
("c_proj", nn.Linear(d_model * 4, d_model)),
|
271 |
+
]
|
272 |
+
)
|
273 |
+
)
|
274 |
+
self.ln_2 = LayerNorm(d_model)
|
275 |
+
|
276 |
+
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
277 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
278 |
+
|
279 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
280 |
+
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
|
281 |
+
x = x + self.mlp(self.ln_2(x))
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
class Transformer(nn.Module):
|
286 |
+
def __init__(
|
287 |
+
self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
|
288 |
+
):
|
289 |
+
super().__init__()
|
290 |
+
self.width = width
|
291 |
+
self.layers = layers
|
292 |
+
self.resblocks = nn.ModuleList(
|
293 |
+
[
|
294 |
+
ResidualAttentionBlock(width, heads, act_layer=act_layer)
|
295 |
+
for _ in range(layers)
|
296 |
+
]
|
297 |
+
)
|
298 |
+
|
299 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
300 |
+
for r in self.resblocks:
|
301 |
+
x = r(x, attn_mask=attn_mask)
|
302 |
+
return x
|
303 |
+
|
304 |
+
|
305 |
+
class VisualTransformer(nn.Module):
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
image_size: int,
|
309 |
+
patch_size: int,
|
310 |
+
width: int,
|
311 |
+
layers: int,
|
312 |
+
heads: int,
|
313 |
+
output_dim: int,
|
314 |
+
act_layer: Callable = nn.GELU,
|
315 |
+
):
|
316 |
+
super().__init__()
|
317 |
+
self.image_size = image_size
|
318 |
+
self.output_dim = output_dim
|
319 |
+
self.conv1 = nn.Conv2d(
|
320 |
+
in_channels=3,
|
321 |
+
out_channels=width,
|
322 |
+
kernel_size=patch_size,
|
323 |
+
stride=patch_size,
|
324 |
+
bias=False,
|
325 |
+
)
|
326 |
+
|
327 |
+
scale = width**-0.5
|
328 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
329 |
+
self.positional_embedding = nn.Parameter(
|
330 |
+
scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
|
331 |
+
)
|
332 |
+
self.ln_pre = LayerNorm(width)
|
333 |
+
|
334 |
+
self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
|
335 |
+
|
336 |
+
self.ln_post = LayerNorm(width)
|
337 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
338 |
+
|
339 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
340 |
+
assert (
|
341 |
+
unlocked_groups == 0
|
342 |
+
), "partial locking not currently supported for this model"
|
343 |
+
for param in self.parameters():
|
344 |
+
param.requires_grad = False
|
345 |
+
|
346 |
+
def forward(self, x: torch.Tensor):
|
347 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
348 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
349 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
350 |
+
x = torch.cat(
|
351 |
+
[
|
352 |
+
self.class_embedding.to(x.dtype)
|
353 |
+
+ torch.zeros(
|
354 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
355 |
+
),
|
356 |
+
x,
|
357 |
+
],
|
358 |
+
dim=1,
|
359 |
+
) # shape = [*, grid ** 2 + 1, width]
|
360 |
+
x = x + self.positional_embedding.to(x.dtype)
|
361 |
+
x = self.ln_pre(x)
|
362 |
+
|
363 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
364 |
+
x = self.text_branch(x)
|
365 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
366 |
+
|
367 |
+
x = self.ln_post(x[:, 0, :])
|
368 |
+
|
369 |
+
if self.proj is not None:
|
370 |
+
x = x @ self.proj
|
371 |
+
|
372 |
+
return x
|
373 |
+
|
374 |
+
|
375 |
+
@dataclass
|
376 |
+
class CLAPVisionCfg:
|
377 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
378 |
+
width: int = 768
|
379 |
+
patch_size: int = 16
|
380 |
+
image_size: Union[Tuple[int, int], int] = 224
|
381 |
+
timm_model_name: str = (
|
382 |
+
None # a valid model name overrides layers, width, patch_size
|
383 |
+
)
|
384 |
+
timm_model_pretrained: bool = (
|
385 |
+
False # use (imagenet) pretrained weights for named model
|
386 |
+
)
|
387 |
+
timm_pool: str = (
|
388 |
+
"avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
389 |
+
)
|
390 |
+
timm_proj: str = (
|
391 |
+
"linear" # linear projection for timm model output ('linear', 'mlp', '')
|
392 |
+
)
|
393 |
+
|
394 |
+
|
395 |
+
# Audio Config Class
|
396 |
+
@dataclass
|
397 |
+
class CLAPAudioCfp:
|
398 |
+
model_type: str = "PANN"
|
399 |
+
model_name: str = "Cnn14"
|
400 |
+
sample_rate: int = 48000
|
401 |
+
# Param
|
402 |
+
audio_length: int = 1024
|
403 |
+
window_size: int = 1024
|
404 |
+
hop_size: int = 1024
|
405 |
+
fmin: int = 50
|
406 |
+
fmax: int = 14000
|
407 |
+
class_num: int = 527
|
408 |
+
mel_bins: int = 64
|
409 |
+
clip_samples: int = 480000
|
410 |
+
|
411 |
+
|
412 |
+
@dataclass
|
413 |
+
class CLAPTextCfg:
|
414 |
+
context_length: int
|
415 |
+
vocab_size: int
|
416 |
+
width: int
|
417 |
+
heads: int
|
418 |
+
layers: int
|
419 |
+
model_type: str
|
420 |
+
|
421 |
+
|
422 |
+
class CLAP(nn.Module):
|
423 |
+
def __init__(
|
424 |
+
self,
|
425 |
+
embed_dim: int,
|
426 |
+
audio_cfg: CLAPAudioCfp,
|
427 |
+
text_cfg: CLAPTextCfg,
|
428 |
+
quick_gelu: bool = False,
|
429 |
+
enable_fusion: bool = False,
|
430 |
+
fusion_type: str = "None",
|
431 |
+
joint_embed_shape: int = 512,
|
432 |
+
mlp_act: str = "relu",
|
433 |
+
):
|
434 |
+
super().__init__()
|
435 |
+
if isinstance(audio_cfg, dict):
|
436 |
+
audio_cfg = CLAPAudioCfp(**audio_cfg)
|
437 |
+
if isinstance(text_cfg, dict):
|
438 |
+
text_cfg = CLAPTextCfg(**text_cfg)
|
439 |
+
|
440 |
+
self.audio_cfg = audio_cfg
|
441 |
+
self.text_cfg = text_cfg
|
442 |
+
self.enable_fusion = enable_fusion
|
443 |
+
self.fusion_type = fusion_type
|
444 |
+
self.joint_embed_shape = joint_embed_shape
|
445 |
+
self.mlp_act = mlp_act
|
446 |
+
|
447 |
+
self.context_length = text_cfg.context_length
|
448 |
+
|
449 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
450 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
451 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
452 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
453 |
+
|
454 |
+
if mlp_act == "relu":
|
455 |
+
mlp_act_layer = nn.ReLU()
|
456 |
+
elif mlp_act == "gelu":
|
457 |
+
mlp_act_layer = nn.GELU()
|
458 |
+
else:
|
459 |
+
raise NotImplementedError
|
460 |
+
|
461 |
+
# audio branch
|
462 |
+
# audio branch parameters
|
463 |
+
if audio_cfg.model_type == "PANN":
|
464 |
+
self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
|
465 |
+
elif audio_cfg.model_type == "HTSAT":
|
466 |
+
self.audio_branch = create_htsat_model(
|
467 |
+
audio_cfg, enable_fusion, fusion_type
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
logging.error(f"Model config for {audio_cfg.model_type} not found")
|
471 |
+
raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
|
472 |
+
|
473 |
+
# text branch
|
474 |
+
# text branch parameters
|
475 |
+
if text_cfg.model_type == "transformer":
|
476 |
+
self.text_branch = Transformer(
|
477 |
+
width=text_cfg.width,
|
478 |
+
layers=text_cfg.layers,
|
479 |
+
heads=text_cfg.heads,
|
480 |
+
act_layer=act_layer,
|
481 |
+
)
|
482 |
+
self.vocab_size = text_cfg.vocab_size
|
483 |
+
self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
|
484 |
+
self.positional_embedding = nn.Parameter(
|
485 |
+
torch.empty(self.context_length, text_cfg.width)
|
486 |
+
)
|
487 |
+
self.ln_final = LayerNorm(text_cfg.width)
|
488 |
+
self.text_transform = MLPLayers(
|
489 |
+
units=[
|
490 |
+
self.joint_embed_shape,
|
491 |
+
self.joint_embed_shape,
|
492 |
+
self.joint_embed_shape,
|
493 |
+
],
|
494 |
+
dropout=0.1,
|
495 |
+
)
|
496 |
+
self.text_projection = nn.Sequential(
|
497 |
+
nn.Linear(text_cfg.width, self.joint_embed_shape),
|
498 |
+
mlp_act_layer,
|
499 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
500 |
+
)
|
501 |
+
elif text_cfg.model_type == "bert":
|
502 |
+
self.text_branch = BertModel.from_pretrained("bert-base-uncased")
|
503 |
+
self.text_transform = MLPLayers(
|
504 |
+
units=[
|
505 |
+
self.joint_embed_shape,
|
506 |
+
self.joint_embed_shape,
|
507 |
+
self.joint_embed_shape,
|
508 |
+
],
|
509 |
+
dropout=0.1,
|
510 |
+
)
|
511 |
+
self.text_projection = nn.Sequential(
|
512 |
+
nn.Linear(768, self.joint_embed_shape),
|
513 |
+
mlp_act_layer,
|
514 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
515 |
+
)
|
516 |
+
elif text_cfg.model_type == "roberta":
|
517 |
+
self.text_branch = RobertaModel.from_pretrained("roberta-base")
|
518 |
+
self.text_transform = MLPLayers(
|
519 |
+
units=[
|
520 |
+
self.joint_embed_shape,
|
521 |
+
self.joint_embed_shape,
|
522 |
+
self.joint_embed_shape,
|
523 |
+
],
|
524 |
+
dropout=0.1,
|
525 |
+
)
|
526 |
+
self.text_projection = nn.Sequential(
|
527 |
+
nn.Linear(768, self.joint_embed_shape),
|
528 |
+
mlp_act_layer,
|
529 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
530 |
+
)
|
531 |
+
elif text_cfg.model_type == "bart":
|
532 |
+
self.text_branch = BartModel.from_pretrained("facebook/bart-base")
|
533 |
+
self.text_transform = MLPLayers(
|
534 |
+
units=[
|
535 |
+
self.joint_embed_shape,
|
536 |
+
self.joint_embed_shape,
|
537 |
+
self.joint_embed_shape,
|
538 |
+
],
|
539 |
+
dropout=0.1,
|
540 |
+
)
|
541 |
+
self.text_projection = nn.Sequential(
|
542 |
+
nn.Linear(768, self.joint_embed_shape),
|
543 |
+
mlp_act_layer,
|
544 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
545 |
+
)
|
546 |
+
else:
|
547 |
+
logging.error(f"Model config for {text_cfg.model_type} not found")
|
548 |
+
raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
|
549 |
+
self.text_branch_type = text_cfg.model_type
|
550 |
+
# text branch parameters
|
551 |
+
|
552 |
+
# audio branch parameters
|
553 |
+
self.audio_transform = MLPLayers(
|
554 |
+
units=[
|
555 |
+
self.joint_embed_shape,
|
556 |
+
self.joint_embed_shape,
|
557 |
+
self.joint_embed_shape,
|
558 |
+
],
|
559 |
+
dropout=0.1,
|
560 |
+
)
|
561 |
+
|
562 |
+
# below here is text branch parameters
|
563 |
+
|
564 |
+
# ============================================================================================================
|
565 |
+
self.audio_projection = nn.Sequential(
|
566 |
+
nn.Linear(embed_dim, self.joint_embed_shape),
|
567 |
+
mlp_act_layer,
|
568 |
+
nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
|
569 |
+
)
|
570 |
+
|
571 |
+
self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
572 |
+
self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
573 |
+
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
|
574 |
+
|
575 |
+
self.init_text_branch_parameters()
|
576 |
+
|
577 |
+
def init_text_branch_parameters(self):
|
578 |
+
if self.text_branch_type == "transformer":
|
579 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
580 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
581 |
+
proj_std = (self.text_branch.width**-0.5) * (
|
582 |
+
(2 * self.text_branch.layers) ** -0.5
|
583 |
+
)
|
584 |
+
attn_std = self.text_branch.width**-0.5
|
585 |
+
fc_std = (2 * self.text_branch.width) ** -0.5
|
586 |
+
for block in self.text_branch.resblocks:
|
587 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
588 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
589 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
590 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
591 |
+
if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
|
592 |
+
width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
|
593 |
+
elif self.text_branch_type == "bart":
|
594 |
+
width = self.text_branch.shared.weight.shape[-1]
|
595 |
+
else:
|
596 |
+
width = self.text_branch.width
|
597 |
+
nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
|
598 |
+
nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
|
599 |
+
|
600 |
+
# deprecated
|
601 |
+
# if hasattr(self.visual, 'init_parameters'):
|
602 |
+
# self.visual.init_parameters()
|
603 |
+
|
604 |
+
# if self.text_projection is not None:
|
605 |
+
# nn.init.normal_(self.text_projection, std=width**-0.5)
|
606 |
+
|
607 |
+
def build_attention_mask(self):
|
608 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
609 |
+
# pytorch uses additive attention mask; fill with -inf
|
610 |
+
mask = torch.empty(self.context_length, self.context_length)
|
611 |
+
mask.fill_(float("-inf"))
|
612 |
+
mask.triu_(1) # zero out the lower diagonal
|
613 |
+
return mask
|
614 |
+
|
615 |
+
def encode_audio(self, audio, device):
|
616 |
+
return self.audio_branch(
|
617 |
+
audio, mixup_lambda=None, device=device
|
618 |
+
) # mix lambda needs to add
|
619 |
+
|
620 |
+
# def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
|
621 |
+
# tmp = {}
|
622 |
+
# for k in x[0].keys():
|
623 |
+
# tmp[k] = []
|
624 |
+
# for i in range(len(x)):
|
625 |
+
# tmp[k].append(x[i][k][:77])
|
626 |
+
# for k in x[0].keys():
|
627 |
+
# tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
|
628 |
+
# return tmp
|
629 |
+
|
630 |
+
def encode_text(self, text, device):
|
631 |
+
if self.text_branch_type == "transformer":
|
632 |
+
text = text.to(device=device, non_blocking=True)
|
633 |
+
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
|
634 |
+
|
635 |
+
x = x + self.positional_embedding
|
636 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
637 |
+
x = self.text_branch(x, attn_mask=self.attn_mask)
|
638 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
639 |
+
x = self.ln_final(x)
|
640 |
+
|
641 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
642 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
643 |
+
x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
|
644 |
+
elif self.text_branch_type == "bert":
|
645 |
+
# text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
|
646 |
+
# text = BatchEncoding(text)
|
647 |
+
x = self.text_branch(
|
648 |
+
input_ids=text["input_ids"].to(device=device, non_blocking=True),
|
649 |
+
attention_mask=text["attention_mask"].to(
|
650 |
+
device=device, non_blocking=True
|
651 |
+
),
|
652 |
+
token_type_ids=text["token_type_ids"].to(
|
653 |
+
device=device, non_blocking=True
|
654 |
+
),
|
655 |
+
)["pooler_output"]
|
656 |
+
x = self.text_projection(x)
|
657 |
+
elif self.text_branch_type == "roberta":
|
658 |
+
x = self.text_branch(
|
659 |
+
input_ids=text["input_ids"].to(device=device, non_blocking=True),
|
660 |
+
attention_mask=text["attention_mask"].to(
|
661 |
+
device=device, non_blocking=True
|
662 |
+
),
|
663 |
+
)["pooler_output"]
|
664 |
+
x = self.text_projection(x)
|
665 |
+
elif self.text_branch_type == "bart":
|
666 |
+
x = torch.mean(
|
667 |
+
self.text_branch(
|
668 |
+
input_ids=text["input_ids"].to(device=device, non_blocking=True),
|
669 |
+
attention_mask=text["attention_mask"].to(
|
670 |
+
device=device, non_blocking=True
|
671 |
+
),
|
672 |
+
)["encoder_last_hidden_state"],
|
673 |
+
axis=1,
|
674 |
+
)
|
675 |
+
x = self.text_projection(x)
|
676 |
+
else:
|
677 |
+
logging.error(f"Model type {self.text_branch_type} not found")
|
678 |
+
raise RuntimeError(f"Model type {self.text_branch_type} not found.")
|
679 |
+
return x
|
680 |
+
|
681 |
+
def forward(self, audio, text, device=None):
|
682 |
+
"""Forward audio and text into the CLAP
|
683 |
+
|
684 |
+
Parameters
|
685 |
+
----------
|
686 |
+
audio: torch.Tensor (batch_size, audio_length)
|
687 |
+
the time-domain audio input / the batch of mel_spec and longer list.
|
688 |
+
text: torch.Tensor () // need to add
|
689 |
+
the text token input
|
690 |
+
"""
|
691 |
+
if device is None:
|
692 |
+
if audio is not None:
|
693 |
+
device = audio.device
|
694 |
+
elif text is not None:
|
695 |
+
device = text.device
|
696 |
+
if audio is None and text is None:
|
697 |
+
# a hack to get the logit scale
|
698 |
+
return self.logit_scale_a.exp(), self.logit_scale_t.exp()
|
699 |
+
elif audio is None:
|
700 |
+
return self.encode_text(text, device=device)
|
701 |
+
elif text is None:
|
702 |
+
return self.audio_projection(
|
703 |
+
self.encode_audio(audio, device=device)["embedding"]
|
704 |
+
)
|
705 |
+
audio_features = self.audio_projection(
|
706 |
+
self.encode_audio(audio, device=device)["embedding"]
|
707 |
+
)
|
708 |
+
audio_features = F.normalize(audio_features, dim=-1)
|
709 |
+
|
710 |
+
text_features = self.encode_text(text, device=device)
|
711 |
+
# print("text_features", text_features)
|
712 |
+
# print("text_features.shape", text_features.shape)
|
713 |
+
# print("text_features.type", type(text_features))
|
714 |
+
text_features = F.normalize(text_features, dim=-1)
|
715 |
+
|
716 |
+
audio_features_mlp = self.audio_transform(audio_features)
|
717 |
+
text_features_mlp = self.text_transform(text_features)
|
718 |
+
# Four outputs: audio features (basic & MLP), text features (basic & MLP)
|
719 |
+
return (
|
720 |
+
audio_features,
|
721 |
+
text_features,
|
722 |
+
audio_features_mlp,
|
723 |
+
text_features_mlp,
|
724 |
+
self.logit_scale_a.exp(),
|
725 |
+
self.logit_scale_t.exp(),
|
726 |
+
)
|
727 |
+
|
728 |
+
def get_logit_scale(self):
|
729 |
+
return self.logit_scale_a.exp(), self.logit_scale_t.exp()
|
730 |
+
|
731 |
+
def get_text_embedding(self, data):
|
732 |
+
"""Get the text embedding from the model
|
733 |
+
|
734 |
+
Parameters
|
735 |
+
----------
|
736 |
+
data: torch.Tensor
|
737 |
+
a tensor of text embedding
|
738 |
+
|
739 |
+
Returns
|
740 |
+
----------
|
741 |
+
text_embed: torch.Tensor
|
742 |
+
a tensor of text_embeds (N, D)
|
743 |
+
|
744 |
+
"""
|
745 |
+
device = next(self.parameters()).device
|
746 |
+
for k in data:
|
747 |
+
data[k] = data[k].to(device)
|
748 |
+
if(len(data[k].size()) < 2):
|
749 |
+
data[k] = data[k].unsqueeze(0)
|
750 |
+
text_embeds = self.encode_text(data, device=device)
|
751 |
+
text_embeds = F.normalize(text_embeds, dim=-1)
|
752 |
+
|
753 |
+
return text_embeds
|
754 |
+
|
755 |
+
def get_audio_embedding(self, data):
|
756 |
+
"""Get the audio embedding from the model
|
757 |
+
|
758 |
+
Parameters
|
759 |
+
----------
|
760 |
+
data: a list of dict
|
761 |
+
the audio input dict list from 'get_audio_feature' method
|
762 |
+
|
763 |
+
Returns
|
764 |
+
----------
|
765 |
+
audio_embed: torch.Tensor
|
766 |
+
a tensor of audio_embeds (N, D)
|
767 |
+
|
768 |
+
"""
|
769 |
+
device = next(self.parameters()).device
|
770 |
+
input_dict = {}
|
771 |
+
keys = data[0].keys()
|
772 |
+
for k in keys:
|
773 |
+
input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
|
774 |
+
device
|
775 |
+
)
|
776 |
+
|
777 |
+
audio_embeds = self.audio_projection(
|
778 |
+
self.encode_audio(input_dict, device=device)["embedding"]
|
779 |
+
)
|
780 |
+
audio_embeds = F.normalize(audio_embeds, dim=-1)
|
781 |
+
|
782 |
+
return audio_embeds
|
783 |
+
|
784 |
+
def audio_infer(self, audio, hopsize=None, device=None):
|
785 |
+
"""Forward one audio and produce the audio embedding
|
786 |
+
|
787 |
+
Parameters
|
788 |
+
----------
|
789 |
+
audio: (audio_length)
|
790 |
+
the time-domain audio input, notice that it must be only one input
|
791 |
+
hopsize: int
|
792 |
+
the overlap hopsize as the sliding window
|
793 |
+
|
794 |
+
Returns
|
795 |
+
----------
|
796 |
+
output_dict: {
|
797 |
+
key: [n, (embedding_shape)] if "HTS-AT"
|
798 |
+
or
|
799 |
+
key: [(embedding_shape)] if "PANN"
|
800 |
+
}
|
801 |
+
the list of key values of the audio branch
|
802 |
+
|
803 |
+
"""
|
804 |
+
|
805 |
+
assert not self.training, "the inference mode must be run at eval stage"
|
806 |
+
output_dict = {}
|
807 |
+
# PANN
|
808 |
+
if self.audio_cfg.model_type == "PANN":
|
809 |
+
audio_input = audio.unsqueeze(dim=0)
|
810 |
+
output_dict[key] = self.encode_audio(audio_input, device=device)[
|
811 |
+
key
|
812 |
+
].squeeze(dim=0)
|
813 |
+
elif self.audio_cfg.model_type == "HTSAT":
|
814 |
+
# repeat
|
815 |
+
audio_len = len(audio)
|
816 |
+
k = self.audio_cfg.clip_samples // audio_len
|
817 |
+
if k > 1:
|
818 |
+
audio = audio.repeat(k)
|
819 |
+
audio_len = len(audio)
|
820 |
+
|
821 |
+
if hopsize is None:
|
822 |
+
hopsize = min(hopsize, audio_len)
|
823 |
+
|
824 |
+
if audio_len > self.audio_cfg.clip_samples:
|
825 |
+
audio_input = [
|
826 |
+
audio[pos : pos + self.audio_cfg.clip_samples].clone()
|
827 |
+
for pos in range(
|
828 |
+
0, audio_len - self.audio_cfg.clip_samples, hopsize
|
829 |
+
)
|
830 |
+
]
|
831 |
+
audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
|
832 |
+
audio_input = torch.stack(audio_input)
|
833 |
+
output_dict[key] = self.encode_audio(audio_input, device=device)[key]
|
834 |
+
else:
|
835 |
+
audio_input = audio.unsqueeze(dim=0)
|
836 |
+
output_dict[key] = self.encode_audio(audio_input, device=device)[
|
837 |
+
key
|
838 |
+
].squeeze(dim=0)
|
839 |
+
|
840 |
+
return output_dict
|
841 |
+
|
842 |
+
|
843 |
+
def convert_weights_to_fp16(model: nn.Module):
|
844 |
+
"""Convert applicable model parameters to fp16"""
|
845 |
+
|
846 |
+
def _convert_weights_to_fp16(l):
|
847 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
848 |
+
l.weight.data = l.weight.data.half()
|
849 |
+
if l.bias is not None:
|
850 |
+
l.bias.data = l.bias.data.half()
|
851 |
+
|
852 |
+
if isinstance(l, nn.MultiheadAttention):
|
853 |
+
for attr in [
|
854 |
+
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
|
855 |
+
"in_proj_bias",
|
856 |
+
"bias_k",
|
857 |
+
"bias_v",
|
858 |
+
]:
|
859 |
+
tensor = getattr(l, attr)
|
860 |
+
if tensor is not None:
|
861 |
+
tensor.data = tensor.data.half()
|
862 |
+
|
863 |
+
for name in ["text_projection", "proj"]:
|
864 |
+
if hasattr(l, name):
|
865 |
+
attr = getattr(l, name)
|
866 |
+
if attr is not None:
|
867 |
+
attr.data = attr.data.half()
|
868 |
+
|
869 |
+
model.apply(_convert_weights_to_fp16)
|
870 |
+
|
871 |
+
|
872 |
+
# Ignore the state dict of the vision part
|
873 |
+
def build_model_from_openai_state_dict(
|
874 |
+
state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
|
875 |
+
):
|
876 |
+
|
877 |
+
embed_dim = model_cfg["embed_dim"]
|
878 |
+
audio_cfg = model_cfg["audio_cfg"]
|
879 |
+
text_cfg = model_cfg["text_cfg"]
|
880 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
881 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
882 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
883 |
+
transformer_heads = transformer_width // 64
|
884 |
+
transformer_layers = len(
|
885 |
+
set(
|
886 |
+
k.split(".")[2]
|
887 |
+
for k in state_dict
|
888 |
+
if k.startswith(f"transformer.resblocks")
|
889 |
+
)
|
890 |
+
)
|
891 |
+
|
892 |
+
audio_cfg = CLAPAudioCfp(**audio_cfg)
|
893 |
+
text_cfg = CLAPTextCfg(**text_cfg)
|
894 |
+
|
895 |
+
model = CLAP(
|
896 |
+
embed_dim,
|
897 |
+
audio_cfg=audio_cfg,
|
898 |
+
text_cfg=text_cfg,
|
899 |
+
quick_gelu=True, # OpenAI models were trained with QuickGELU
|
900 |
+
enable_fusion=enable_fusion,
|
901 |
+
fusion_type=fusion_type,
|
902 |
+
)
|
903 |
+
state_dict["logit_scale_a"] = state_dict["logit_scale"]
|
904 |
+
state_dict["logit_scale_t"] = state_dict["logit_scale"]
|
905 |
+
pop_keys = list(state_dict.keys())[::]
|
906 |
+
# pop the visual branch saved weights
|
907 |
+
for key in pop_keys:
|
908 |
+
if key.startswith("visual."):
|
909 |
+
state_dict.pop(key, None)
|
910 |
+
|
911 |
+
for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
|
912 |
+
state_dict.pop(key, None)
|
913 |
+
|
914 |
+
# not use fp16
|
915 |
+
# convert_weights_to_fp16(model)
|
916 |
+
model.load_state_dict(state_dict, strict=False)
|
917 |
+
return model.eval()
|
918 |
+
|
919 |
+
|
920 |
+
def trace_model(model, batch_size=256, device=torch.device("cpu")):
|
921 |
+
model.eval()
|
922 |
+
audio_length = model.audio_cfg.audio_length
|
923 |
+
example_audio = torch.ones((batch_size, audio_length), device=device)
|
924 |
+
example_text = torch.zeros(
|
925 |
+
(batch_size, model.context_length), dtype=torch.int, device=device
|
926 |
+
)
|
927 |
+
model = torch.jit.trace_module(
|
928 |
+
model,
|
929 |
+
inputs=dict(
|
930 |
+
forward=(example_audio, example_text),
|
931 |
+
encode_text=(example_text,),
|
932 |
+
encode_image=(example_audio,),
|
933 |
+
),
|
934 |
+
)
|
935 |
+
model.audio_cfg.audio_length = audio_length # Question: what does this do?
|
936 |
+
return model
|
audioldm/clap/open_clip/model_configs/HTSAT-base.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "HTSAT",
|
14 |
+
"model_name": "base"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/HTSAT-large.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "HTSAT",
|
14 |
+
"model_name": "large"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1536,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "HTSAT",
|
14 |
+
"model_name": "tiny"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/HTSAT-tiny.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "HTSAT",
|
14 |
+
"model_name": "tiny"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/PANN-10.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn10"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 18000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 960000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 360,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 8000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 4
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1536,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/PANN-14.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 2048,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn14"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/PANN-6.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"audio_cfg": {
|
4 |
+
"audio_length": 1024,
|
5 |
+
"clip_samples": 480000,
|
6 |
+
"mel_bins": 64,
|
7 |
+
"sample_rate": 48000,
|
8 |
+
"window_size": 1024,
|
9 |
+
"hop_size": 480,
|
10 |
+
"fmin": 50,
|
11 |
+
"fmax": 14000,
|
12 |
+
"class_num": 527,
|
13 |
+
"model_type": "PANN",
|
14 |
+
"model_name": "Cnn6"
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 512,
|
20 |
+
"heads": 8,
|
21 |
+
"layers": 12
|
22 |
+
}
|
23 |
+
}
|
audioldm/clap/open_clip/model_configs/RN101-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
23,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
audioldm/clap/open_clip/model_configs/RN101.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
23,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
audioldm/clap/open_clip/model_configs/RN50-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
6,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
audioldm/clap/open_clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
audioldm/clap/open_clip/model_configs/RN50x16.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 384,
|
5 |
+
"layers": [
|
6 |
+
6,
|
7 |
+
8,
|
8 |
+
18,
|
9 |
+
8
|
10 |
+
],
|
11 |
+
"width": 96,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 768,
|
18 |
+
"heads": 12,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
audioldm/clap/open_clip/model_configs/RN50x4.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 288,
|
5 |
+
"layers": [
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"width": 80,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 640,
|
18 |
+
"heads": 10,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
audioldm/clap/open_clip/model_configs/ViT-B-16.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": 12,
|
7 |
+
"width": 768,
|
8 |
+
"patch_size": 32
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
}
|
17 |
+
}
|
audioldm/clap/open_clip/model_configs/ViT-B-32.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
audioldm/clap/open_clip/model_configs/ViT-L-14.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
audioldm/clap/open_clip/openai.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" OpenAI pretrained model functions
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import warnings
|
8 |
+
from typing import Union, List
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .model import build_model_from_openai_state_dict
|
13 |
+
from .pretrained import (
|
14 |
+
get_pretrained_url,
|
15 |
+
list_pretrained_tag_models,
|
16 |
+
download_pretrained,
|
17 |
+
)
|
18 |
+
|
19 |
+
__all__ = ["list_openai_models", "load_openai_model"]
|
20 |
+
|
21 |
+
|
22 |
+
def list_openai_models() -> List[str]:
|
23 |
+
"""Returns the names of available CLIP models"""
|
24 |
+
return list_pretrained_tag_models("openai")
|
25 |
+
|
26 |
+
|
27 |
+
def load_openai_model(
|
28 |
+
name: str,
|
29 |
+
model_cfg,
|
30 |
+
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
31 |
+
jit=True,
|
32 |
+
cache_dir=os.path.expanduser("~/.cache/clip"),
|
33 |
+
enable_fusion: bool = False,
|
34 |
+
fusion_type: str = "None",
|
35 |
+
):
|
36 |
+
"""Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
name : str
|
41 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
42 |
+
device : Union[str, torch.device]
|
43 |
+
The device to put the loaded model
|
44 |
+
jit : bool
|
45 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
model : torch.nn.Module
|
50 |
+
The CLAP model
|
51 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
52 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
53 |
+
"""
|
54 |
+
if get_pretrained_url(name, "openai"):
|
55 |
+
model_path = download_pretrained(
|
56 |
+
get_pretrained_url(name, "openai"), root=cache_dir
|
57 |
+
)
|
58 |
+
elif os.path.isfile(name):
|
59 |
+
model_path = name
|
60 |
+
else:
|
61 |
+
raise RuntimeError(
|
62 |
+
f"Model {name} not found; available models = {list_openai_models()}"
|
63 |
+
)
|
64 |
+
|
65 |
+
try:
|
66 |
+
# loading JIT archive
|
67 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
68 |
+
state_dict = None
|
69 |
+
except RuntimeError:
|
70 |
+
# loading saved state dict
|
71 |
+
if jit:
|
72 |
+
warnings.warn(
|
73 |
+
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
|
74 |
+
)
|
75 |
+
jit = False
|
76 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
77 |
+
|
78 |
+
if not jit:
|
79 |
+
try:
|
80 |
+
model = build_model_from_openai_state_dict(
|
81 |
+
state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
|
82 |
+
).to(device)
|
83 |
+
except KeyError:
|
84 |
+
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
85 |
+
model = build_model_from_openai_state_dict(
|
86 |
+
sd, model_cfg, enable_fusion, fusion_type
|
87 |
+
).to(device)
|
88 |
+
|
89 |
+
if str(device) == "cpu":
|
90 |
+
model.float()
|
91 |
+
return model
|
92 |
+
|
93 |
+
# patch the device names
|
94 |
+
device_holder = torch.jit.trace(
|
95 |
+
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
|
96 |
+
)
|
97 |
+
device_node = [
|
98 |
+
n
|
99 |
+
for n in device_holder.graph.findAllNodes("prim::Constant")
|
100 |
+
if "Device" in repr(n)
|
101 |
+
][-1]
|
102 |
+
|
103 |
+
def patch_device(module):
|
104 |
+
try:
|
105 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
106 |
+
except RuntimeError:
|
107 |
+
graphs = []
|
108 |
+
|
109 |
+
if hasattr(module, "forward1"):
|
110 |
+
graphs.append(module.forward1.graph)
|
111 |
+
|
112 |
+
for graph in graphs:
|
113 |
+
for node in graph.findAllNodes("prim::Constant"):
|
114 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith(
|
115 |
+
"cuda"
|
116 |
+
):
|
117 |
+
node.copyAttributes(device_node)
|
118 |
+
|
119 |
+
model.apply(patch_device)
|
120 |
+
patch_device(model.encode_audio)
|
121 |
+
patch_device(model.encode_text)
|
122 |
+
|
123 |
+
# patch dtype to float32 on CPU
|
124 |
+
if str(device) == "cpu":
|
125 |
+
float_holder = torch.jit.trace(
|
126 |
+
lambda: torch.ones([]).float(), example_inputs=[]
|
127 |
+
)
|
128 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
129 |
+
float_node = float_input.node()
|
130 |
+
|
131 |
+
def patch_float(module):
|
132 |
+
try:
|
133 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
134 |
+
except RuntimeError:
|
135 |
+
graphs = []
|
136 |
+
|
137 |
+
if hasattr(module, "forward1"):
|
138 |
+
graphs.append(module.forward1.graph)
|
139 |
+
|
140 |
+
for graph in graphs:
|
141 |
+
for node in graph.findAllNodes("aten::to"):
|
142 |
+
inputs = list(node.inputs())
|
143 |
+
for i in [
|
144 |
+
1,
|
145 |
+
2,
|
146 |
+
]: # dtype can be the second or third argument to aten::to()
|
147 |
+
if inputs[i].node()["value"] == 5:
|
148 |
+
inputs[i].node().copyAttributes(float_node)
|
149 |
+
|
150 |
+
model.apply(patch_float)
|
151 |
+
patch_float(model.encode_audio)
|
152 |
+
patch_float(model.encode_text)
|
153 |
+
model.float()
|
154 |
+
|
155 |
+
model.audio_branch.audio_length = model.audio_cfg.audio_length
|
156 |
+
return model
|
audioldm/clap/open_clip/pann_model.py
ADDED
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
|
2 |
+
# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
|
3 |
+
# Some layers are re-designed for CLAP
|
4 |
+
import os
|
5 |
+
|
6 |
+
os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
12 |
+
from torchlibrosa.augmentation import SpecAugmentation
|
13 |
+
|
14 |
+
from .utils import do_mixup, interpolate, pad_framewise_output
|
15 |
+
from .feature_fusion import iAFF, AFF, DAF
|
16 |
+
|
17 |
+
|
18 |
+
def init_layer(layer):
|
19 |
+
"""Initialize a Linear or Convolutional layer."""
|
20 |
+
nn.init.xavier_uniform_(layer.weight)
|
21 |
+
|
22 |
+
if hasattr(layer, "bias"):
|
23 |
+
if layer.bias is not None:
|
24 |
+
layer.bias.data.fill_(0.0)
|
25 |
+
|
26 |
+
def init_bn(bn):
|
27 |
+
"""Initialize a Batchnorm layer."""
|
28 |
+
bn.bias.data.fill_(0.0)
|
29 |
+
bn.weight.data.fill_(1.0)
|
30 |
+
|
31 |
+
|
32 |
+
class ConvBlock(nn.Module):
|
33 |
+
def __init__(self, in_channels, out_channels):
|
34 |
+
|
35 |
+
super(ConvBlock, self).__init__()
|
36 |
+
|
37 |
+
self.conv1 = nn.Conv2d(
|
38 |
+
in_channels=in_channels,
|
39 |
+
out_channels=out_channels,
|
40 |
+
kernel_size=(3, 3),
|
41 |
+
stride=(1, 1),
|
42 |
+
padding=(1, 1),
|
43 |
+
bias=False,
|
44 |
+
)
|
45 |
+
|
46 |
+
self.conv2 = nn.Conv2d(
|
47 |
+
in_channels=out_channels,
|
48 |
+
out_channels=out_channels,
|
49 |
+
kernel_size=(3, 3),
|
50 |
+
stride=(1, 1),
|
51 |
+
padding=(1, 1),
|
52 |
+
bias=False,
|
53 |
+
)
|
54 |
+
|
55 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
56 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
57 |
+
|
58 |
+
self.init_weight()
|
59 |
+
|
60 |
+
def init_weight(self):
|
61 |
+
init_layer(self.conv1)
|
62 |
+
init_layer(self.conv2)
|
63 |
+
init_bn(self.bn1)
|
64 |
+
init_bn(self.bn2)
|
65 |
+
|
66 |
+
def forward(self, input, pool_size=(2, 2), pool_type="avg"):
|
67 |
+
|
68 |
+
x = input
|
69 |
+
x = F.relu_(self.bn1(self.conv1(x)))
|
70 |
+
x = F.relu_(self.bn2(self.conv2(x)))
|
71 |
+
if pool_type == "max":
|
72 |
+
x = F.max_pool2d(x, kernel_size=pool_size)
|
73 |
+
elif pool_type == "avg":
|
74 |
+
x = F.avg_pool2d(x, kernel_size=pool_size)
|
75 |
+
elif pool_type == "avg+max":
|
76 |
+
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
77 |
+
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
78 |
+
x = x1 + x2
|
79 |
+
else:
|
80 |
+
raise Exception("Incorrect argument!")
|
81 |
+
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class ConvBlock5x5(nn.Module):
|
86 |
+
def __init__(self, in_channels, out_channels):
|
87 |
+
|
88 |
+
super(ConvBlock5x5, self).__init__()
|
89 |
+
|
90 |
+
self.conv1 = nn.Conv2d(
|
91 |
+
in_channels=in_channels,
|
92 |
+
out_channels=out_channels,
|
93 |
+
kernel_size=(5, 5),
|
94 |
+
stride=(1, 1),
|
95 |
+
padding=(2, 2),
|
96 |
+
bias=False,
|
97 |
+
)
|
98 |
+
|
99 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
100 |
+
|
101 |
+
self.init_weight()
|
102 |
+
|
103 |
+
def init_weight(self):
|
104 |
+
init_layer(self.conv1)
|
105 |
+
init_bn(self.bn1)
|
106 |
+
|
107 |
+
def forward(self, input, pool_size=(2, 2), pool_type="avg"):
|
108 |
+
|
109 |
+
x = input
|
110 |
+
x = F.relu_(self.bn1(self.conv1(x)))
|
111 |
+
if pool_type == "max":
|
112 |
+
x = F.max_pool2d(x, kernel_size=pool_size)
|
113 |
+
elif pool_type == "avg":
|
114 |
+
x = F.avg_pool2d(x, kernel_size=pool_size)
|
115 |
+
elif pool_type == "avg+max":
|
116 |
+
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
117 |
+
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
118 |
+
x = x1 + x2
|
119 |
+
else:
|
120 |
+
raise Exception("Incorrect argument!")
|
121 |
+
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class AttBlock(nn.Module):
|
126 |
+
def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
|
127 |
+
super(AttBlock, self).__init__()
|
128 |
+
|
129 |
+
self.activation = activation
|
130 |
+
self.temperature = temperature
|
131 |
+
self.att = nn.Conv1d(
|
132 |
+
in_channels=n_in,
|
133 |
+
out_channels=n_out,
|
134 |
+
kernel_size=1,
|
135 |
+
stride=1,
|
136 |
+
padding=0,
|
137 |
+
bias=True,
|
138 |
+
)
|
139 |
+
self.cla = nn.Conv1d(
|
140 |
+
in_channels=n_in,
|
141 |
+
out_channels=n_out,
|
142 |
+
kernel_size=1,
|
143 |
+
stride=1,
|
144 |
+
padding=0,
|
145 |
+
bias=True,
|
146 |
+
)
|
147 |
+
|
148 |
+
self.bn_att = nn.BatchNorm1d(n_out)
|
149 |
+
self.init_weights()
|
150 |
+
|
151 |
+
def init_weights(self):
|
152 |
+
init_layer(self.att)
|
153 |
+
init_layer(self.cla)
|
154 |
+
init_bn(self.bn_att)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
# x: (n_samples, n_in, n_time)
|
158 |
+
norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
|
159 |
+
cla = self.nonlinear_transform(self.cla(x))
|
160 |
+
x = torch.sum(norm_att * cla, dim=2)
|
161 |
+
return x, norm_att, cla
|
162 |
+
|
163 |
+
def nonlinear_transform(self, x):
|
164 |
+
if self.activation == "linear":
|
165 |
+
return x
|
166 |
+
elif self.activation == "sigmoid":
|
167 |
+
return torch.sigmoid(x)
|
168 |
+
|
169 |
+
|
170 |
+
class Cnn14(nn.Module):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
sample_rate,
|
174 |
+
window_size,
|
175 |
+
hop_size,
|
176 |
+
mel_bins,
|
177 |
+
fmin,
|
178 |
+
fmax,
|
179 |
+
classes_num,
|
180 |
+
enable_fusion=False,
|
181 |
+
fusion_type="None",
|
182 |
+
):
|
183 |
+
|
184 |
+
super(Cnn14, self).__init__()
|
185 |
+
|
186 |
+
window = "hann"
|
187 |
+
center = True
|
188 |
+
pad_mode = "reflect"
|
189 |
+
ref = 1.0
|
190 |
+
amin = 1e-10
|
191 |
+
top_db = None
|
192 |
+
|
193 |
+
self.enable_fusion = enable_fusion
|
194 |
+
self.fusion_type = fusion_type
|
195 |
+
|
196 |
+
# Spectrogram extractor
|
197 |
+
self.spectrogram_extractor = Spectrogram(
|
198 |
+
n_fft=window_size,
|
199 |
+
hop_length=hop_size,
|
200 |
+
win_length=window_size,
|
201 |
+
window=window,
|
202 |
+
center=center,
|
203 |
+
pad_mode=pad_mode,
|
204 |
+
freeze_parameters=True,
|
205 |
+
)
|
206 |
+
|
207 |
+
# Logmel feature extractor
|
208 |
+
self.logmel_extractor = LogmelFilterBank(
|
209 |
+
sr=sample_rate,
|
210 |
+
n_fft=window_size,
|
211 |
+
n_mels=mel_bins,
|
212 |
+
fmin=fmin,
|
213 |
+
fmax=fmax,
|
214 |
+
ref=ref,
|
215 |
+
amin=amin,
|
216 |
+
top_db=top_db,
|
217 |
+
freeze_parameters=True,
|
218 |
+
)
|
219 |
+
|
220 |
+
# Spec augmenter
|
221 |
+
self.spec_augmenter = SpecAugmentation(
|
222 |
+
time_drop_width=64,
|
223 |
+
time_stripes_num=2,
|
224 |
+
freq_drop_width=8,
|
225 |
+
freq_stripes_num=2,
|
226 |
+
)
|
227 |
+
|
228 |
+
self.bn0 = nn.BatchNorm2d(64)
|
229 |
+
|
230 |
+
if (self.enable_fusion) and (self.fusion_type == "channel_map"):
|
231 |
+
self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
|
232 |
+
else:
|
233 |
+
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
|
234 |
+
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
|
235 |
+
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
|
236 |
+
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
|
237 |
+
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
|
238 |
+
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
239 |
+
|
240 |
+
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
241 |
+
self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
|
242 |
+
|
243 |
+
if (self.enable_fusion) and (
|
244 |
+
self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
|
245 |
+
):
|
246 |
+
self.mel_conv1d = nn.Sequential(
|
247 |
+
nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
|
248 |
+
nn.BatchNorm1d(64), # No Relu
|
249 |
+
)
|
250 |
+
if self.fusion_type == "daf_1d":
|
251 |
+
self.fusion_model = DAF()
|
252 |
+
elif self.fusion_type == "aff_1d":
|
253 |
+
self.fusion_model = AFF(channels=64, type="1D")
|
254 |
+
elif self.fusion_type == "iaff_1d":
|
255 |
+
self.fusion_model = iAFF(channels=64, type="1D")
|
256 |
+
|
257 |
+
if (self.enable_fusion) and (
|
258 |
+
self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
|
259 |
+
):
|
260 |
+
self.mel_conv2d = nn.Sequential(
|
261 |
+
nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
|
262 |
+
nn.BatchNorm2d(64),
|
263 |
+
nn.ReLU(inplace=True),
|
264 |
+
)
|
265 |
+
|
266 |
+
if self.fusion_type == "daf_2d":
|
267 |
+
self.fusion_model = DAF()
|
268 |
+
elif self.fusion_type == "aff_2d":
|
269 |
+
self.fusion_model = AFF(channels=64, type="2D")
|
270 |
+
elif self.fusion_type == "iaff_2d":
|
271 |
+
self.fusion_model = iAFF(channels=64, type="2D")
|
272 |
+
self.init_weight()
|
273 |
+
|
274 |
+
def init_weight(self):
|
275 |
+
init_bn(self.bn0)
|
276 |
+
init_layer(self.fc1)
|
277 |
+
init_layer(self.fc_audioset)
|
278 |
+
|
279 |
+
def forward(self, input, mixup_lambda=None, device=None):
|
280 |
+
"""
|
281 |
+
Input: (batch_size, data_length)"""
|
282 |
+
|
283 |
+
if self.enable_fusion and input["longer"].sum() == 0:
|
284 |
+
# if no audio is longer than 10s, then randomly select one audio to be longer
|
285 |
+
input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
|
286 |
+
|
287 |
+
if not self.enable_fusion:
|
288 |
+
x = self.spectrogram_extractor(
|
289 |
+
input["waveform"].to(device=device, non_blocking=True)
|
290 |
+
) # (batch_size, 1, time_steps, freq_bins)
|
291 |
+
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
|
292 |
+
|
293 |
+
x = x.transpose(1, 3)
|
294 |
+
x = self.bn0(x)
|
295 |
+
x = x.transpose(1, 3)
|
296 |
+
else:
|
297 |
+
longer_list = input["longer"].to(device=device, non_blocking=True)
|
298 |
+
x = input["mel_fusion"].to(device=device, non_blocking=True)
|
299 |
+
longer_list_idx = torch.where(longer_list)[0]
|
300 |
+
x = x.transpose(1, 3)
|
301 |
+
x = self.bn0(x)
|
302 |
+
x = x.transpose(1, 3)
|
303 |
+
if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
|
304 |
+
new_x = x[:, 0:1, :, :].clone().contiguous()
|
305 |
+
# local processing
|
306 |
+
if len(longer_list_idx) > 0:
|
307 |
+
fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
|
308 |
+
FB, FC, FT, FF = fusion_x_local.size()
|
309 |
+
fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
|
310 |
+
fusion_x_local = torch.permute(
|
311 |
+
fusion_x_local, (0, 2, 1)
|
312 |
+
).contiguous()
|
313 |
+
fusion_x_local = self.mel_conv1d(fusion_x_local)
|
314 |
+
fusion_x_local = fusion_x_local.view(
|
315 |
+
FB, FC, FF, fusion_x_local.size(-1)
|
316 |
+
)
|
317 |
+
fusion_x_local = (
|
318 |
+
torch.permute(fusion_x_local, (0, 2, 1, 3))
|
319 |
+
.contiguous()
|
320 |
+
.flatten(2)
|
321 |
+
)
|
322 |
+
if fusion_x_local.size(-1) < FT:
|
323 |
+
fusion_x_local = torch.cat(
|
324 |
+
[
|
325 |
+
fusion_x_local,
|
326 |
+
torch.zeros(
|
327 |
+
(FB, FF, FT - fusion_x_local.size(-1)),
|
328 |
+
device=device,
|
329 |
+
),
|
330 |
+
],
|
331 |
+
dim=-1,
|
332 |
+
)
|
333 |
+
else:
|
334 |
+
fusion_x_local = fusion_x_local[:, :, :FT]
|
335 |
+
# 1D fusion
|
336 |
+
new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
|
337 |
+
new_x[longer_list_idx] = self.fusion_model(
|
338 |
+
new_x[longer_list_idx], fusion_x_local
|
339 |
+
)
|
340 |
+
x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
|
341 |
+
else:
|
342 |
+
x = new_x
|
343 |
+
elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
|
344 |
+
x = x # no change
|
345 |
+
|
346 |
+
if self.training:
|
347 |
+
x = self.spec_augmenter(x)
|
348 |
+
# Mixup on spectrogram
|
349 |
+
if self.training and mixup_lambda is not None:
|
350 |
+
x = do_mixup(x, mixup_lambda)
|
351 |
+
if (self.enable_fusion) and (
|
352 |
+
self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
|
353 |
+
):
|
354 |
+
global_x = x[:, 0:1, :, :]
|
355 |
+
|
356 |
+
# global processing
|
357 |
+
B, C, H, W = global_x.shape
|
358 |
+
global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
|
359 |
+
if len(longer_list_idx) > 0:
|
360 |
+
local_x = x[longer_list_idx, 1:, :, :].contiguous()
|
361 |
+
TH = global_x.size(-2)
|
362 |
+
# local processing
|
363 |
+
B, C, H, W = local_x.shape
|
364 |
+
local_x = local_x.view(B * C, 1, H, W)
|
365 |
+
local_x = self.mel_conv2d(local_x)
|
366 |
+
local_x = local_x.view(
|
367 |
+
B, C, local_x.size(1), local_x.size(2), local_x.size(3)
|
368 |
+
)
|
369 |
+
local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
|
370 |
+
TB, TC, _, TW = local_x.size()
|
371 |
+
if local_x.size(-2) < TH:
|
372 |
+
local_x = torch.cat(
|
373 |
+
[
|
374 |
+
local_x,
|
375 |
+
torch.zeros(
|
376 |
+
(TB, TC, TH - local_x.size(-2), TW),
|
377 |
+
device=global_x.device,
|
378 |
+
),
|
379 |
+
],
|
380 |
+
dim=-2,
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
local_x = local_x[:, :, :TH, :]
|
384 |
+
|
385 |
+
global_x[longer_list_idx] = self.fusion_model(
|
386 |
+
global_x[longer_list_idx], local_x
|
387 |
+
)
|
388 |
+
x = global_x
|
389 |
+
else:
|
390 |
+
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
391 |
+
|
392 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
393 |
+
x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
|
394 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
395 |
+
x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
|
396 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
397 |
+
x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
|
398 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
399 |
+
x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
|
400 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
401 |
+
x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
|
402 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
403 |
+
x = torch.mean(x, dim=3)
|
404 |
+
|
405 |
+
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
|
406 |
+
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
|
407 |
+
latent_x = latent_x1 + latent_x2
|
408 |
+
latent_x = latent_x.transpose(1, 2)
|
409 |
+
latent_x = F.relu_(self.fc1(latent_x))
|
410 |
+
latent_output = interpolate(latent_x, 32)
|
411 |
+
|
412 |
+
(x1, _) = torch.max(x, dim=2)
|
413 |
+
x2 = torch.mean(x, dim=2)
|
414 |
+
x = x1 + x2
|
415 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
416 |
+
x = F.relu_(self.fc1(x))
|
417 |
+
embedding = F.dropout(x, p=0.5, training=self.training)
|
418 |
+
clipwise_output = torch.sigmoid(self.fc_audioset(x))
|
419 |
+
|
420 |
+
output_dict = {
|
421 |
+
"clipwise_output": clipwise_output,
|
422 |
+
"embedding": embedding,
|
423 |
+
"fine_grained_embedding": latent_output,
|
424 |
+
}
|
425 |
+
return output_dict
|
426 |
+
|
427 |
+
|
428 |
+
class Cnn6(nn.Module):
|
429 |
+
def __init__(
|
430 |
+
self,
|
431 |
+
sample_rate,
|
432 |
+
window_size,
|
433 |
+
hop_size,
|
434 |
+
mel_bins,
|
435 |
+
fmin,
|
436 |
+
fmax,
|
437 |
+
classes_num,
|
438 |
+
enable_fusion=False,
|
439 |
+
fusion_type="None",
|
440 |
+
):
|
441 |
+
|
442 |
+
super(Cnn6, self).__init__()
|
443 |
+
|
444 |
+
window = "hann"
|
445 |
+
center = True
|
446 |
+
pad_mode = "reflect"
|
447 |
+
ref = 1.0
|
448 |
+
amin = 1e-10
|
449 |
+
top_db = None
|
450 |
+
|
451 |
+
self.enable_fusion = enable_fusion
|
452 |
+
self.fusion_type = fusion_type
|
453 |
+
|
454 |
+
# Spectrogram extractor
|
455 |
+
self.spectrogram_extractor = Spectrogram(
|
456 |
+
n_fft=window_size,
|
457 |
+
hop_length=hop_size,
|
458 |
+
win_length=window_size,
|
459 |
+
window=window,
|
460 |
+
center=center,
|
461 |
+
pad_mode=pad_mode,
|
462 |
+
freeze_parameters=True,
|
463 |
+
)
|
464 |
+
|
465 |
+
# Logmel feature extractor
|
466 |
+
self.logmel_extractor = LogmelFilterBank(
|
467 |
+
sr=sample_rate,
|
468 |
+
n_fft=window_size,
|
469 |
+
n_mels=mel_bins,
|
470 |
+
fmin=fmin,
|
471 |
+
fmax=fmax,
|
472 |
+
ref=ref,
|
473 |
+
amin=amin,
|
474 |
+
top_db=top_db,
|
475 |
+
freeze_parameters=True,
|
476 |
+
)
|
477 |
+
|
478 |
+
# Spec augmenter
|
479 |
+
self.spec_augmenter = SpecAugmentation(
|
480 |
+
time_drop_width=64,
|
481 |
+
time_stripes_num=2,
|
482 |
+
freq_drop_width=8,
|
483 |
+
freq_stripes_num=2,
|
484 |
+
)
|
485 |
+
|
486 |
+
self.bn0 = nn.BatchNorm2d(64)
|
487 |
+
|
488 |
+
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
|
489 |
+
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
|
490 |
+
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
|
491 |
+
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
|
492 |
+
|
493 |
+
self.fc1 = nn.Linear(512, 512, bias=True)
|
494 |
+
self.fc_audioset = nn.Linear(512, classes_num, bias=True)
|
495 |
+
|
496 |
+
self.init_weight()
|
497 |
+
|
498 |
+
def init_weight(self):
|
499 |
+
init_bn(self.bn0)
|
500 |
+
init_layer(self.fc1)
|
501 |
+
init_layer(self.fc_audioset)
|
502 |
+
|
503 |
+
def forward(self, input, mixup_lambda=None, device=None):
|
504 |
+
"""
|
505 |
+
Input: (batch_size, data_length)"""
|
506 |
+
|
507 |
+
x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
|
508 |
+
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
|
509 |
+
|
510 |
+
x = x.transpose(1, 3)
|
511 |
+
x = self.bn0(x)
|
512 |
+
x = x.transpose(1, 3)
|
513 |
+
|
514 |
+
if self.training:
|
515 |
+
x = self.spec_augmenter(x)
|
516 |
+
|
517 |
+
# Mixup on spectrogram
|
518 |
+
if self.training and mixup_lambda is not None:
|
519 |
+
x = do_mixup(x, mixup_lambda)
|
520 |
+
|
521 |
+
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
522 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
523 |
+
x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
|
524 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
525 |
+
x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
|
526 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
527 |
+
x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
|
528 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
529 |
+
x = torch.mean(x, dim=3)
|
530 |
+
|
531 |
+
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
|
532 |
+
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
|
533 |
+
latent_x = latent_x1 + latent_x2
|
534 |
+
latent_x = latent_x.transpose(1, 2)
|
535 |
+
latent_x = F.relu_(self.fc1(latent_x))
|
536 |
+
latent_output = interpolate(latent_x, 16)
|
537 |
+
|
538 |
+
(x1, _) = torch.max(x, dim=2)
|
539 |
+
x2 = torch.mean(x, dim=2)
|
540 |
+
x = x1 + x2
|
541 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
542 |
+
x = F.relu_(self.fc1(x))
|
543 |
+
embedding = F.dropout(x, p=0.5, training=self.training)
|
544 |
+
clipwise_output = torch.sigmoid(self.fc_audioset(x))
|
545 |
+
|
546 |
+
output_dict = {
|
547 |
+
"clipwise_output": clipwise_output,
|
548 |
+
"embedding": embedding,
|
549 |
+
"fine_grained_embedding": latent_output,
|
550 |
+
}
|
551 |
+
|
552 |
+
return output_dict
|
553 |
+
|
554 |
+
|
555 |
+
class Cnn10(nn.Module):
|
556 |
+
def __init__(
|
557 |
+
self,
|
558 |
+
sample_rate,
|
559 |
+
window_size,
|
560 |
+
hop_size,
|
561 |
+
mel_bins,
|
562 |
+
fmin,
|
563 |
+
fmax,
|
564 |
+
classes_num,
|
565 |
+
enable_fusion=False,
|
566 |
+
fusion_type="None",
|
567 |
+
):
|
568 |
+
|
569 |
+
super(Cnn10, self).__init__()
|
570 |
+
|
571 |
+
window = "hann"
|
572 |
+
center = True
|
573 |
+
pad_mode = "reflect"
|
574 |
+
ref = 1.0
|
575 |
+
amin = 1e-10
|
576 |
+
top_db = None
|
577 |
+
|
578 |
+
self.enable_fusion = enable_fusion
|
579 |
+
self.fusion_type = fusion_type
|
580 |
+
|
581 |
+
# Spectrogram extractor
|
582 |
+
self.spectrogram_extractor = Spectrogram(
|
583 |
+
n_fft=window_size,
|
584 |
+
hop_length=hop_size,
|
585 |
+
win_length=window_size,
|
586 |
+
window=window,
|
587 |
+
center=center,
|
588 |
+
pad_mode=pad_mode,
|
589 |
+
freeze_parameters=True,
|
590 |
+
)
|
591 |
+
|
592 |
+
# Logmel feature extractor
|
593 |
+
self.logmel_extractor = LogmelFilterBank(
|
594 |
+
sr=sample_rate,
|
595 |
+
n_fft=window_size,
|
596 |
+
n_mels=mel_bins,
|
597 |
+
fmin=fmin,
|
598 |
+
fmax=fmax,
|
599 |
+
ref=ref,
|
600 |
+
amin=amin,
|
601 |
+
top_db=top_db,
|
602 |
+
freeze_parameters=True,
|
603 |
+
)
|
604 |
+
|
605 |
+
# Spec augmenter
|
606 |
+
self.spec_augmenter = SpecAugmentation(
|
607 |
+
time_drop_width=64,
|
608 |
+
time_stripes_num=2,
|
609 |
+
freq_drop_width=8,
|
610 |
+
freq_stripes_num=2,
|
611 |
+
)
|
612 |
+
|
613 |
+
self.bn0 = nn.BatchNorm2d(64)
|
614 |
+
|
615 |
+
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
|
616 |
+
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
|
617 |
+
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
|
618 |
+
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
|
619 |
+
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
|
620 |
+
|
621 |
+
self.fc1 = nn.Linear(1024, 1024, bias=True)
|
622 |
+
self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
|
623 |
+
|
624 |
+
self.init_weight()
|
625 |
+
|
626 |
+
def init_weight(self):
|
627 |
+
init_bn(self.bn0)
|
628 |
+
init_layer(self.fc1)
|
629 |
+
init_layer(self.fc_audioset)
|
630 |
+
|
631 |
+
def forward(self, input, mixup_lambda=None, device=None):
|
632 |
+
"""
|
633 |
+
Input: (batch_size, data_length)"""
|
634 |
+
|
635 |
+
x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
|
636 |
+
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
|
637 |
+
|
638 |
+
x = x.transpose(1, 3)
|
639 |
+
x = self.bn0(x)
|
640 |
+
x = x.transpose(1, 3)
|
641 |
+
|
642 |
+
if self.training:
|
643 |
+
x = self.spec_augmenter(x)
|
644 |
+
|
645 |
+
# Mixup on spectrogram
|
646 |
+
if self.training and mixup_lambda is not None:
|
647 |
+
x = do_mixup(x, mixup_lambda)
|
648 |
+
|
649 |
+
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
650 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
651 |
+
x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
|
652 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
653 |
+
x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
|
654 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
655 |
+
x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
|
656 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
657 |
+
x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
|
658 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
659 |
+
x = torch.mean(x, dim=3)
|
660 |
+
|
661 |
+
latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
|
662 |
+
latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
|
663 |
+
latent_x = latent_x1 + latent_x2
|
664 |
+
latent_x = latent_x.transpose(1, 2)
|
665 |
+
latent_x = F.relu_(self.fc1(latent_x))
|
666 |
+
latent_output = interpolate(latent_x, 32)
|
667 |
+
|
668 |
+
(x1, _) = torch.max(x, dim=2)
|
669 |
+
x2 = torch.mean(x, dim=2)
|
670 |
+
x = x1 + x2
|
671 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
672 |
+
x = F.relu_(self.fc1(x))
|
673 |
+
embedding = F.dropout(x, p=0.5, training=self.training)
|
674 |
+
clipwise_output = torch.sigmoid(self.fc_audioset(x))
|
675 |
+
|
676 |
+
output_dict = {
|
677 |
+
"clipwise_output": clipwise_output,
|
678 |
+
"embedding": embedding,
|
679 |
+
"fine_grained_embedding": latent_output,
|
680 |
+
}
|
681 |
+
|
682 |
+
return output_dict
|
683 |
+
|
684 |
+
|
685 |
+
def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
|
686 |
+
try:
|
687 |
+
ModelProto = eval(audio_cfg.model_name)
|
688 |
+
model = ModelProto(
|
689 |
+
sample_rate=audio_cfg.sample_rate,
|
690 |
+
window_size=audio_cfg.window_size,
|
691 |
+
hop_size=audio_cfg.hop_size,
|
692 |
+
mel_bins=audio_cfg.mel_bins,
|
693 |
+
fmin=audio_cfg.fmin,
|
694 |
+
fmax=audio_cfg.fmax,
|
695 |
+
classes_num=audio_cfg.class_num,
|
696 |
+
enable_fusion=enable_fusion,
|
697 |
+
fusion_type=fusion_type,
|
698 |
+
)
|
699 |
+
return model
|
700 |
+
except:
|
701 |
+
raise RuntimeError(
|
702 |
+
f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
|
703 |
+
)
|
audioldm/clap/open_clip/pretrained.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
_RN50 = dict(
|
9 |
+
openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
10 |
+
yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
|
11 |
+
cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
|
12 |
+
)
|
13 |
+
|
14 |
+
_RN50_quickgelu = dict(
|
15 |
+
openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
16 |
+
yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
|
17 |
+
cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
|
18 |
+
)
|
19 |
+
|
20 |
+
_RN101 = dict(
|
21 |
+
openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
22 |
+
yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
|
23 |
+
)
|
24 |
+
|
25 |
+
_RN101_quickgelu = dict(
|
26 |
+
openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
27 |
+
yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
|
28 |
+
)
|
29 |
+
|
30 |
+
_RN50x4 = dict(
|
31 |
+
openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
32 |
+
)
|
33 |
+
|
34 |
+
_RN50x16 = dict(
|
35 |
+
openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
36 |
+
)
|
37 |
+
|
38 |
+
_RN50x64 = dict(
|
39 |
+
openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
40 |
+
)
|
41 |
+
|
42 |
+
_VITB32 = dict(
|
43 |
+
openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
44 |
+
laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
|
45 |
+
laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
|
46 |
+
laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
|
47 |
+
)
|
48 |
+
|
49 |
+
_VITB32_quickgelu = dict(
|
50 |
+
openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
51 |
+
laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
|
52 |
+
laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
|
53 |
+
laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
|
54 |
+
)
|
55 |
+
|
56 |
+
_VITB16 = dict(
|
57 |
+
openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
58 |
+
)
|
59 |
+
|
60 |
+
_VITL14 = dict(
|
61 |
+
openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
62 |
+
)
|
63 |
+
|
64 |
+
_PRETRAINED = {
|
65 |
+
"RN50": _RN50,
|
66 |
+
"RN50-quickgelu": _RN50_quickgelu,
|
67 |
+
"RN101": _RN101,
|
68 |
+
"RN101-quickgelu": _RN101_quickgelu,
|
69 |
+
"RN50x4": _RN50x4,
|
70 |
+
"RN50x16": _RN50x16,
|
71 |
+
"ViT-B-32": _VITB32,
|
72 |
+
"ViT-B-32-quickgelu": _VITB32_quickgelu,
|
73 |
+
"ViT-B-16": _VITB16,
|
74 |
+
"ViT-L-14": _VITL14,
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
+
def list_pretrained(as_str: bool = False):
|
79 |
+
"""returns list of pretrained models
|
80 |
+
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
81 |
+
"""
|
82 |
+
return [
|
83 |
+
":".join([k, t]) if as_str else (k, t)
|
84 |
+
for k in _PRETRAINED.keys()
|
85 |
+
for t in _PRETRAINED[k].keys()
|
86 |
+
]
|
87 |
+
|
88 |
+
|
89 |
+
def list_pretrained_tag_models(tag: str):
|
90 |
+
"""return all models having the specified pretrain tag"""
|
91 |
+
models = []
|
92 |
+
for k in _PRETRAINED.keys():
|
93 |
+
if tag in _PRETRAINED[k]:
|
94 |
+
models.append(k)
|
95 |
+
return models
|
96 |
+
|
97 |
+
|
98 |
+
def list_pretrained_model_tags(model: str):
|
99 |
+
"""return all pretrain tags for the specified model architecture"""
|
100 |
+
tags = []
|
101 |
+
if model in _PRETRAINED:
|
102 |
+
tags.extend(_PRETRAINED[model].keys())
|
103 |
+
return tags
|
104 |
+
|
105 |
+
|
106 |
+
def get_pretrained_url(model: str, tag: str):
|
107 |
+
if model not in _PRETRAINED:
|
108 |
+
return ""
|
109 |
+
model_pretrained = _PRETRAINED[model]
|
110 |
+
if tag not in model_pretrained:
|
111 |
+
return ""
|
112 |
+
return model_pretrained[tag]
|
113 |
+
|
114 |
+
|
115 |
+
def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
116 |
+
os.makedirs(root, exist_ok=True)
|
117 |
+
filename = os.path.basename(url)
|
118 |
+
|
119 |
+
if "openaipublic" in url:
|
120 |
+
expected_sha256 = url.split("/")[-2]
|
121 |
+
else:
|
122 |
+
expected_sha256 = ""
|
123 |
+
|
124 |
+
download_target = os.path.join(root, filename)
|
125 |
+
|
126 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
127 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
128 |
+
|
129 |
+
if os.path.isfile(download_target):
|
130 |
+
if expected_sha256:
|
131 |
+
if (
|
132 |
+
hashlib.sha256(open(download_target, "rb").read()).hexdigest()
|
133 |
+
== expected_sha256
|
134 |
+
):
|
135 |
+
return download_target
|
136 |
+
else:
|
137 |
+
warnings.warn(
|
138 |
+
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
return download_target
|
142 |
+
|
143 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
144 |
+
with tqdm(
|
145 |
+
total=int(source.info().get("Content-Length")),
|
146 |
+
ncols=80,
|
147 |
+
unit="iB",
|
148 |
+
unit_scale=True,
|
149 |
+
) as loop:
|
150 |
+
while True:
|
151 |
+
buffer = source.read(8192)
|
152 |
+
if not buffer:
|
153 |
+
break
|
154 |
+
|
155 |
+
output.write(buffer)
|
156 |
+
loop.update(len(buffer))
|
157 |
+
|
158 |
+
if (
|
159 |
+
expected_sha256
|
160 |
+
and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
|
161 |
+
!= expected_sha256
|
162 |
+
):
|
163 |
+
raise RuntimeError(
|
164 |
+
f"Model has been downloaded but the SHA256 checksum does not not match"
|
165 |
+
)
|
166 |
+
|
167 |
+
return download_target
|
audioldm/clap/open_clip/timm_model.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" timm model adapter
|
2 |
+
|
3 |
+
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
4 |
+
"""
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
try:
|
10 |
+
import timm
|
11 |
+
from timm.models.layers import Mlp, to_2tuple
|
12 |
+
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
13 |
+
from timm.models.layers.attention_pool2d import (
|
14 |
+
AttentionPool2d as AbsAttentionPool2d,
|
15 |
+
)
|
16 |
+
except ImportError as e:
|
17 |
+
timm = None
|
18 |
+
|
19 |
+
from .utils import freeze_batch_norm_2d
|
20 |
+
|
21 |
+
|
22 |
+
class TimmModel(nn.Module):
|
23 |
+
"""timm model adapter
|
24 |
+
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
model_name,
|
30 |
+
embed_dim,
|
31 |
+
image_size=224,
|
32 |
+
pool="avg",
|
33 |
+
proj="linear",
|
34 |
+
drop=0.0,
|
35 |
+
pretrained=False,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
if timm is None:
|
39 |
+
raise RuntimeError("Please `pip install timm` to use timm models.")
|
40 |
+
|
41 |
+
self.image_size = to_2tuple(image_size)
|
42 |
+
self.trunk = timm.create_model(model_name, pretrained=pretrained)
|
43 |
+
feat_size = self.trunk.default_cfg.get("pool_size", None)
|
44 |
+
feature_ndim = 1 if not feat_size else 2
|
45 |
+
if pool in ("abs_attn", "rot_attn"):
|
46 |
+
assert feature_ndim == 2
|
47 |
+
# if attn pooling used, remove both classifier and default pool
|
48 |
+
self.trunk.reset_classifier(0, global_pool="")
|
49 |
+
else:
|
50 |
+
# reset global pool if pool config set, otherwise leave as network default
|
51 |
+
reset_kwargs = dict(global_pool=pool) if pool else {}
|
52 |
+
self.trunk.reset_classifier(0, **reset_kwargs)
|
53 |
+
prev_chs = self.trunk.num_features
|
54 |
+
|
55 |
+
head_layers = OrderedDict()
|
56 |
+
if pool == "abs_attn":
|
57 |
+
head_layers["pool"] = AbsAttentionPool2d(
|
58 |
+
prev_chs, feat_size=feat_size, out_features=embed_dim
|
59 |
+
)
|
60 |
+
prev_chs = embed_dim
|
61 |
+
elif pool == "rot_attn":
|
62 |
+
head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
63 |
+
prev_chs = embed_dim
|
64 |
+
else:
|
65 |
+
assert proj, "projection layer needed if non-attention pooling is used."
|
66 |
+
|
67 |
+
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
68 |
+
if proj == "linear":
|
69 |
+
head_layers["drop"] = nn.Dropout(drop)
|
70 |
+
head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
|
71 |
+
elif proj == "mlp":
|
72 |
+
head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
|
73 |
+
|
74 |
+
self.head = nn.Sequential(head_layers)
|
75 |
+
|
76 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
77 |
+
"""lock modules
|
78 |
+
Args:
|
79 |
+
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
80 |
+
"""
|
81 |
+
if not unlocked_groups:
|
82 |
+
# lock full model
|
83 |
+
for param in self.trunk.parameters():
|
84 |
+
param.requires_grad = False
|
85 |
+
if freeze_bn_stats:
|
86 |
+
freeze_batch_norm_2d(self.trunk)
|
87 |
+
else:
|
88 |
+
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
89 |
+
try:
|
90 |
+
# FIXME import here until API stable and in an official release
|
91 |
+
from timm.models.helpers import group_parameters, group_modules
|
92 |
+
except ImportError:
|
93 |
+
raise RuntimeError(
|
94 |
+
"Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
|
95 |
+
)
|
96 |
+
matcher = self.trunk.group_matcher()
|
97 |
+
gparams = group_parameters(self.trunk, matcher)
|
98 |
+
max_layer_id = max(gparams.keys())
|
99 |
+
max_layer_id = max_layer_id - unlocked_groups
|
100 |
+
for group_idx in range(max_layer_id + 1):
|
101 |
+
group = gparams[group_idx]
|
102 |
+
for param in group:
|
103 |
+
self.trunk.get_parameter(param).requires_grad = False
|
104 |
+
if freeze_bn_stats:
|
105 |
+
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
106 |
+
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
107 |
+
freeze_batch_norm_2d(self.trunk, gmodules)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x = self.trunk(x)
|
111 |
+
x = self.head(x)
|
112 |
+
return x
|
audioldm/clap/open_clip/tokenizer.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP tokenizer
|
2 |
+
|
3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
import gzip
|
6 |
+
import html
|
7 |
+
import os
|
8 |
+
from functools import lru_cache
|
9 |
+
from typing import Union, List
|
10 |
+
|
11 |
+
import ftfy
|
12 |
+
import regex as re
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
@lru_cache()
|
17 |
+
def default_bpe():
|
18 |
+
return os.path.join(
|
19 |
+
os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
@lru_cache()
|
24 |
+
def bytes_to_unicode():
|
25 |
+
"""
|
26 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
27 |
+
The reversible bpe codes work on unicode strings.
|
28 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
29 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
30 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
31 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
32 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
33 |
+
"""
|
34 |
+
bs = (
|
35 |
+
list(range(ord("!"), ord("~") + 1))
|
36 |
+
+ list(range(ord("¡"), ord("¬") + 1))
|
37 |
+
+ list(range(ord("®"), ord("ÿ") + 1))
|
38 |
+
)
|
39 |
+
cs = bs[:]
|
40 |
+
n = 0
|
41 |
+
for b in range(2**8):
|
42 |
+
if b not in bs:
|
43 |
+
bs.append(b)
|
44 |
+
cs.append(2**8 + n)
|
45 |
+
n += 1
|
46 |
+
cs = [chr(n) for n in cs]
|
47 |
+
return dict(zip(bs, cs))
|
48 |
+
|
49 |
+
|
50 |
+
def get_pairs(word):
|
51 |
+
"""Return set of symbol pairs in a word.
|
52 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
53 |
+
"""
|
54 |
+
pairs = set()
|
55 |
+
prev_char = word[0]
|
56 |
+
for char in word[1:]:
|
57 |
+
pairs.add((prev_char, char))
|
58 |
+
prev_char = char
|
59 |
+
return pairs
|
60 |
+
|
61 |
+
|
62 |
+
def basic_clean(text):
|
63 |
+
text = ftfy.fix_text(text)
|
64 |
+
text = html.unescape(html.unescape(text))
|
65 |
+
return text.strip()
|
66 |
+
|
67 |
+
|
68 |
+
def whitespace_clean(text):
|
69 |
+
text = re.sub(r"\s+", " ", text)
|
70 |
+
text = text.strip()
|
71 |
+
return text
|
72 |
+
|
73 |
+
|
74 |
+
class SimpleTokenizer(object):
|
75 |
+
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
76 |
+
self.byte_encoder = bytes_to_unicode()
|
77 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
78 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
|
79 |
+
merges = merges[1 : 49152 - 256 - 2 + 1]
|
80 |
+
merges = [tuple(merge.split()) for merge in merges]
|
81 |
+
vocab = list(bytes_to_unicode().values())
|
82 |
+
vocab = vocab + [v + "</w>" for v in vocab]
|
83 |
+
for merge in merges:
|
84 |
+
vocab.append("".join(merge))
|
85 |
+
if not special_tokens:
|
86 |
+
special_tokens = ["<start_of_text>", "<end_of_text>"]
|
87 |
+
else:
|
88 |
+
special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
|
89 |
+
vocab.extend(special_tokens)
|
90 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
91 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
92 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
93 |
+
self.cache = {t: t for t in special_tokens}
|
94 |
+
special = "|".join(special_tokens)
|
95 |
+
self.pat = re.compile(
|
96 |
+
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
97 |
+
re.IGNORECASE,
|
98 |
+
)
|
99 |
+
|
100 |
+
self.vocab_size = len(self.encoder)
|
101 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
102 |
+
|
103 |
+
def bpe(self, token):
|
104 |
+
if token in self.cache:
|
105 |
+
return self.cache[token]
|
106 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
107 |
+
pairs = get_pairs(word)
|
108 |
+
|
109 |
+
if not pairs:
|
110 |
+
return token + "</w>"
|
111 |
+
|
112 |
+
while True:
|
113 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
114 |
+
if bigram not in self.bpe_ranks:
|
115 |
+
break
|
116 |
+
first, second = bigram
|
117 |
+
new_word = []
|
118 |
+
i = 0
|
119 |
+
while i < len(word):
|
120 |
+
try:
|
121 |
+
j = word.index(first, i)
|
122 |
+
new_word.extend(word[i:j])
|
123 |
+
i = j
|
124 |
+
except:
|
125 |
+
new_word.extend(word[i:])
|
126 |
+
break
|
127 |
+
|
128 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
129 |
+
new_word.append(first + second)
|
130 |
+
i += 2
|
131 |
+
else:
|
132 |
+
new_word.append(word[i])
|
133 |
+
i += 1
|
134 |
+
new_word = tuple(new_word)
|
135 |
+
word = new_word
|
136 |
+
if len(word) == 1:
|
137 |
+
break
|
138 |
+
else:
|
139 |
+
pairs = get_pairs(word)
|
140 |
+
word = " ".join(word)
|
141 |
+
self.cache[token] = word
|
142 |
+
return word
|
143 |
+
|
144 |
+
def encode(self, text):
|
145 |
+
bpe_tokens = []
|
146 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
147 |
+
for token in re.findall(self.pat, text):
|
148 |
+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
149 |
+
bpe_tokens.extend(
|
150 |
+
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
151 |
+
)
|
152 |
+
return bpe_tokens
|
153 |
+
|
154 |
+
def decode(self, tokens):
|
155 |
+
text = "".join([self.decoder[token] for token in tokens])
|
156 |
+
text = (
|
157 |
+
bytearray([self.byte_decoder[c] for c in text])
|
158 |
+
.decode("utf-8", errors="replace")
|
159 |
+
.replace("</w>", " ")
|
160 |
+
)
|
161 |
+
return text
|
162 |
+
|
163 |
+
|
164 |
+
_tokenizer = SimpleTokenizer()
|
165 |
+
|
166 |
+
|
167 |
+
def tokenize(
|
168 |
+
texts: Union[str, List[str]], context_length: int = 77
|
169 |
+
) -> torch.LongTensor:
|
170 |
+
"""
|
171 |
+
Returns the tokenized representation of given input string(s)
|
172 |
+
|
173 |
+
Parameters
|
174 |
+
----------
|
175 |
+
texts : Union[str, List[str]]
|
176 |
+
An input string or a list of input strings to tokenize
|
177 |
+
context_length : int
|
178 |
+
The context length to use; all CLIP models use 77 as the context length
|
179 |
+
|
180 |
+
Returns
|
181 |
+
-------
|
182 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
183 |
+
"""
|
184 |
+
if isinstance(texts, str):
|
185 |
+
texts = [texts]
|
186 |
+
|
187 |
+
sot_token = _tokenizer.encoder["<start_of_text>"]
|
188 |
+
eot_token = _tokenizer.encoder["<end_of_text>"]
|
189 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
190 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
191 |
+
|
192 |
+
for i, tokens in enumerate(all_tokens):
|
193 |
+
if len(tokens) > context_length:
|
194 |
+
tokens = tokens[:context_length] # Truncate
|
195 |
+
result[i, : len(tokens)] = torch.tensor(tokens)
|
196 |
+
|
197 |
+
return result
|
audioldm/clap/open_clip/transform.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import (
|
2 |
+
Normalize,
|
3 |
+
Compose,
|
4 |
+
RandomResizedCrop,
|
5 |
+
InterpolationMode,
|
6 |
+
ToTensor,
|
7 |
+
Resize,
|
8 |
+
CenterCrop,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def _convert_to_rgb(image):
|
13 |
+
return image.convert("RGB")
|
14 |
+
|
15 |
+
|
16 |
+
def image_transform(
|
17 |
+
image_size: int,
|
18 |
+
is_train: bool,
|
19 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
20 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
21 |
+
):
|
22 |
+
normalize = Normalize(mean=mean, std=std)
|
23 |
+
if is_train:
|
24 |
+
return Compose(
|
25 |
+
[
|
26 |
+
RandomResizedCrop(
|
27 |
+
image_size,
|
28 |
+
scale=(0.9, 1.0),
|
29 |
+
interpolation=InterpolationMode.BICUBIC,
|
30 |
+
),
|
31 |
+
_convert_to_rgb,
|
32 |
+
ToTensor(),
|
33 |
+
normalize,
|
34 |
+
]
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
return Compose(
|
38 |
+
[
|
39 |
+
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
40 |
+
CenterCrop(image_size),
|
41 |
+
_convert_to_rgb,
|
42 |
+
ToTensor(),
|
43 |
+
normalize,
|
44 |
+
]
|
45 |
+
)
|
audioldm/clap/open_clip/utils.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
5 |
+
import logging
|
6 |
+
# import h5py
|
7 |
+
from tqdm import tqdm
|
8 |
+
import random
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import pathlib
|
12 |
+
|
13 |
+
# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
|
14 |
+
dataset_split = {
|
15 |
+
"audiocaps": ["train", "valid", "test"],
|
16 |
+
"audioset": ["balanced_train", "unbalanced_train", "eval"],
|
17 |
+
"BBCSoundEffects": ["train", "test"],
|
18 |
+
"Clotho": ["train", "test", "valid"],
|
19 |
+
"free_to_use_sounds": ["train", "test"],
|
20 |
+
"paramount_motion": ["train", "test"],
|
21 |
+
"sonniss_game_effects": ["train", "test"],
|
22 |
+
"wesoundeffects": ["train", "test"],
|
23 |
+
"MACS": ["train", "test"],
|
24 |
+
"freesound": ["train", "test"],
|
25 |
+
"FSD50K": ["train", "test", "valid"],
|
26 |
+
"fsd50k_class_label": ["train", "test", "valid"],
|
27 |
+
"esc50": ["train", "test"],
|
28 |
+
"audiostock": ["train", "test"],
|
29 |
+
"freesound_no_overlap_noesc50": ["train", "test"],
|
30 |
+
"epidemic_sound_effects": ["train", "test"],
|
31 |
+
"VGGSound": ["train", "test"],
|
32 |
+
"urbansound8k_class_label": ["train", "test"],
|
33 |
+
"audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
|
34 |
+
"epidemic_sound_effects_t5": ["train", "test"],
|
35 |
+
"WavText5K": ["train", "test"],
|
36 |
+
"esc50_no_overlap": ["train", "test"],
|
37 |
+
"usd8k_no_overlap": ["train", "test"],
|
38 |
+
"fsd50k_200_class_label": ["train", "test", "valid"],
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def freeze_batch_norm_2d(module, module_match={}, name=""):
|
43 |
+
"""
|
44 |
+
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
45 |
+
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
46 |
+
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
module (torch.nn.Module): Any PyTorch module.
|
50 |
+
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
51 |
+
name (str): Full module name (prefix)
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
torch.nn.Module: Resulting module
|
55 |
+
|
56 |
+
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
57 |
+
"""
|
58 |
+
res = module
|
59 |
+
is_match = True
|
60 |
+
if module_match:
|
61 |
+
is_match = name in module_match
|
62 |
+
if is_match and isinstance(
|
63 |
+
module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
|
64 |
+
):
|
65 |
+
res = FrozenBatchNorm2d(module.num_features)
|
66 |
+
res.num_features = module.num_features
|
67 |
+
res.affine = module.affine
|
68 |
+
if module.affine:
|
69 |
+
res.weight.data = module.weight.data.clone().detach()
|
70 |
+
res.bias.data = module.bias.data.clone().detach()
|
71 |
+
res.running_mean.data = module.running_mean.data
|
72 |
+
res.running_var.data = module.running_var.data
|
73 |
+
res.eps = module.eps
|
74 |
+
else:
|
75 |
+
for child_name, child in module.named_children():
|
76 |
+
full_child_name = ".".join([name, child_name]) if name else child_name
|
77 |
+
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
78 |
+
if new_child is not child:
|
79 |
+
res.add_module(child_name, new_child)
|
80 |
+
return res
|
81 |
+
|
82 |
+
|
83 |
+
def exist(dataset_name, dataset_type):
|
84 |
+
"""
|
85 |
+
Check if dataset exists
|
86 |
+
"""
|
87 |
+
if dataset_type in dataset_split[dataset_name]:
|
88 |
+
return True
|
89 |
+
else:
|
90 |
+
return False
|
91 |
+
|
92 |
+
|
93 |
+
def get_tar_path_from_dataset_name(
|
94 |
+
dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Get tar path from dataset name and type
|
98 |
+
"""
|
99 |
+
output = []
|
100 |
+
for n in dataset_names:
|
101 |
+
if full_dataset is not None and n in full_dataset:
|
102 |
+
current_dataset_types = dataset_split[n]
|
103 |
+
else:
|
104 |
+
current_dataset_types = dataset_types
|
105 |
+
for s in current_dataset_types:
|
106 |
+
tmp = []
|
107 |
+
if islocal:
|
108 |
+
sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
|
109 |
+
if not os.path.exists(sizefilepath_):
|
110 |
+
sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
|
111 |
+
else:
|
112 |
+
sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
|
113 |
+
if not os.path.exists(sizefilepath_):
|
114 |
+
continue
|
115 |
+
sizes = json.load(open(sizefilepath_, "r"))
|
116 |
+
for k in sizes.keys():
|
117 |
+
if islocal:
|
118 |
+
tmp.append(f"{dataset_path}/{n}/{s}/{k}")
|
119 |
+
else:
|
120 |
+
tmp.append(
|
121 |
+
f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
|
122 |
+
)
|
123 |
+
if proportion != 1:
|
124 |
+
tmp = random.sample(tmp, int(proportion * len(tmp)))
|
125 |
+
output.append(tmp)
|
126 |
+
return sum(output, [])
|
127 |
+
|
128 |
+
|
129 |
+
def get_tar_path_from_txts(txt_path, islocal, proportion=1):
|
130 |
+
"""
|
131 |
+
Get tar path from txt path
|
132 |
+
"""
|
133 |
+
if isinstance(txt_path, (list, tuple)):
|
134 |
+
return sum(
|
135 |
+
[
|
136 |
+
get_tar_path_from_txts(
|
137 |
+
txt_path[i], islocal=islocal, proportion=proportion
|
138 |
+
)
|
139 |
+
for i in range(len(txt_path))
|
140 |
+
],
|
141 |
+
[],
|
142 |
+
)
|
143 |
+
if isinstance(txt_path, str):
|
144 |
+
with open(txt_path) as f:
|
145 |
+
lines = f.readlines()
|
146 |
+
if islocal:
|
147 |
+
lines = [
|
148 |
+
lines[i]
|
149 |
+
.split("\n")[0]
|
150 |
+
.replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
|
151 |
+
for i in range(len(lines))
|
152 |
+
]
|
153 |
+
else:
|
154 |
+
lines = [
|
155 |
+
lines[i].split("\n")[0].replace(".tar", ".tar -")
|
156 |
+
for i in range(len(lines))
|
157 |
+
]
|
158 |
+
if proportion != 1:
|
159 |
+
print("Sampling tars with proportion of {}".format(proportion))
|
160 |
+
lines = random.sample(lines, int(proportion * len(lines)))
|
161 |
+
return lines
|
162 |
+
|
163 |
+
|
164 |
+
def get_mix_lambda(mixup_alpha, batch_size):
|
165 |
+
mixup_lambdas = [
|
166 |
+
np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
|
167 |
+
]
|
168 |
+
return np.array(mixup_lambdas).astype(np.float32)
|
169 |
+
|
170 |
+
|
171 |
+
def do_mixup(x, mixup_lambda):
|
172 |
+
"""
|
173 |
+
Args:
|
174 |
+
x: (batch_size , ...)
|
175 |
+
mixup_lambda: (batch_size,)
|
176 |
+
Returns:
|
177 |
+
out: (batch_size, ...)
|
178 |
+
"""
|
179 |
+
out = (
|
180 |
+
x.transpose(0, -1) * mixup_lambda
|
181 |
+
+ torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
|
182 |
+
).transpose(0, -1)
|
183 |
+
return out
|
184 |
+
|
185 |
+
|
186 |
+
def interpolate(x, ratio):
|
187 |
+
"""Interpolate data in time domain. This is used to compensate the
|
188 |
+
resolution reduction in downsampling of a CNN.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
x: (batch_size, time_steps, classes_num)
|
192 |
+
ratio: int, ratio to interpolate
|
193 |
+
Returns:
|
194 |
+
upsampled: (batch_size, time_steps * ratio, classes_num)
|
195 |
+
"""
|
196 |
+
(batch_size, time_steps, classes_num) = x.shape
|
197 |
+
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
|
198 |
+
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
|
199 |
+
return upsampled
|
200 |
+
|
201 |
+
|
202 |
+
def pad_framewise_output(framewise_output, frames_num):
|
203 |
+
"""Pad framewise_output to the same length as input frames. The pad value
|
204 |
+
is the same as the value of the last frame.
|
205 |
+
Args:
|
206 |
+
framewise_output: (batch_size, frames_num, classes_num)
|
207 |
+
frames_num: int, number of frames to pad
|
208 |
+
Outputs:
|
209 |
+
output: (batch_size, frames_num, classes_num)
|
210 |
+
"""
|
211 |
+
pad = framewise_output[:, -1:, :].repeat(
|
212 |
+
1, frames_num - framewise_output.shape[1], 1
|
213 |
+
)
|
214 |
+
"""tensor for padding"""
|
215 |
+
|
216 |
+
output = torch.cat((framewise_output, pad), dim=1)
|
217 |
+
"""(batch_size, frames_num, classes_num)"""
|
218 |
+
|
219 |
+
|
220 |
+
# def process_ipc(index_path, classes_num, filename):
|
221 |
+
# # load data
|
222 |
+
# logging.info("Load Data...............")
|
223 |
+
# ipc = [[] for _ in range(classes_num)]
|
224 |
+
# with h5py.File(index_path, "r") as f:
|
225 |
+
# for i in tqdm(range(len(f["target"]))):
|
226 |
+
# t_class = np.where(f["target"][i])[0]
|
227 |
+
# for t in t_class:
|
228 |
+
# ipc[t].append(i)
|
229 |
+
# print(ipc)
|
230 |
+
# np.save(filename, ipc)
|
231 |
+
# logging.info("Load Data Succeed...............")
|
232 |
+
|
233 |
+
|
234 |
+
def save_to_dict(s, o_={}):
|
235 |
+
sp = s.split(": ")
|
236 |
+
o_.update({sp[0]: float(sp[1])})
|
237 |
+
return o_
|
238 |
+
|
239 |
+
|
240 |
+
def get_data_from_log(txt_path):
|
241 |
+
"""
|
242 |
+
Output dictionary from out.txt log file
|
243 |
+
"""
|
244 |
+
with open(txt_path) as f:
|
245 |
+
lines = f.readlines()
|
246 |
+
val_data = {}
|
247 |
+
train_data = {}
|
248 |
+
train_losses = []
|
249 |
+
train_losses_epoch = []
|
250 |
+
for i in range(len(lines)):
|
251 |
+
if "| INFO |" in lines[i]:
|
252 |
+
if "Eval Epoch" in lines[i]:
|
253 |
+
if "val_loss" in lines[i]:
|
254 |
+
# float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
|
255 |
+
line = lines[i].split("Eval Epoch: ")[-1]
|
256 |
+
num_epoch = int(line.split(" ")[0].split(" ")[0])
|
257 |
+
d = {
|
258 |
+
line.split(" ")[0]
|
259 |
+
.split(" ")[1]
|
260 |
+
.replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
|
261 |
+
}
|
262 |
+
for i in range(1, len(line.split(" "))):
|
263 |
+
d = save_to_dict(line.split(" ")[i], d)
|
264 |
+
val_data[num_epoch] = d
|
265 |
+
elif "Train Epoch" in lines[i]:
|
266 |
+
num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
|
267 |
+
loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
|
268 |
+
train_losses.append(loss)
|
269 |
+
train_losses_epoch.append(num_epoch)
|
270 |
+
for i in range(len(train_losses)):
|
271 |
+
train_data[i] = {
|
272 |
+
"num_epoch": train_losses_epoch[i],
|
273 |
+
"train_loss": train_losses[i],
|
274 |
+
}
|
275 |
+
return train_data, val_data
|
276 |
+
|
277 |
+
|
278 |
+
def save_p(obj, filename):
|
279 |
+
import pickle
|
280 |
+
|
281 |
+
try:
|
282 |
+
from deepdiff import DeepDiff
|
283 |
+
except:
|
284 |
+
os.system("pip install deepdiff")
|
285 |
+
from deepdiff import DeepDiff
|
286 |
+
with open(filename, "wb") as file:
|
287 |
+
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
|
288 |
+
with open(filename, "rb") as file:
|
289 |
+
z = pickle.load(file)
|
290 |
+
assert (
|
291 |
+
DeepDiff(obj, z, ignore_string_case=True) == {}
|
292 |
+
), "there is something wrong with the saving process"
|
293 |
+
return
|
294 |
+
|
295 |
+
|
296 |
+
def load_p(filename):
|
297 |
+
import pickle
|
298 |
+
|
299 |
+
with open(filename, "rb") as file:
|
300 |
+
z = pickle.load(file)
|
301 |
+
return z
|
302 |
+
|
303 |
+
|
304 |
+
def save_json(data, name="data.json"):
|
305 |
+
import json
|
306 |
+
|
307 |
+
with open(name, "w") as fp:
|
308 |
+
json.dump(data, fp)
|
309 |
+
return
|
310 |
+
|
311 |
+
|
312 |
+
def load_json(name):
|
313 |
+
import json
|
314 |
+
|
315 |
+
with open(name, "r") as fp:
|
316 |
+
data = json.load(fp)
|
317 |
+
return data
|
318 |
+
|
319 |
+
|
320 |
+
from multiprocessing import Process, Manager
|
321 |
+
from multiprocessing import Process, Value, Array
|
322 |
+
from ctypes import c_wchar
|
323 |
+
|
324 |
+
|
325 |
+
def load_class_label(path):
|
326 |
+
# https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
|
327 |
+
# https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
|
328 |
+
out = None
|
329 |
+
if path is not None:
|
330 |
+
if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
|
331 |
+
out = load_p(path)
|
332 |
+
elif pathlib.Path(path).suffix in [".json", ".txt"]:
|
333 |
+
out = load_json(path)
|
334 |
+
elif pathlib.Path(path).suffix in [".npy", ".npz"]:
|
335 |
+
out = np.load(path)
|
336 |
+
elif pathlib.Path(path).suffix in [".csv"]:
|
337 |
+
import pandas as pd
|
338 |
+
|
339 |
+
out = pd.read_csv(path)
|
340 |
+
return out
|
341 |
+
# if out is None:
|
342 |
+
# return None
|
343 |
+
# else:
|
344 |
+
# key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
|
345 |
+
# val = Array('i', out.values(), lock=False)
|
346 |
+
# return (key, val)
|
347 |
+
|
348 |
+
|
349 |
+
from torch import optim
|
350 |
+
|
351 |
+
|
352 |
+
def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
|
353 |
+
if optimizer_name.lower() == "adamw":
|
354 |
+
optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
|
355 |
+
elif optimizer_name.lower() == "sgd":
|
356 |
+
optimizer = optim.SGD(params, lr=lr, momentum=momentum)
|
357 |
+
elif optimizer_name.lower() == "adam":
|
358 |
+
optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
|
359 |
+
else:
|
360 |
+
raise ValueError("optimizer name is not correct")
|
361 |
+
return optimizer
|
audioldm/clap/open_clip/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.2.1"
|