Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
β’
557fb53
1
Parent(s):
e82ec2b
Refactor config style and reorganize files
Browse files- .gitignore +1 -0
- TODO.md +5 -2
- environment.yml +5 -0
- models/audio_spectrogram_transformer.py +117 -76
- models/config/decision_tree.yaml +47 -0
- models/config/train.yaml +5 -5
- models/config/train_local.yaml +47 -36
- models/decision_tree.py +124 -37
- models/residual.py +82 -86
- models/training_environment.py +90 -0
- models/utils.py +47 -20
- models/wav2vec2.py +84 -0
- preprocessing/dataset.py +230 -198
- preprocessing/pipelines.py +56 -42
- preprocessing/preprocess.py +66 -44
- tests.py +0 -22
- tests/test_datasets.py +17 -0
- tests/test_pipelines.py +13 -0
- tests/utils.py +7 -0
- train.py +9 -176
.gitignore
CHANGED
@@ -9,3 +9,4 @@ lightning_logs
|
|
9 |
.lr_find_*
|
10 |
.cache
|
11 |
.vscode
|
|
|
|
9 |
.lr_find_*
|
10 |
.cache
|
11 |
.vscode
|
12 |
+
models/weights/ast
|
TODO.md
CHANGED
@@ -6,10 +6,13 @@
|
|
6 |
- Create an attention-based network
|
7 |
- β
Increase parameter count in network
|
8 |
- Verify that labels really match what is on the music4dance site
|
9 |
-
- Read the Medium series about audio DL
|
10 |
- double check \_rectify_duration
|
11 |
- β
Filter out songs that have only one vote
|
|
|
|
|
|
|
12 |
|
13 |
## Notes
|
14 |
|
15 |
-
2xM60 insufficient memory.
|
|
|
6 |
- Create an attention-based network
|
7 |
- β
Increase parameter count in network
|
8 |
- Verify that labels really match what is on the music4dance site
|
9 |
+
- β
Read the Medium series about audio DL
|
10 |
- double check \_rectify_duration
|
11 |
- β
Filter out songs that have only one vote
|
12 |
+
- β
Download songs from [Best Ballroom](https://www.youtube.com/channel/UC0bYSnzAFMwPiEjmVsrvmRg)
|
13 |
+
|
14 |
+
- β
fix nan values
|
15 |
|
16 |
## Notes
|
17 |
|
18 |
+
2xM60 insufficient memory for the AST.
|
environment.yml
CHANGED
@@ -23,6 +23,11 @@ dependencies:
|
|
23 |
- scikit-learn
|
24 |
- tensorboard
|
25 |
- transformers
|
|
|
|
|
|
|
26 |
- pip:
|
27 |
- evaluate
|
28 |
- wakepy
|
|
|
|
|
|
23 |
- scikit-learn
|
24 |
- tensorboard
|
25 |
- transformers
|
26 |
+
- accelerate
|
27 |
+
- pytest
|
28 |
+
|
29 |
- pip:
|
30 |
- evaluate
|
31 |
- wakepy
|
32 |
+
- soundfile
|
33 |
+
- youtube_dl
|
models/audio_spectrogram_transformer.py
CHANGED
@@ -1,93 +1,138 @@
|
|
1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
-
from
|
5 |
-
import
|
6 |
-
import numpy as np
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
|
|
|
|
10 |
|
11 |
-
|
12 |
|
13 |
|
14 |
-
|
|
|
15 |
super().__init__(*args, **kwargs)
|
16 |
id2label, label2id = get_id_label_mapping(labels)
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
ignore_mismatched_sizes=True
|
26 |
-
)
|
27 |
-
self.sample_rate = sample_rate
|
28 |
-
|
29 |
-
self.bpm_model = nn.Sequential(
|
30 |
-
nn.Linear(len(labels), 100),
|
31 |
-
nn.Linear(100, 50)
|
32 |
-
)
|
33 |
-
|
34 |
-
out_dim = 50 # TODO: Calculate output dimension
|
35 |
-
self.classifier = nn.Sequential(
|
36 |
-
nn.Linear(out_dim, 100),
|
37 |
-
nn.Linear(100, len(labels))
|
38 |
)
|
39 |
-
|
40 |
-
def vectorize_bpm(self, waveform):
|
41 |
-
pass
|
42 |
-
|
43 |
-
|
44 |
-
def forward(self, audio):
|
45 |
-
|
46 |
-
bpm_vector = self.vectorize_bpm(audio)
|
47 |
-
bpm_out = self.bpm_model(bpm_vector)
|
48 |
-
|
49 |
-
spectrogram = self.ast_feature_extractor(audio)
|
50 |
-
ast_out = self.ast_model(spectrogram)
|
51 |
-
|
52 |
-
# Late fusion
|
53 |
-
z = torch.cat([ast_out, bpm_out]) # Which dimension?
|
54 |
-
return self.classifier(z)
|
55 |
|
|
|
|
|
56 |
|
57 |
-
def compute_metrics(eval_pred):
|
58 |
-
predictions = np.argmax(eval_pred.predictions, axis=1)
|
59 |
-
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
64 |
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
77 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
78 |
-
preprocess_waveform = lambda wf
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
model = AutoModelForAudioClassification.from_pretrained(
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
).to(
|
89 |
training_args = TrainingArguments(
|
90 |
-
output_dir=
|
91 |
evaluation_strategy="epoch",
|
92 |
save_strategy="epoch",
|
93 |
learning_rate=5e-5,
|
@@ -100,7 +145,7 @@ def train(
|
|
100 |
load_best_model_at_end=True,
|
101 |
metric_for_best_model="accuracy",
|
102 |
push_to_hub=False,
|
103 |
-
use_mps_device=
|
104 |
)
|
105 |
|
106 |
trainer = Trainer(
|
@@ -109,11 +154,7 @@ def train(
|
|
109 |
train_dataset=train_ds,
|
110 |
eval_dataset=test_ds,
|
111 |
tokenizer=feature_extractor,
|
112 |
-
compute_metrics=
|
113 |
)
|
114 |
trainer.train()
|
115 |
return model
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
1 |
+
from typing import Any
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.model_selection import train_test_split
|
4 |
+
from transformers import (
|
5 |
+
AutoFeatureExtractor,
|
6 |
+
AutoModelForAudioClassification,
|
7 |
+
TrainingArguments,
|
8 |
+
Trainer,
|
9 |
+
ASTConfig,
|
10 |
+
ASTFeatureExtractor,
|
11 |
+
ASTForAudioClassification,
|
12 |
+
)
|
13 |
import torch
|
14 |
from torch import nn
|
15 |
+
from models.training_environment import TrainingEnvironment
|
16 |
+
from preprocessing.pipelines import WaveformTrainingPipeline
|
|
|
17 |
|
18 |
+
from preprocessing.dataset import (
|
19 |
+
DanceDataModule,
|
20 |
+
HuggingFaceDatasetWrapper,
|
21 |
+
get_datasets,
|
22 |
+
)
|
23 |
+
from preprocessing.dataset import get_music4dance_examples
|
24 |
+
from .utils import get_id_label_mapping, compute_hf_metrics
|
25 |
|
26 |
+
import pytorch_lightning as pl
|
27 |
+
from pytorch_lightning import callbacks as cb
|
28 |
|
29 |
+
MODEL_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
30 |
|
31 |
|
32 |
+
class AST(nn.Module):
|
33 |
+
def __init__(self, labels, *args, **kwargs) -> None:
|
34 |
super().__init__(*args, **kwargs)
|
35 |
id2label, label2id = get_id_label_mapping(labels)
|
36 |
+
config = ASTConfig(
|
37 |
+
hidden_size=300,
|
38 |
+
num_attention_heads=5,
|
39 |
+
num_hidden_layers=3,
|
40 |
+
id2label=id2label,
|
41 |
+
label2id=label2id,
|
42 |
+
num_labels=len(label2id),
|
43 |
+
ignore_mismatched_sizes=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
)
|
45 |
+
self.model = ASTForAudioClassification(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
def forward(self, x):
|
48 |
+
return self.model(x).logits
|
49 |
|
|
|
|
|
|
|
50 |
|
51 |
+
class ASTExtractorWrapper:
|
52 |
+
def __init__(self, sampling_rate=16000, return_tensors="pt") -> None:
|
53 |
+
self.extractor = ASTFeatureExtractor()
|
54 |
+
self.sampling_rate = sampling_rate
|
55 |
+
self.return_tensors = return_tensors
|
56 |
+
self.waveform_pipeline = WaveformTrainingPipeline() # TODO configure from yaml
|
57 |
|
58 |
+
def __call__(self, x) -> Any:
|
59 |
+
x = self.waveform_pipeline(x)
|
60 |
+
device = x.device
|
61 |
+
x = x.squeeze(0).numpy()
|
62 |
+
x = self.extractor(
|
63 |
+
x, return_tensors=self.return_tensors, sampling_rate=self.sampling_rate
|
64 |
+
)
|
65 |
+
return x["input_values"].squeeze(0).to(device)
|
66 |
+
|
67 |
+
|
68 |
+
def train_lightning_ast(config: dict):
|
69 |
+
"""
|
70 |
+
work on integration between waveform dataset and environment. Should work for both HF and PTL.
|
71 |
+
"""
|
72 |
+
TARGET_CLASSES = config["dance_ids"]
|
73 |
+
DEVICE = config["device"]
|
74 |
+
SEED = config["seed"]
|
75 |
+
pl.seed_everything(SEED, workers=True)
|
76 |
+
feature_extractor = ASTExtractorWrapper()
|
77 |
+
dataset = get_datasets(config["datasets"], feature_extractor)
|
78 |
+
data = DanceDataModule(
|
79 |
+
dataset,
|
80 |
+
target_classes=TARGET_CLASSES,
|
81 |
+
**config["data_module"],
|
82 |
+
)
|
83 |
|
84 |
+
model = AST(TARGET_CLASSES).to(DEVICE)
|
85 |
+
label_weights = data.get_label_weights().to(DEVICE)
|
86 |
+
criterion = nn.CrossEntropyLoss(
|
87 |
+
label_weights
|
88 |
+
) # LabelWeightedBCELoss(label_weights)
|
89 |
+
train_env = TrainingEnvironment(model, criterion, config)
|
90 |
+
callbacks = [
|
91 |
+
# cb.LearningRateFinder(update_attr=True),
|
92 |
+
cb.EarlyStopping("val/loss", patience=5),
|
93 |
+
cb.RichProgressBar(),
|
94 |
+
]
|
95 |
+
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
96 |
+
trainer.fit(train_env, datamodule=data)
|
97 |
+
trainer.test(train_env, datamodule=data)
|
98 |
+
|
99 |
+
|
100 |
+
def train_huggingface_ast(config: dict):
|
101 |
+
TARGET_CLASSES = config["dance_ids"]
|
102 |
+
DEVICE = config["device"]
|
103 |
+
SEED = config["seed"]
|
104 |
+
OUTPUT_DIR = "models/weights/ast"
|
105 |
+
batch_size = config["data_module"]["batch_size"]
|
106 |
+
epochs = config["data_module"]["min_epochs"]
|
107 |
+
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
108 |
+
pl.seed_everything(SEED, workers=True)
|
109 |
+
dataset = get_datasets(config["datasets"])
|
110 |
+
hf_dataset = HuggingFaceDatasetWrapper(dataset)
|
111 |
+
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
112 |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
113 |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
114 |
+
preprocess_waveform = lambda wf: feature_extractor(
|
115 |
+
wf,
|
116 |
+
sampling_rate=train_ds.resample_frequency,
|
117 |
+
# padding="max_length",
|
118 |
+
# return_tensors="pt",
|
119 |
+
)
|
120 |
+
hf_dataset.append_to_pipeline(preprocess_waveform)
|
121 |
+
test_proportion = config["data_module"]["test_proportion"]
|
122 |
+
train_proporition = 1 - test_proportion
|
123 |
+
train_ds, test_ds = torch.utils.data.random_split(
|
124 |
+
hf_dataset, [train_proporition, test_proportion]
|
125 |
+
)
|
126 |
|
127 |
model = AutoModelForAudioClassification.from_pretrained(
|
128 |
+
model_checkpoint,
|
129 |
+
num_labels=len(TARGET_CLASSES),
|
130 |
+
label2id=label2id,
|
131 |
+
id2label=id2label,
|
132 |
+
ignore_mismatched_sizes=True,
|
133 |
+
).to(DEVICE)
|
134 |
training_args = TrainingArguments(
|
135 |
+
output_dir=OUTPUT_DIR,
|
136 |
evaluation_strategy="epoch",
|
137 |
save_strategy="epoch",
|
138 |
learning_rate=5e-5,
|
|
|
145 |
load_best_model_at_end=True,
|
146 |
metric_for_best_model="accuracy",
|
147 |
push_to_hub=False,
|
148 |
+
use_mps_device=DEVICE == "mps",
|
149 |
)
|
150 |
|
151 |
trainer = Trainer(
|
|
|
154 |
train_dataset=train_ds,
|
155 |
eval_dataset=test_ds,
|
156 |
tokenizer=feature_extractor,
|
157 |
+
compute_metrics=compute_hf_metrics,
|
158 |
)
|
159 |
trainer.train()
|
160 |
return model
|
|
|
|
|
|
|
|
models/config/decision_tree.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
global:
|
2 |
+
id: decision_tree
|
3 |
+
device: mps
|
4 |
+
seed: 42
|
5 |
+
dance_ids:
|
6 |
+
- ATN
|
7 |
+
- BCH
|
8 |
+
- CHA
|
9 |
+
- ECS
|
10 |
+
- HST
|
11 |
+
- JIV
|
12 |
+
- QST
|
13 |
+
- RMB
|
14 |
+
- SFT
|
15 |
+
- SLS
|
16 |
+
- SMB
|
17 |
+
- SWZ
|
18 |
+
- TGO
|
19 |
+
- VWZ
|
20 |
+
- WCS
|
21 |
+
data_module:
|
22 |
+
song_data_path: data/songs_cleaned.csv
|
23 |
+
song_audio_path: data/samples
|
24 |
+
batch_size: 32
|
25 |
+
num_workers: 7
|
26 |
+
min_votes: 1
|
27 |
+
dataset_kwargs:
|
28 |
+
audio_window_duration: 6
|
29 |
+
audio_window_jitter: 1.5
|
30 |
+
audio_pipeline_kwargs:
|
31 |
+
mask_count: 0 # Don't mask the data
|
32 |
+
snr_mean: 15.0 # Pretty much eliminate the noise
|
33 |
+
freq_mask_size: 10
|
34 |
+
time_mask_size: 80
|
35 |
+
|
36 |
+
trainer:
|
37 |
+
log_every_n_steps: 15
|
38 |
+
accelerator: gpu
|
39 |
+
max_epochs: 50
|
40 |
+
min_epochs: 5
|
41 |
+
fast_dev_run: False
|
42 |
+
# gradient_clip_val: 0.5
|
43 |
+
# overfit_batches: 1
|
44 |
+
training_environment:
|
45 |
+
learning_rate: 0.00053
|
46 |
+
model:
|
47 |
+
n_channels: 128
|
models/config/train.yaml
CHANGED
@@ -27,11 +27,11 @@ data_module:
|
|
27 |
dataset_kwargs:
|
28 |
audio_window_duration: 6
|
29 |
audio_window_jitter: 1.5
|
30 |
-
audio_pipeline_kwargs:
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
|
36 |
trainer:
|
37 |
log_every_n_steps: 15
|
|
|
27 |
dataset_kwargs:
|
28 |
audio_window_duration: 6
|
29 |
audio_window_jitter: 1.5
|
30 |
+
# audio_pipeline_kwargs:
|
31 |
+
# mask_count: 0 # Don't mask the data
|
32 |
+
# snr_mean: 15.0 # Pretty much eliminate the noise
|
33 |
+
# freq_mask_size: 10
|
34 |
+
# time_mask_size: 80
|
35 |
|
36 |
trainer:
|
37 |
log_every_n_steps: 15
|
models/config/train_local.yaml
CHANGED
@@ -1,47 +1,58 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
- VWZ
|
20 |
-
- WCS
|
21 |
data_module:
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
trainer:
|
37 |
log_every_n_steps: 15
|
38 |
accelerator: gpu
|
39 |
max_epochs: 50
|
40 |
-
min_epochs:
|
41 |
fast_dev_run: False
|
42 |
# gradient_clip_val: 0.5
|
43 |
# overfit_batches: 1
|
|
|
44 |
training_environment:
|
45 |
learning_rate: 0.00053
|
46 |
-
|
47 |
-
n_channels: 128
|
|
|
1 |
+
training_fn: audio_spectrogram_transformer.train_lightning_ast
|
2 |
+
device: mps
|
3 |
+
seed: 42
|
4 |
+
dance_ids: &dance_ids
|
5 |
+
- BCH
|
6 |
+
- CHA
|
7 |
+
- JIV
|
8 |
+
- ECS
|
9 |
+
- QST
|
10 |
+
- RMB
|
11 |
+
- SFT
|
12 |
+
- SLS
|
13 |
+
- SMB
|
14 |
+
- SWZ
|
15 |
+
- TGO
|
16 |
+
- VWZ
|
17 |
+
- WCS
|
18 |
+
|
|
|
|
|
19 |
data_module:
|
20 |
+
batch_size: 64
|
21 |
+
num_workers: 10
|
22 |
+
test_proportion: 0.2
|
23 |
+
|
24 |
+
datasets:
|
25 |
+
preprocessing.dataset.BestBallroomDataset:
|
26 |
+
audio_dir: data/ballroom-songs
|
27 |
+
class_list: *dance_ids
|
28 |
+
audio_window_jitter: 0.7
|
29 |
+
|
30 |
+
preprocessing.dataset.Music4DanceDataset:
|
31 |
+
song_data_path: data/songs_cleaned.csv
|
32 |
+
song_audio_path: data/samples # data/samples
|
33 |
+
class_list: *dance_ids
|
34 |
+
multi_label: False
|
35 |
+
min_votes: 1
|
36 |
+
audio_window_jitter: 0.7
|
37 |
+
|
38 |
+
model:
|
39 |
+
n_channels: 128
|
40 |
+
|
41 |
+
feature_extractor:
|
42 |
+
mask_count: 0 # Don't mask the data
|
43 |
+
snr_mean: 15.0 # Pretty much eliminate the noise
|
44 |
+
freq_mask_size: 10
|
45 |
+
time_mask_size: 80
|
46 |
|
47 |
trainer:
|
48 |
log_every_n_steps: 15
|
49 |
accelerator: gpu
|
50 |
max_epochs: 50
|
51 |
+
min_epochs: 7
|
52 |
fast_dev_run: False
|
53 |
# gradient_clip_val: 0.5
|
54 |
# overfit_batches: 1
|
55 |
+
|
56 |
training_environment:
|
57 |
learning_rate: 0.00053
|
58 |
+
log_spectrograms: False
|
|
models/decision_tree.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from sklearn.base import ClassifierMixin, BaseEstimator
|
2 |
import pandas as pd
|
3 |
from torch import nn
|
@@ -5,8 +6,14 @@ import torch
|
|
5 |
from typing import Iterator
|
6 |
import numpy as np
|
7 |
import json
|
|
|
8 |
from tqdm import tqdm
|
9 |
import librosa
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
DANCE_INFO_FILE = "data/dance_info.csv"
|
12 |
dance_info_df = pd.read_csv(
|
@@ -24,9 +31,8 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
|
|
24 |
- BPM
|
25 |
"""
|
26 |
|
27 |
-
def __init__(self, device="cpu", lr=1e-4,
|
28 |
self.device = device
|
29 |
-
self.epochs = epochs
|
30 |
self.verbose = verbose
|
31 |
self.lr = lr
|
32 |
self.classifiers = {}
|
@@ -44,41 +50,40 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
|
|
44 |
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
|
45 |
y: (batch_size, n_classes)
|
46 |
"""
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
self.optimizers[dance] = torch.optim.Adam(
|
63 |
-
classifier.parameters(), lr=self.lr
|
64 |
-
)
|
65 |
-
models = [
|
66 |
-
(dance, model, self.optimizers[dance])
|
67 |
-
for dance, model in self.classifiers.items()
|
68 |
-
if dance in matching_dances
|
69 |
-
]
|
70 |
-
for model_i, (dance, model, opt) in enumerate(models):
|
71 |
-
opt.zero_grad()
|
72 |
-
output = model(spec)
|
73 |
-
target = torch.tensor([float(dance == label)], device=self.device)
|
74 |
-
loss = self.criterion(output, target)
|
75 |
-
epoch_loss += loss.item()
|
76 |
-
pred_count += 1
|
77 |
-
loss.backward()
|
78 |
-
opt.step()
|
79 |
-
progress_bar.set_description(
|
80 |
-
f"Loss: {epoch_loss / pred_count}, Step: {step}, Model: {model_i+1}/{len(models)}"
|
81 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
def predict(self, x) -> list[str]:
|
84 |
results = []
|
@@ -90,6 +95,52 @@ class DanceTreeClassifier(BaseEstimator, ClassifierMixin):
|
|
90 |
results.append(matching_dances[dance_i])
|
91 |
return results
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
class DanceCNN(nn.Module):
|
95 |
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
|
@@ -136,7 +187,6 @@ def features_from_path(
|
|
136 |
num_frames = audio_window_duration * sr
|
137 |
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
|
138 |
spec = librosa.feature.melspectrogram(y=waveform, sr=sr)
|
139 |
-
mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20)
|
140 |
spec_normalized = (spec - spec.mean()) / spec.std()
|
141 |
spec_padded = librosa.util.fix_length(
|
142 |
spec_normalized, size=sr * audio_duration, axis=1
|
@@ -145,3 +195,40 @@ def features_from_path(
|
|
145 |
for i in range(audio_duration // audio_window_duration):
|
146 |
spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames]
|
147 |
yield (spec_window, tempo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
from sklearn.base import ClassifierMixin, BaseEstimator
|
3 |
import pandas as pd
|
4 |
from torch import nn
|
|
|
6 |
from typing import Iterator
|
7 |
import numpy as np
|
8 |
import json
|
9 |
+
from torch.utils.data import random_split
|
10 |
from tqdm import tqdm
|
11 |
import librosa
|
12 |
+
from joblib import dump, load
|
13 |
+
from os import path
|
14 |
+
import os
|
15 |
+
|
16 |
+
from preprocessing.dataset import get_music4dance_examples
|
17 |
|
18 |
DANCE_INFO_FILE = "data/dance_info.csv"
|
19 |
dance_info_df = pd.read_csv(
|
|
|
31 |
- BPM
|
32 |
"""
|
33 |
|
34 |
+
def __init__(self, device="cpu", lr=1e-4, verbose=True) -> None:
|
35 |
self.device = device
|
|
|
36 |
self.verbose = verbose
|
37 |
self.lr = lr
|
38 |
self.classifiers = {}
|
|
|
50 |
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time)
|
51 |
y: (batch_size, n_classes)
|
52 |
"""
|
53 |
+
epoch_loss = 0
|
54 |
+
pred_count = 0
|
55 |
+
data_loader = zip(x, y)
|
56 |
+
if self.verbose:
|
57 |
+
data_loader = tqdm(data_loader, total=len(y))
|
58 |
+
for (spec, bpm), label in data_loader:
|
59 |
+
# find all models that are in the bpm range
|
60 |
+
matching_dances = self.get_valid_dances_from_bpm(bpm)
|
61 |
+
spec = torch.from_numpy(spec).to(self.device)
|
62 |
+
for dance in matching_dances:
|
63 |
+
if dance not in self.classifiers or dance not in self.optimizers:
|
64 |
+
classifier = DanceCNN().to(self.device)
|
65 |
+
self.classifiers[dance] = classifier
|
66 |
+
self.optimizers[dance] = torch.optim.Adam(
|
67 |
+
classifier.parameters(), lr=self.lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
)
|
69 |
+
models = [
|
70 |
+
(dance, model, self.optimizers[dance])
|
71 |
+
for dance, model in self.classifiers.items()
|
72 |
+
if dance in matching_dances
|
73 |
+
]
|
74 |
+
for model_i, (dance, model, opt) in enumerate(models, start=1):
|
75 |
+
opt.zero_grad()
|
76 |
+
output = model(spec)
|
77 |
+
target = torch.tensor([float(dance == label)], device=self.device)
|
78 |
+
loss = self.criterion(output, target)
|
79 |
+
epoch_loss += loss.item()
|
80 |
+
pred_count += 1
|
81 |
+
loss.backward()
|
82 |
+
if self.verbose:
|
83 |
+
data_loader.set_description(
|
84 |
+
f"model: {model_i}/{len(models)}, loss: {loss.item()}"
|
85 |
+
)
|
86 |
+
opt.step()
|
87 |
|
88 |
def predict(self, x) -> list[str]:
|
89 |
results = []
|
|
|
95 |
results.append(matching_dances[dance_i])
|
96 |
return results
|
97 |
|
98 |
+
def save(self, folder: str):
|
99 |
+
# Create a folder
|
100 |
+
classifier_path = path.join(folder, "classifier")
|
101 |
+
os.makedirs(classifier_path, exist_ok=True)
|
102 |
+
|
103 |
+
# Swap out model reference
|
104 |
+
classifiers = self.classifiers
|
105 |
+
optimizers = self.optimizers
|
106 |
+
criterion = self.criterion
|
107 |
+
|
108 |
+
self.classifiers = None
|
109 |
+
self.optimizers = None
|
110 |
+
self.criterion = None
|
111 |
+
|
112 |
+
# Save the Pth models
|
113 |
+
for dance, classifier in classifiers.items():
|
114 |
+
torch.save(
|
115 |
+
classifier.state_dict(), path.join(classifier_path, dance + ".pth")
|
116 |
+
)
|
117 |
+
|
118 |
+
# Save the Sklearn model
|
119 |
+
dump(path.join(folder, "sklearn.joblib"))
|
120 |
+
|
121 |
+
# Reload values
|
122 |
+
self.classifiers = classifiers
|
123 |
+
self.optimizers = optimizers
|
124 |
+
self.criterion = criterion
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def from_config(folder: str, device="cpu") -> "DanceTreeClassifier":
|
128 |
+
# load in weights
|
129 |
+
model_paths = (
|
130 |
+
p for p in os.listdir(path.join(folder, "classifier")) if p.endswith("pth")
|
131 |
+
)
|
132 |
+
classifiers = {}
|
133 |
+
for model_path in model_paths:
|
134 |
+
dance = model_path.split(".")[0]
|
135 |
+
model = DanceCNN().to(device)
|
136 |
+
model.load_state_dict(
|
137 |
+
torch.load(path.join(folder, "classifier", model_path))
|
138 |
+
)
|
139 |
+
classifiers[dance] = model
|
140 |
+
wrapper = load(path.join(folder, "sklearn.joblib"))
|
141 |
+
wrapper.classifiers = classifiers
|
142 |
+
return wrapper
|
143 |
+
|
144 |
|
145 |
class DanceCNN(nn.Module):
|
146 |
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None:
|
|
|
187 |
num_frames = audio_window_duration * sr
|
188 |
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr)
|
189 |
spec = librosa.feature.melspectrogram(y=waveform, sr=sr)
|
|
|
190 |
spec_normalized = (spec - spec.mean()) / spec.std()
|
191 |
spec_padded = librosa.util.fix_length(
|
192 |
spec_normalized, size=sr * audio_duration, axis=1
|
|
|
195 |
for i in range(audio_duration // audio_window_duration):
|
196 |
spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames]
|
197 |
yield (spec_window, tempo)
|
198 |
+
|
199 |
+
|
200 |
+
def train_decision_tree(config: dict):
|
201 |
+
TARGET_CLASSES = config["global"]["dance_ids"]
|
202 |
+
DEVICE = config["global"]["device"]
|
203 |
+
SEED = config["global"]["seed"]
|
204 |
+
SEED = config["global"]["seed"]
|
205 |
+
EPOCHS = config["trainer"]["min_epochs"]
|
206 |
+
song_data_path = config["data_module"]["song_data_path"]
|
207 |
+
song_audio_path = config["data_module"]["song_audio_path"]
|
208 |
+
pl.seed_everything(SEED, workers=True)
|
209 |
+
|
210 |
+
df = pd.read_csv(song_data_path)
|
211 |
+
x, y = get_music4dance_examples(
|
212 |
+
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
|
213 |
+
)
|
214 |
+
# Convert y back to string classes
|
215 |
+
y = np.array(TARGET_CLASSES)[y.argmax(-1)]
|
216 |
+
train_i, test_i = random_split(
|
217 |
+
np.arange(len(x)), [0.1, 0.9]
|
218 |
+
) # Temporary to test efficacy
|
219 |
+
train_paths, train_y = x[train_i], y[train_i]
|
220 |
+
model = DanceTreeClassifier(device=DEVICE)
|
221 |
+
for epoch in tqdm(range(1, EPOCHS + 1)):
|
222 |
+
# Shuffle the data
|
223 |
+
i = np.arange(len(train_paths))
|
224 |
+
np.random.shuffle(i)
|
225 |
+
train_paths = train_paths[i]
|
226 |
+
train_y = train_y[i]
|
227 |
+
train_x = features_from_path(train_paths)
|
228 |
+
model.fit(train_x, train_y)
|
229 |
+
|
230 |
+
# evaluate the model
|
231 |
+
preds = model.predict(x[test_i])
|
232 |
+
accuracy = (preds == y[test_i]).mean()
|
233 |
+
print(f"{accuracy=}")
|
234 |
+
model.save("models/weights/decision_tree")
|
models/residual.py
CHANGED
@@ -1,18 +1,25 @@
|
|
|
|
|
|
1 |
import torch
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
-
import pytorch_lightning as pl
|
6 |
import numpy as np
|
7 |
import torchaudio
|
8 |
import yaml
|
9 |
-
from .
|
10 |
-
from preprocessing.
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
|
13 |
|
|
|
14 |
class ResidualDancer(nn.Module):
|
15 |
-
def __init__(self,n_channels=128, n_classes=50):
|
16 |
super().__init__()
|
17 |
|
18 |
self.n_channels = n_channels
|
@@ -25,17 +32,17 @@ class ResidualDancer(nn.Module):
|
|
25 |
self.res_layers = nn.Sequential(
|
26 |
ResBlock(1, n_channels, stride=2),
|
27 |
ResBlock(n_channels, n_channels, stride=2),
|
28 |
-
ResBlock(n_channels, n_channels*2, stride=2),
|
29 |
-
ResBlock(n_channels*2, n_channels*2, stride=2),
|
30 |
-
ResBlock(n_channels*2, n_channels*2, stride=2),
|
31 |
-
ResBlock(n_channels*2, n_channels*2, stride=2),
|
32 |
-
ResBlock(n_channels*2, n_channels*4, stride=2)
|
33 |
)
|
34 |
|
35 |
# Dense
|
36 |
-
self.dense1 = nn.Linear(n_channels*4, n_channels*4)
|
37 |
-
self.bn = nn.BatchNorm1d(n_channels*4)
|
38 |
-
self.dense2 = nn.Linear(n_channels*4, n_classes)
|
39 |
self.dropout = nn.Dropout(0.2)
|
40 |
|
41 |
def forward(self, x):
|
@@ -56,24 +63,34 @@ class ResidualDancer(nn.Module):
|
|
56 |
x = F.relu(x)
|
57 |
x = self.dropout(x)
|
58 |
x = self.dense2(x)
|
59 |
-
x = nn.Sigmoid()(x)
|
60 |
|
61 |
return x
|
62 |
-
|
63 |
|
64 |
class ResBlock(nn.Module):
|
65 |
def __init__(self, input_channels, output_channels, shape=3, stride=2):
|
66 |
super().__init__()
|
67 |
# convolution
|
68 |
-
self.conv_1 = nn.Conv2d(
|
|
|
|
|
69 |
self.bn_1 = nn.BatchNorm2d(output_channels)
|
70 |
-
self.conv_2 = nn.Conv2d(
|
|
|
|
|
71 |
self.bn_2 = nn.BatchNorm2d(output_channels)
|
72 |
|
73 |
# residual
|
74 |
self.diff = False
|
75 |
if (stride != 1) or (input_channels != output_channels):
|
76 |
-
self.conv_3 = nn.Conv2d(
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
self.bn_3 = nn.BatchNorm2d(output_channels)
|
78 |
self.diff = True
|
79 |
self.relu = nn.ReLU()
|
@@ -89,79 +106,31 @@ class ResBlock(nn.Module):
|
|
89 |
out = self.relu(out)
|
90 |
return out
|
91 |
|
92 |
-
class TrainingEnvironment(pl.LightningModule):
|
93 |
-
|
94 |
-
def __init__(self, model: nn.Module, criterion: nn.Module, config:dict, learning_rate=1e-4, *args, **kwargs):
|
95 |
-
super().__init__(*args, **kwargs)
|
96 |
-
self.model = model
|
97 |
-
self.criterion = criterion
|
98 |
-
self.learning_rate = learning_rate
|
99 |
-
self.config=config
|
100 |
-
self.save_hyperparameters({
|
101 |
-
"model": type(model).__name__,
|
102 |
-
"loss": type(criterion).__name__,
|
103 |
-
"config": config,
|
104 |
-
**kwargs
|
105 |
-
})
|
106 |
-
|
107 |
-
def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
|
108 |
-
features, labels = batch
|
109 |
-
outputs = self.model(features)
|
110 |
-
loss = self.criterion(outputs, labels)
|
111 |
-
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
|
112 |
-
self.log_dict(metrics, prog_bar=True)
|
113 |
-
# Log spectrograms
|
114 |
-
if batch_index % 100 == 0:
|
115 |
-
tensorboard = self.logger.experiment
|
116 |
-
img_index = torch.randint(0, len(features), (1,)).item()
|
117 |
-
img = features[img_index][0]
|
118 |
-
img = (img - img.min()) / (img.max() - img.min())
|
119 |
-
tensorboard.add_image(f"batch: {batch_index}, element: {img_index}", img, 0, dataformats='HW')
|
120 |
-
return loss
|
121 |
-
|
122 |
-
|
123 |
-
def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
124 |
-
x, y = batch
|
125 |
-
preds = self.model(x)
|
126 |
-
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
|
127 |
-
metrics["val/loss"] = self.criterion(preds, y)
|
128 |
-
self.log_dict(metrics,prog_bar=True)
|
129 |
-
|
130 |
-
def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
|
131 |
-
x, y = batch
|
132 |
-
preds = self.model(x)
|
133 |
-
self.log_dict(calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True)
|
134 |
-
|
135 |
-
def configure_optimizers(self):
|
136 |
-
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
137 |
-
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
138 |
-
return [optimizer]
|
139 |
-
|
140 |
-
|
141 |
|
142 |
class DancePredictor:
|
143 |
def __init__(
|
144 |
-
self,
|
145 |
-
weight_path:str,
|
146 |
-
labels:list[str],
|
147 |
-
expected_duration=6,
|
148 |
threshold=0.5,
|
149 |
resample_frequency=16000,
|
150 |
-
device="cpu"
|
151 |
-
|
152 |
super().__init__()
|
153 |
-
|
154 |
self.expected_duration = expected_duration
|
155 |
self.threshold = threshold
|
156 |
self.resample_frequency = resample_frequency
|
157 |
-
self.preprocess_waveform = WaveformPreprocessing(
|
158 |
-
|
|
|
|
|
159 |
self.labels = np.array(labels)
|
160 |
self.device = device
|
161 |
self.model = self.get_model(weight_path)
|
162 |
|
163 |
-
|
164 |
-
def get_model(self, weight_path:str) -> nn.Module:
|
165 |
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
166 |
model = ResidualDancer(n_classes=len(self.labels))
|
167 |
for key in list(weights):
|
@@ -170,21 +139,25 @@ class DancePredictor:
|
|
170 |
return model.to(self.device).eval()
|
171 |
|
172 |
@classmethod
|
173 |
-
def from_config(cls, config_path:str) -> "DancePredictor":
|
174 |
with open(config_path, "r") as f:
|
175 |
config = yaml.safe_load(f)
|
176 |
return DancePredictor(**config)
|
177 |
|
178 |
@torch.no_grad()
|
179 |
-
def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
|
180 |
if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
|
181 |
-
waveform = waveform.transpose(1,0)
|
182 |
elif len(waveform.shape) == 1:
|
183 |
waveform = np.expand_dims(waveform, 0)
|
184 |
waveform = torch.from_numpy(waveform.astype("int16"))
|
185 |
-
waveform = torchaudio.functional.apply_codec(
|
|
|
|
|
186 |
|
187 |
-
waveform = torchaudio.functional.resample(
|
|
|
|
|
188 |
waveform = self.preprocess_waveform(waveform)
|
189 |
spectrogram = self.audio_to_spectrogram(waveform)
|
190 |
spectrogram = spectrogram.unsqueeze(0).to(self.device)
|
@@ -194,8 +167,31 @@ class DancePredictor:
|
|
194 |
result_mask = results > self.threshold
|
195 |
probs = results[result_mask]
|
196 |
dances = self.labels[result_mask]
|
197 |
-
|
198 |
-
return {dance:float(prob) for dance, prob in zip(dances, probs)}
|
199 |
-
|
200 |
-
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
from pytorch_lightning import callbacks as cb
|
3 |
import torch
|
4 |
+
from torch import nn
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
|
|
|
8 |
import numpy as np
|
9 |
import torchaudio
|
10 |
import yaml
|
11 |
+
from models.training_environment import TrainingEnvironment
|
12 |
+
from preprocessing.dataset import DanceDataModule, get_datasets
|
13 |
+
from preprocessing.pipelines import (
|
14 |
+
SpectrogramTrainingPipeline,
|
15 |
+
WaveformPreprocessing,
|
16 |
+
)
|
17 |
|
18 |
# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
|
19 |
|
20 |
+
|
21 |
class ResidualDancer(nn.Module):
|
22 |
+
def __init__(self, n_channels=128, n_classes=50):
|
23 |
super().__init__()
|
24 |
|
25 |
self.n_channels = n_channels
|
|
|
32 |
self.res_layers = nn.Sequential(
|
33 |
ResBlock(1, n_channels, stride=2),
|
34 |
ResBlock(n_channels, n_channels, stride=2),
|
35 |
+
ResBlock(n_channels, n_channels * 2, stride=2),
|
36 |
+
ResBlock(n_channels * 2, n_channels * 2, stride=2),
|
37 |
+
ResBlock(n_channels * 2, n_channels * 2, stride=2),
|
38 |
+
ResBlock(n_channels * 2, n_channels * 2, stride=2),
|
39 |
+
ResBlock(n_channels * 2, n_channels * 4, stride=2),
|
40 |
)
|
41 |
|
42 |
# Dense
|
43 |
+
self.dense1 = nn.Linear(n_channels * 4, n_channels * 4)
|
44 |
+
self.bn = nn.BatchNorm1d(n_channels * 4)
|
45 |
+
self.dense2 = nn.Linear(n_channels * 4, n_classes)
|
46 |
self.dropout = nn.Dropout(0.2)
|
47 |
|
48 |
def forward(self, x):
|
|
|
63 |
x = F.relu(x)
|
64 |
x = self.dropout(x)
|
65 |
x = self.dense2(x)
|
66 |
+
# x = nn.Sigmoid()(x)
|
67 |
|
68 |
return x
|
69 |
+
|
70 |
|
71 |
class ResBlock(nn.Module):
|
72 |
def __init__(self, input_channels, output_channels, shape=3, stride=2):
|
73 |
super().__init__()
|
74 |
# convolution
|
75 |
+
self.conv_1 = nn.Conv2d(
|
76 |
+
input_channels, output_channels, shape, stride=stride, padding=shape // 2
|
77 |
+
)
|
78 |
self.bn_1 = nn.BatchNorm2d(output_channels)
|
79 |
+
self.conv_2 = nn.Conv2d(
|
80 |
+
output_channels, output_channels, shape, padding=shape // 2
|
81 |
+
)
|
82 |
self.bn_2 = nn.BatchNorm2d(output_channels)
|
83 |
|
84 |
# residual
|
85 |
self.diff = False
|
86 |
if (stride != 1) or (input_channels != output_channels):
|
87 |
+
self.conv_3 = nn.Conv2d(
|
88 |
+
input_channels,
|
89 |
+
output_channels,
|
90 |
+
shape,
|
91 |
+
stride=stride,
|
92 |
+
padding=shape // 2,
|
93 |
+
)
|
94 |
self.bn_3 = nn.BatchNorm2d(output_channels)
|
95 |
self.diff = True
|
96 |
self.relu = nn.ReLU()
|
|
|
106 |
out = self.relu(out)
|
107 |
return out
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
class DancePredictor:
|
111 |
def __init__(
|
112 |
+
self,
|
113 |
+
weight_path: str,
|
114 |
+
labels: list[str],
|
115 |
+
expected_duration=6,
|
116 |
threshold=0.5,
|
117 |
resample_frequency=16000,
|
118 |
+
device="cpu",
|
119 |
+
):
|
120 |
super().__init__()
|
121 |
+
|
122 |
self.expected_duration = expected_duration
|
123 |
self.threshold = threshold
|
124 |
self.resample_frequency = resample_frequency
|
125 |
+
self.preprocess_waveform = WaveformPreprocessing(
|
126 |
+
resample_frequency * expected_duration
|
127 |
+
)
|
128 |
+
self.audio_to_spectrogram = lambda x: x # TODO: Fix
|
129 |
self.labels = np.array(labels)
|
130 |
self.device = device
|
131 |
self.model = self.get_model(weight_path)
|
132 |
|
133 |
+
def get_model(self, weight_path: str) -> nn.Module:
|
|
|
134 |
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
135 |
model = ResidualDancer(n_classes=len(self.labels))
|
136 |
for key in list(weights):
|
|
|
139 |
return model.to(self.device).eval()
|
140 |
|
141 |
@classmethod
|
142 |
+
def from_config(cls, config_path: str) -> "DancePredictor":
|
143 |
with open(config_path, "r") as f:
|
144 |
config = yaml.safe_load(f)
|
145 |
return DancePredictor(**config)
|
146 |
|
147 |
@torch.no_grad()
|
148 |
+
def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]:
|
149 |
if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
|
150 |
+
waveform = waveform.transpose(1, 0)
|
151 |
elif len(waveform.shape) == 1:
|
152 |
waveform = np.expand_dims(waveform, 0)
|
153 |
waveform = torch.from_numpy(waveform.astype("int16"))
|
154 |
+
waveform = torchaudio.functional.apply_codec(
|
155 |
+
waveform, sample_rate, "wav", channels_first=True
|
156 |
+
)
|
157 |
|
158 |
+
waveform = torchaudio.functional.resample(
|
159 |
+
waveform, sample_rate, self.resample_frequency
|
160 |
+
)
|
161 |
waveform = self.preprocess_waveform(waveform)
|
162 |
spectrogram = self.audio_to_spectrogram(waveform)
|
163 |
spectrogram = spectrogram.unsqueeze(0).to(self.device)
|
|
|
167 |
result_mask = results > self.threshold
|
168 |
probs = results[result_mask]
|
169 |
dances = self.labels[result_mask]
|
|
|
|
|
|
|
|
|
170 |
|
171 |
+
return {dance: float(prob) for dance, prob in zip(dances, probs)}
|
172 |
+
|
173 |
+
|
174 |
+
def train_residual_dancer(config: dict):
|
175 |
+
TARGET_CLASSES = config["dance_ids"]
|
176 |
+
DEVICE = config["device"]
|
177 |
+
SEED = config["seed"]
|
178 |
+
pl.seed_everything(SEED, workers=True)
|
179 |
+
feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
|
180 |
+
dataset = get_datasets(config["datasets"], feature_extractor)
|
181 |
+
|
182 |
+
data = DanceDataModule(dataset, **config["data_module"])
|
183 |
+
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
|
184 |
+
label_weights = data.get_label_weights().to(DEVICE)
|
185 |
+
criterion = nn.CrossEntropyLoss(label_weights)
|
186 |
+
|
187 |
+
train_env = TrainingEnvironment(model, criterion, config)
|
188 |
+
callbacks = [
|
189 |
+
# cb.LearningRateFinder(update_attr=True),
|
190 |
+
cb.EarlyStopping("val/loss", patience=5),
|
191 |
+
cb.StochasticWeightAveraging(1e-2),
|
192 |
+
cb.RichProgressBar(),
|
193 |
+
cb.DeviceStatsMonitor(),
|
194 |
+
]
|
195 |
+
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
196 |
+
trainer.fit(train_env, datamodule=data)
|
197 |
+
trainer.test(train_env, datamodule=data)
|
models/training_environment.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.utils import calculate_metrics
|
2 |
+
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class TrainingEnvironment(pl.LightningModule):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
model: nn.Module,
|
13 |
+
criterion: nn.Module,
|
14 |
+
config: dict,
|
15 |
+
learning_rate=1e-4,
|
16 |
+
log_spectrograms=False,
|
17 |
+
*args,
|
18 |
+
**kwargs,
|
19 |
+
):
|
20 |
+
super().__init__(*args, **kwargs)
|
21 |
+
self.model = model
|
22 |
+
self.criterion = criterion
|
23 |
+
self.learning_rate = learning_rate
|
24 |
+
self.log_spectrograms = log_spectrograms
|
25 |
+
self.config = config
|
26 |
+
self.has_multi_label_predictions = (
|
27 |
+
not type(criterion).__name__ == "CrossEntropyLoss"
|
28 |
+
)
|
29 |
+
self.save_hyperparameters(
|
30 |
+
{
|
31 |
+
"model": type(model).__name__,
|
32 |
+
"loss": type(criterion).__name__,
|
33 |
+
"config": config,
|
34 |
+
**kwargs,
|
35 |
+
}
|
36 |
+
)
|
37 |
+
|
38 |
+
def training_step(
|
39 |
+
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
40 |
+
) -> torch.Tensor:
|
41 |
+
features, labels = batch
|
42 |
+
outputs = self.model(features)
|
43 |
+
loss = self.criterion(outputs, labels)
|
44 |
+
metrics = calculate_metrics(
|
45 |
+
outputs,
|
46 |
+
labels,
|
47 |
+
prefix="train/",
|
48 |
+
multi_label=self.has_multi_label_predictions,
|
49 |
+
)
|
50 |
+
self.log_dict(metrics, prog_bar=True)
|
51 |
+
# Log spectrograms
|
52 |
+
if self.log_spectrograms and batch_index % 100 == 0:
|
53 |
+
tensorboard = self.logger.experiment
|
54 |
+
img_index = torch.randint(0, len(features), (1,)).item()
|
55 |
+
img = features[img_index][0]
|
56 |
+
img = (img - img.min()) / (img.max() - img.min())
|
57 |
+
tensorboard.add_image(
|
58 |
+
f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
|
59 |
+
)
|
60 |
+
return loss
|
61 |
+
|
62 |
+
def validation_step(
|
63 |
+
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
64 |
+
):
|
65 |
+
x, y = batch
|
66 |
+
preds = self.model(x)
|
67 |
+
metrics = calculate_metrics(
|
68 |
+
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
|
69 |
+
)
|
70 |
+
metrics["val/loss"] = self.criterion(preds, y)
|
71 |
+
self.log_dict(metrics, prog_bar=True)
|
72 |
+
|
73 |
+
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
74 |
+
x, y = batch
|
75 |
+
preds = self.model(x)
|
76 |
+
self.log_dict(
|
77 |
+
calculate_metrics(
|
78 |
+
preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
|
79 |
+
),
|
80 |
+
prog_bar=True,
|
81 |
+
)
|
82 |
+
|
83 |
+
def configure_optimizers(self):
|
84 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
85 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
|
86 |
+
return {
|
87 |
+
"optimizer": optimizer,
|
88 |
+
"lr_scheduler": scheduler,
|
89 |
+
"monitor": "val/loss",
|
90 |
+
}
|
models/utils.py
CHANGED
@@ -1,14 +1,20 @@
|
|
1 |
import torch.nn as nn
|
2 |
import torch
|
3 |
import numpy as np
|
|
|
4 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
5 |
|
|
|
|
|
|
|
|
|
6 |
class LabelWeightedBCELoss(nn.Module):
|
7 |
"""
|
8 |
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
|
9 |
Allows for the weighing of each probability distribution wrt loss.
|
10 |
"""
|
11 |
-
|
|
|
12 |
super().__init__()
|
13 |
self.label_weights = label_weights
|
14 |
|
@@ -17,46 +23,67 @@ class LabelWeightedBCELoss(nn.Module):
|
|
17 |
self.reduction = torch.mean
|
18 |
case "sum":
|
19 |
self.reduction = torch.sum
|
20 |
-
|
21 |
-
def _log(self,x:torch.Tensor) -> torch.Tensor:
|
22 |
return torch.clamp_min(torch.log(x), -100)
|
23 |
|
24 |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
25 |
-
losses = -self.label_weights * (
|
|
|
|
|
26 |
return self.reduction(losses)
|
27 |
|
28 |
|
29 |
# TODO: Code a onehot
|
30 |
|
31 |
|
32 |
-
def calculate_metrics(
|
|
|
|
|
33 |
target = target.detach().cpu().numpy()
|
34 |
pred = pred.detach().cpu().numpy()
|
35 |
params = {
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
class EarlyStopping:
|
50 |
def __init__(self, patience=0):
|
51 |
self.patience = patience
|
52 |
self.last_measure = np.inf
|
53 |
self.consecutive_increase = 0
|
54 |
-
|
55 |
def step(self, val) -> bool:
|
56 |
if self.last_measure <= val:
|
57 |
-
self.consecutive_increase +=1
|
58 |
else:
|
59 |
self.consecutive_increase = 0
|
60 |
self.last_measure = val
|
61 |
|
62 |
-
return self.patience < self.consecutive_increase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch.nn as nn
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
import evaluate
|
5 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
6 |
|
7 |
+
|
8 |
+
accuracy = evaluate.load("accuracy")
|
9 |
+
|
10 |
+
|
11 |
class LabelWeightedBCELoss(nn.Module):
|
12 |
"""
|
13 |
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
|
14 |
Allows for the weighing of each probability distribution wrt loss.
|
15 |
"""
|
16 |
+
|
17 |
+
def __init__(self, label_weights: torch.Tensor, reduction="mean"):
|
18 |
super().__init__()
|
19 |
self.label_weights = label_weights
|
20 |
|
|
|
23 |
self.reduction = torch.mean
|
24 |
case "sum":
|
25 |
self.reduction = torch.sum
|
26 |
+
|
27 |
+
def _log(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
return torch.clamp_min(torch.log(x), -100)
|
29 |
|
30 |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
31 |
+
losses = -self.label_weights * (
|
32 |
+
target * self._log(input) + (1 - target) * self._log(1 - input)
|
33 |
+
)
|
34 |
return self.reduction(losses)
|
35 |
|
36 |
|
37 |
# TODO: Code a onehot
|
38 |
|
39 |
|
40 |
+
def calculate_metrics(
|
41 |
+
pred, target, threshold=0.5, prefix="", multi_label=True
|
42 |
+
) -> dict[str, torch.Tensor]:
|
43 |
target = target.detach().cpu().numpy()
|
44 |
pred = pred.detach().cpu().numpy()
|
45 |
params = {
|
46 |
+
"y_true": target if multi_label else target.argmax(1),
|
47 |
+
"y_pred": np.array(pred > threshold, dtype=float)
|
48 |
+
if multi_label
|
49 |
+
else pred.argmax(1),
|
50 |
+
"zero_division": 0,
|
51 |
+
"average": "macro",
|
52 |
+
}
|
53 |
+
metrics = {
|
54 |
+
"precision": precision_score(**params),
|
55 |
+
"recall": recall_score(**params),
|
56 |
+
"f1": f1_score(**params),
|
57 |
+
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
|
58 |
+
}
|
59 |
+
return {
|
60 |
+
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
|
61 |
+
}
|
62 |
+
|
63 |
|
64 |
class EarlyStopping:
|
65 |
def __init__(self, patience=0):
|
66 |
self.patience = patience
|
67 |
self.last_measure = np.inf
|
68 |
self.consecutive_increase = 0
|
69 |
+
|
70 |
def step(self, val) -> bool:
|
71 |
if self.last_measure <= val:
|
72 |
+
self.consecutive_increase += 1
|
73 |
else:
|
74 |
self.consecutive_increase = 0
|
75 |
self.last_measure = val
|
76 |
|
77 |
+
return self.patience < self.consecutive_increase
|
78 |
+
|
79 |
+
|
80 |
+
def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
|
81 |
+
id2label = {str(i): label for i, label in enumerate(labels)}
|
82 |
+
label2id = {label: str(i) for i, label in enumerate(labels)}
|
83 |
+
|
84 |
+
return id2label, label2id
|
85 |
+
|
86 |
+
|
87 |
+
def compute_hf_metrics(eval_pred):
|
88 |
+
predictions = np.argmax(eval_pred.predictions, axis=1)
|
89 |
+
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
|
models/wav2vec2.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from torch.utils.data import random_split
|
5 |
+
from transformers import AutoFeatureExtractor
|
6 |
+
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
|
7 |
+
|
8 |
+
from preprocessing.dataset import (
|
9 |
+
HuggingFaceDatasetWrapper,
|
10 |
+
BestBallroomDataset,
|
11 |
+
get_datasets,
|
12 |
+
)
|
13 |
+
from preprocessing.pipelines import WaveformTrainingPipeline
|
14 |
+
|
15 |
+
from .utils import get_id_label_mapping, compute_hf_metrics
|
16 |
+
|
17 |
+
MODEL_CHECKPOINT = "facebook/wav2vec2-base"
|
18 |
+
|
19 |
+
|
20 |
+
class Wav2VecFeatureExtractor:
|
21 |
+
def __init__(self) -> None:
|
22 |
+
self.waveform_pipeline = WaveformTrainingPipeline()
|
23 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
24 |
+
MODEL_CHECKPOINT,
|
25 |
+
)
|
26 |
+
|
27 |
+
def __call__(self, waveform) -> Any:
|
28 |
+
waveform = self.waveform_pipeline(waveform)
|
29 |
+
return self.feature_extractor(
|
30 |
+
waveform, sampling_rate=self.feature_extractor.sampling_rate
|
31 |
+
)
|
32 |
+
|
33 |
+
def __getattr__(self, attr):
|
34 |
+
return getattr(self.feature_extractor, attr)
|
35 |
+
|
36 |
+
|
37 |
+
def train_wav_model(config: dict):
|
38 |
+
TARGET_CLASSES = config["dance_ids"]
|
39 |
+
DEVICE = config["device"]
|
40 |
+
SEED = config["seed"]
|
41 |
+
OUTPUT_DIR = "models/weights/wav2vec2"
|
42 |
+
batch_size = config["data_module"]["batch_size"]
|
43 |
+
epochs = config["trainer"]["min_epochs"]
|
44 |
+
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
45 |
+
pl.seed_everything(SEED, workers=True)
|
46 |
+
dataset = get_datasets(config["datasets"])
|
47 |
+
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
48 |
+
test_proportion = config["data_module"]["test_proportion"]
|
49 |
+
train_proporition = 1 - test_proportion
|
50 |
+
train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
|
51 |
+
feature_extractor = Wav2VecFeatureExtractor()
|
52 |
+
model = AutoModelForAudioClassification.from_pretrained(
|
53 |
+
MODEL_CHECKPOINT,
|
54 |
+
num_labels=len(TARGET_CLASSES),
|
55 |
+
label2id=label2id,
|
56 |
+
id2label=id2label,
|
57 |
+
ignore_mismatched_sizes=True,
|
58 |
+
).to(DEVICE)
|
59 |
+
training_args = TrainingArguments(
|
60 |
+
output_dir=OUTPUT_DIR,
|
61 |
+
evaluation_strategy="epoch",
|
62 |
+
save_strategy="epoch",
|
63 |
+
learning_rate=3e-5,
|
64 |
+
per_device_train_batch_size=batch_size,
|
65 |
+
gradient_accumulation_steps=5,
|
66 |
+
per_device_eval_batch_size=batch_size,
|
67 |
+
num_train_epochs=epochs,
|
68 |
+
warmup_ratio=0.1,
|
69 |
+
logging_steps=10,
|
70 |
+
load_best_model_at_end=True,
|
71 |
+
metric_for_best_model="accuracy",
|
72 |
+
push_to_hub=False,
|
73 |
+
use_mps_device=DEVICE == "mps",
|
74 |
+
)
|
75 |
+
trainer = Trainer(
|
76 |
+
model=model,
|
77 |
+
args=training_args,
|
78 |
+
train_dataset=train_ds,
|
79 |
+
eval_dataset=test_ds,
|
80 |
+
tokenizer=feature_extractor,
|
81 |
+
compute_metrics=compute_hf_metrics,
|
82 |
+
)
|
83 |
+
trainer.train()
|
84 |
+
return model
|
preprocessing/dataset.py
CHANGED
@@ -1,15 +1,21 @@
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
-
from torch.utils.data import Dataset, DataLoader, random_split
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
import torchaudio as ta
|
6 |
-
from .pipelines import AudioTrainingPipeline
|
7 |
import pytorch_lightning as pl
|
8 |
-
|
9 |
-
from
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
class SongDataset(Dataset):
|
@@ -17,60 +23,67 @@ class SongDataset(Dataset):
|
|
17 |
self,
|
18 |
audio_paths: list[str],
|
19 |
dance_labels: list[np.ndarray],
|
20 |
-
|
21 |
audio_window_duration=6, # seconds
|
22 |
-
audio_window_jitter=
|
23 |
-
audio_pipeline_kwargs={},
|
24 |
-
resample_frequency=16000,
|
25 |
):
|
26 |
-
assert (
|
27 |
-
audio_duration % audio_window_duration == 0
|
28 |
-
), "Audio window should divide duration evenly."
|
29 |
assert (
|
30 |
audio_window_duration > audio_window_jitter
|
31 |
), "Jitter should be a small fraction of the audio window duration."
|
32 |
|
33 |
self.audio_paths = audio_paths
|
34 |
self.dance_labels = dance_labels
|
35 |
-
|
36 |
-
self.
|
|
|
|
|
|
|
37 |
self.audio_window_duration = int(audio_window_duration)
|
|
|
38 |
self.audio_window_jitter = audio_window_jitter
|
39 |
-
self.audio_duration = int(audio_duration)
|
40 |
-
|
41 |
-
self.audio_pipeline = AudioTrainingPipeline(
|
42 |
-
self.sample_rate,
|
43 |
-
resample_frequency,
|
44 |
-
audio_window_duration,
|
45 |
-
**audio_pipeline_kwargs,
|
46 |
-
)
|
47 |
|
48 |
def __len__(self):
|
49 |
-
return
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
spectrogram = self.audio_pipeline(waveform)
|
57 |
|
|
|
58 |
dance_labels = self._label_from_index(idx)
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
def
|
70 |
-
|
|
|
71 |
|
72 |
def _backtrace_audio_path(self, index: int) -> str:
|
73 |
-
return self.audio_paths[self.
|
74 |
|
75 |
def _validate_output(self, x, y):
|
76 |
is_finite = not torch.any(torch.isinf(x))
|
@@ -80,16 +93,18 @@ class SongDataset(Dataset):
|
|
80 |
return all((is_finite, is_numerical, has_data, is_binary))
|
81 |
|
82 |
def _waveform_from_index(self, idx: int) -> torch.Tensor:
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
|
87 |
jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
|
88 |
jitter = int(
|
89 |
torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate
|
90 |
)
|
91 |
-
frame_offset = (
|
92 |
-
frame_index * self.audio_window_duration * self.sample_rate
|
|
|
|
|
93 |
)
|
94 |
num_frames = self.sample_rate * self.audio_window_duration
|
95 |
waveform, sample_rate = ta.load(
|
@@ -101,41 +116,21 @@ class SongDataset(Dataset):
|
|
101 |
return waveform
|
102 |
|
103 |
def _label_from_index(self, idx: int) -> torch.Tensor:
|
104 |
-
return torch.from_numpy(self.dance_labels[self.
|
105 |
|
106 |
|
107 |
-
class
|
108 |
"""
|
109 |
-
|
110 |
"""
|
111 |
|
112 |
-
def __init__(self, *args,
|
113 |
super().__init__(*args, **kwargs)
|
114 |
-
self.
|
115 |
-
self.resampler = taT.Resample(self.sample_rate, self.resample_frequency)
|
116 |
self.pipeline = []
|
117 |
|
118 |
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
119 |
-
|
120 |
-
assert (
|
121 |
-
waveform.shape[1] > 10
|
122 |
-
), f"No data found: {self._backtrace_audio_path(idx)}"
|
123 |
-
# resample the waveform
|
124 |
-
waveform = self.resampler(waveform)
|
125 |
-
|
126 |
-
waveform = waveform.mean(0)
|
127 |
-
|
128 |
-
dance_labels = self._label_from_index(idx)
|
129 |
-
return waveform, dance_labels
|
130 |
-
|
131 |
-
|
132 |
-
class HuggingFaceWaveformSongDataset(WaveformSongDataset):
|
133 |
-
def __init__(self, *args, **kwargs):
|
134 |
-
super().__init__(*args, **kwargs)
|
135 |
-
self.pipeline = []
|
136 |
-
|
137 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
138 |
-
x, y = super().__getitem__(idx)
|
139 |
if len(self.pipeline) > 0:
|
140 |
for fn in self.pipeline:
|
141 |
x = fn(x)
|
@@ -146,59 +141,158 @@ class HuggingFaceWaveformSongDataset(WaveformSongDataset):
|
|
146 |
"label": dance_labels,
|
147 |
}
|
148 |
|
149 |
-
def
|
|
|
|
|
|
|
150 |
"""
|
151 |
-
|
152 |
"""
|
153 |
self.pipeline.append(fn)
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
class DanceDataModule(pl.LightningDataModule):
|
157 |
def __init__(
|
158 |
self,
|
159 |
-
|
160 |
-
song_audio_path="data/samples",
|
161 |
test_proportion=0.15,
|
162 |
val_proportion=0.1,
|
163 |
target_classes: list[str] = None,
|
164 |
-
min_votes=1,
|
165 |
batch_size: int = 64,
|
166 |
num_workers=10,
|
167 |
-
dataset_cls=None,
|
168 |
-
dataset_kwargs={},
|
169 |
):
|
170 |
super().__init__()
|
171 |
-
self.song_data_path = song_data_path
|
172 |
-
self.song_audio_path = song_audio_path
|
173 |
self.val_proportion = val_proportion
|
174 |
self.test_proportion = test_proportion
|
175 |
self.train_proportion = 1.0 - test_proportion - val_proportion
|
176 |
self.target_classes = target_classes
|
177 |
self.batch_size = batch_size
|
178 |
self.num_workers = num_workers
|
179 |
-
self.
|
180 |
-
self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset
|
181 |
-
|
182 |
-
df = pd.read_csv(song_data_path)
|
183 |
-
self.x, self.y = get_examples(
|
184 |
-
df,
|
185 |
-
self.song_audio_path,
|
186 |
-
class_list=self.target_classes,
|
187 |
-
multi_label=True,
|
188 |
-
min_votes=min_votes,
|
189 |
-
)
|
190 |
|
191 |
def setup(self, stage: str):
|
192 |
-
|
193 |
-
|
194 |
[self.train_proportion, self.val_proportion, self.test_proportion],
|
195 |
)
|
196 |
-
self.train_ds = self._dataset_from_indices(train_i)
|
197 |
-
self.val_ds = self._dataset_from_indices(val_i)
|
198 |
-
self.test_ds = self._dataset_from_indices(test_i)
|
199 |
-
|
200 |
-
def _dataset_from_indices(self, idx: list[int]) -> SongDataset:
|
201 |
-
return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs)
|
202 |
|
203 |
def train_dataloader(self):
|
204 |
return DataLoader(
|
@@ -210,110 +304,48 @@ class DanceDataModule(pl.LightningDataModule):
|
|
210 |
|
211 |
def val_dataloader(self):
|
212 |
return DataLoader(
|
213 |
-
self.val_ds,
|
|
|
|
|
214 |
)
|
215 |
|
216 |
def test_dataloader(self):
|
217 |
return DataLoader(
|
218 |
-
self.test_ds,
|
|
|
|
|
219 |
)
|
220 |
|
221 |
def get_label_weights(self):
|
222 |
-
|
223 |
-
|
|
|
|
|
224 |
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
)
|
251 |
-
|
252 |
-
def preprocess_inputs(self, x):
|
253 |
-
device = x.device
|
254 |
-
x = list(x.squeeze(1).cpu().numpy())
|
255 |
-
x = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000)
|
256 |
-
return x["input_values"].to(device)
|
257 |
-
|
258 |
-
def training_step(
|
259 |
-
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
260 |
-
) -> torch.Tensor:
|
261 |
-
features, labels = batch
|
262 |
-
features = self.preprocess_inputs(features)
|
263 |
-
outputs = self.model(features).logits
|
264 |
-
outputs = nn.Sigmoid()(
|
265 |
-
outputs
|
266 |
-
) # good for multi label classification, should be softmax otherwise
|
267 |
-
loss = self.criterion(outputs, labels)
|
268 |
-
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
|
269 |
-
self.log_dict(metrics, prog_bar=True)
|
270 |
-
return loss
|
271 |
-
|
272 |
-
def validation_step(
|
273 |
-
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
|
274 |
-
):
|
275 |
-
x, y = batch
|
276 |
-
x = self.preprocess_inputs(x)
|
277 |
-
preds = self.model(x).logits
|
278 |
-
preds = nn.Sigmoid()(preds)
|
279 |
-
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
|
280 |
-
metrics["val/loss"] = self.criterion(preds, y)
|
281 |
-
self.log_dict(metrics, prog_bar=True)
|
282 |
-
|
283 |
-
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
284 |
-
x, y = batch
|
285 |
-
x = self.preprocess_inputs(x)
|
286 |
-
preds = self.model(x).logits
|
287 |
-
preds = nn.Sigmoid()(preds)
|
288 |
-
self.log_dict(
|
289 |
-
calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True
|
290 |
-
)
|
291 |
-
|
292 |
-
def configure_optimizers(self):
|
293 |
-
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
294 |
-
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
|
295 |
-
return [optimizer]
|
296 |
-
|
297 |
-
|
298 |
-
def calculate_metrics(
|
299 |
-
pred, target, threshold=0.5, prefix="", multi_label=True
|
300 |
-
) -> dict[str, torch.Tensor]:
|
301 |
-
target = target.detach().cpu().numpy()
|
302 |
-
pred = pred.detach().cpu().numpy()
|
303 |
-
params = {
|
304 |
-
"y_true": target if multi_label else target.argmax(1),
|
305 |
-
"y_pred": np.array(pred > threshold, dtype=float)
|
306 |
-
if multi_label
|
307 |
-
else pred.argmax(1),
|
308 |
-
"zero_division": 0,
|
309 |
-
"average": "macro",
|
310 |
-
}
|
311 |
-
metrics = {
|
312 |
-
"precision": precision_score(**params),
|
313 |
-
"recall": recall_score(**params),
|
314 |
-
"f1": f1_score(**params),
|
315 |
-
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
|
316 |
-
}
|
317 |
-
return {
|
318 |
-
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
|
319 |
-
}
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
from typing import Any
|
4 |
import torch
|
5 |
+
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
|
6 |
import numpy as np
|
7 |
import pandas as pd
|
8 |
import torchaudio as ta
|
|
|
9 |
import pytorch_lightning as pl
|
10 |
+
|
11 |
+
from preprocessing.preprocess import (
|
12 |
+
fix_dance_rating_counts,
|
13 |
+
get_unique_labels,
|
14 |
+
has_valid_audio,
|
15 |
+
url_to_filename,
|
16 |
+
vectorize_label_probs,
|
17 |
+
vectorize_multi_label,
|
18 |
+
)
|
19 |
|
20 |
|
21 |
class SongDataset(Dataset):
|
|
|
23 |
self,
|
24 |
audio_paths: list[str],
|
25 |
dance_labels: list[np.ndarray],
|
26 |
+
audio_start_offset=6, # seconds
|
27 |
audio_window_duration=6, # seconds
|
28 |
+
audio_window_jitter=1.0, # seconds
|
|
|
|
|
29 |
):
|
|
|
|
|
|
|
30 |
assert (
|
31 |
audio_window_duration > audio_window_jitter
|
32 |
), "Jitter should be a small fraction of the audio window duration."
|
33 |
|
34 |
self.audio_paths = audio_paths
|
35 |
self.dance_labels = dance_labels
|
36 |
+
audio_metadata = [ta.info(audio) for audio in audio_paths]
|
37 |
+
self.audio_durations = [
|
38 |
+
meta.num_frames / meta.sample_rate for meta in audio_metadata
|
39 |
+
]
|
40 |
+
self.sample_rate = audio_metadata[0].sample_rate # assuming same sample rate
|
41 |
self.audio_window_duration = int(audio_window_duration)
|
42 |
+
self.audio_start_offset = audio_start_offset
|
43 |
self.audio_window_jitter = audio_window_jitter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def __len__(self):
|
46 |
+
return int(
|
47 |
+
sum(
|
48 |
+
max(duration - self.audio_start_offset, 0) // self.audio_window_duration
|
49 |
+
for duration in self.audio_durations
|
50 |
+
)
|
51 |
+
)
|
52 |
|
53 |
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
54 |
+
if isinstance(idx, list):
|
55 |
+
return [
|
56 |
+
(self._waveform_from_index(i), self._label_from_index(i)) for i in idx
|
57 |
+
]
|
|
|
58 |
|
59 |
+
waveform = self._waveform_from_index(idx)
|
60 |
dance_labels = self._label_from_index(idx)
|
61 |
+
return waveform, dance_labels
|
62 |
|
63 |
+
def _idx2audio_idx(self, idx: int) -> int:
|
64 |
+
return self._get_audio_loc_from_idx(idx)[0]
|
65 |
+
|
66 |
+
def _get_audio_loc_from_idx(self, idx: int) -> tuple[int, int]:
|
67 |
+
"""
|
68 |
+
Converts dataset index to the indices that reference the target audio path
|
69 |
+
and window offset.
|
70 |
+
"""
|
71 |
+
total_slices = 0
|
72 |
+
for audio_index, duration in enumerate(self.audio_durations):
|
73 |
+
audio_slices = max(
|
74 |
+
(duration - self.audio_start_offset) // self.audio_window_duration, 1
|
75 |
+
)
|
76 |
+
if total_slices + audio_slices > idx:
|
77 |
+
frame_index = idx - total_slices
|
78 |
+
return audio_index, frame_index
|
79 |
+
total_slices += audio_slices
|
80 |
|
81 |
+
def get_label_weights(self):
|
82 |
+
n_examples, n_classes = self.dance_labels.shape
|
83 |
+
return torch.from_numpy(n_examples / (n_classes * sum(self.dance_labels)))
|
84 |
|
85 |
def _backtrace_audio_path(self, index: int) -> str:
|
86 |
+
return self.audio_paths[self._idx2audio_idx(index)]
|
87 |
|
88 |
def _validate_output(self, x, y):
|
89 |
is_finite = not torch.any(torch.isinf(x))
|
|
|
93 |
return all((is_finite, is_numerical, has_data, is_binary))
|
94 |
|
95 |
def _waveform_from_index(self, idx: int) -> torch.Tensor:
|
96 |
+
audio_index, frame_index = self._get_audio_loc_from_idx(idx)
|
97 |
+
audio_filepath = self.audio_paths[audio_index]
|
98 |
+
num_windows = self.audio_durations[audio_index] // self.audio_window_duration
|
99 |
jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
|
100 |
jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
|
101 |
jitter = int(
|
102 |
torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate
|
103 |
)
|
104 |
+
frame_offset = int(
|
105 |
+
frame_index * self.audio_window_duration * self.sample_rate
|
106 |
+
+ jitter
|
107 |
+
+ self.audio_start_offset * self.sample_rate
|
108 |
)
|
109 |
num_frames = self.sample_rate * self.audio_window_duration
|
110 |
waveform, sample_rate = ta.load(
|
|
|
116 |
return waveform
|
117 |
|
118 |
def _label_from_index(self, idx: int) -> torch.Tensor:
|
119 |
+
return torch.from_numpy(self.dance_labels[self._idx2audio_idx(idx)])
|
120 |
|
121 |
|
122 |
+
class HuggingFaceDatasetWrapper(Dataset):
|
123 |
"""
|
124 |
+
Makes a standard PyTorch Dataset compatible with a HuggingFace Trainer.
|
125 |
"""
|
126 |
|
127 |
+
def __init__(self, dataset, *args, **kwargs):
|
128 |
super().__init__(*args, **kwargs)
|
129 |
+
self.dataset = dataset
|
|
|
130 |
self.pipeline = []
|
131 |
|
132 |
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
133 |
+
x, y = self.dataset[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
if len(self.pipeline) > 0:
|
135 |
for fn in self.pipeline:
|
136 |
x = fn(x)
|
|
|
141 |
"label": dance_labels,
|
142 |
}
|
143 |
|
144 |
+
def __len__(self):
|
145 |
+
return len(self.dataset)
|
146 |
+
|
147 |
+
def append_to_pipeline(self, fn):
|
148 |
"""
|
149 |
+
Adds a preprocessing step to the dataset.
|
150 |
"""
|
151 |
self.pipeline.append(fn)
|
152 |
|
153 |
|
154 |
+
class BestBallroomDataset(Dataset):
|
155 |
+
def __init__(
|
156 |
+
self, audio_dir="data/ballroom-songs", class_list=None, **kwargs
|
157 |
+
) -> None:
|
158 |
+
super().__init__()
|
159 |
+
song_paths, labels = self.get_examples(audio_dir, class_list)
|
160 |
+
self.song_dataset = SongDataset(song_paths, labels, **kwargs)
|
161 |
+
|
162 |
+
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
163 |
+
return self.song_dataset[index]
|
164 |
+
|
165 |
+
def __len__(self):
|
166 |
+
return len(self.song_dataset)
|
167 |
+
|
168 |
+
def get_examples(self, audio_dir, class_list=None):
|
169 |
+
dances = set(
|
170 |
+
f
|
171 |
+
for f in os.listdir(audio_dir)
|
172 |
+
if os.path.isdir(os.path.join(audio_dir, f))
|
173 |
+
)
|
174 |
+
common_dances = dances
|
175 |
+
if class_list is not None:
|
176 |
+
common_dances = dances & set(class_list)
|
177 |
+
dances = class_list
|
178 |
+
dances = np.array(sorted(dances))
|
179 |
+
song_paths = []
|
180 |
+
labels = []
|
181 |
+
for dance in common_dances:
|
182 |
+
dance_label = (dances == dance).astype("float32")
|
183 |
+
folder_path = os.path.join(audio_dir, dance)
|
184 |
+
folder_contents = [f for f in os.listdir(folder_path) if f.endswith(".wav")]
|
185 |
+
song_paths.extend(os.path.join(folder_path, f) for f in folder_contents)
|
186 |
+
labels.extend([dance_label] * len(folder_contents))
|
187 |
+
|
188 |
+
return np.array(song_paths), np.stack(labels)
|
189 |
+
|
190 |
+
|
191 |
+
class Music4DanceDataset(Dataset):
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
song_data_path,
|
195 |
+
song_audio_path,
|
196 |
+
class_list=None,
|
197 |
+
multi_label=True,
|
198 |
+
min_votes=1,
|
199 |
+
**kwargs,
|
200 |
+
) -> None:
|
201 |
+
super().__init__()
|
202 |
+
df = pd.read_csv(song_data_path)
|
203 |
+
song_paths, labels = get_music4dance_examples(
|
204 |
+
df,
|
205 |
+
song_audio_path,
|
206 |
+
class_list=class_list,
|
207 |
+
multi_label=multi_label,
|
208 |
+
min_votes=min_votes,
|
209 |
+
)
|
210 |
+
self.song_dataset = SongDataset(song_paths, labels, **kwargs)
|
211 |
+
|
212 |
+
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
213 |
+
return self.song_dataset[index]
|
214 |
+
|
215 |
+
def __len__(self):
|
216 |
+
return len(self.song_dataset)
|
217 |
+
|
218 |
+
|
219 |
+
def get_music4dance_examples(
|
220 |
+
df: pd.DataFrame, audio_dir: str, class_list=None, multi_label=True, min_votes=1
|
221 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
222 |
+
sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)].copy(deep=True)
|
223 |
+
sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
|
224 |
+
if class_list is not None:
|
225 |
+
class_list = set(class_list)
|
226 |
+
sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
|
227 |
+
lambda labels: {k: v for k, v in labels.items() if k in class_list}
|
228 |
+
if not pd.isna(labels)
|
229 |
+
and any(label in class_list and amt > 0 for label, amt in labels.items())
|
230 |
+
else np.nan
|
231 |
+
)
|
232 |
+
sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
|
233 |
+
vote_mask = sampled_songs["DanceRating"].apply(
|
234 |
+
lambda dances: any(votes >= min_votes for votes in dances.values())
|
235 |
+
)
|
236 |
+
sampled_songs = sampled_songs[vote_mask]
|
237 |
+
labels = sampled_songs["DanceRating"].apply(
|
238 |
+
lambda dances: {
|
239 |
+
dance: votes for dance, votes in dances.items() if votes >= min_votes
|
240 |
+
}
|
241 |
+
)
|
242 |
+
unique_labels = np.array(get_unique_labels(labels))
|
243 |
+
vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
|
244 |
+
labels = labels.apply(lambda i: vectorizer(i, unique_labels))
|
245 |
+
|
246 |
+
audio_paths = [
|
247 |
+
os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]
|
248 |
+
]
|
249 |
+
|
250 |
+
return np.array(audio_paths), np.stack(labels)
|
251 |
+
|
252 |
+
|
253 |
+
class PipelinedDataset(Dataset):
|
254 |
+
"""
|
255 |
+
Adds a feature extractor preprocessing step to a dataset.
|
256 |
+
"""
|
257 |
+
|
258 |
+
def __init__(self, dataset, feature_extractor):
|
259 |
+
self._data = dataset
|
260 |
+
self.feature_extractor = feature_extractor
|
261 |
+
|
262 |
+
def __len__(self):
|
263 |
+
return len(self._data)
|
264 |
+
|
265 |
+
def __getitem__(self, index):
|
266 |
+
sample, label = self._data[index]
|
267 |
+
|
268 |
+
features = self.feature_extractor(sample)
|
269 |
+
return features, label
|
270 |
+
|
271 |
+
|
272 |
class DanceDataModule(pl.LightningDataModule):
|
273 |
def __init__(
|
274 |
self,
|
275 |
+
dataset: Dataset,
|
|
|
276 |
test_proportion=0.15,
|
277 |
val_proportion=0.1,
|
278 |
target_classes: list[str] = None,
|
|
|
279 |
batch_size: int = 64,
|
280 |
num_workers=10,
|
|
|
|
|
281 |
):
|
282 |
super().__init__()
|
|
|
|
|
283 |
self.val_proportion = val_proportion
|
284 |
self.test_proportion = test_proportion
|
285 |
self.train_proportion = 1.0 - test_proportion - val_proportion
|
286 |
self.target_classes = target_classes
|
287 |
self.batch_size = batch_size
|
288 |
self.num_workers = num_workers
|
289 |
+
self.dataset = dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
def setup(self, stage: str):
|
292 |
+
self.train_ds, self.val_ds, self.test_ds = random_split(
|
293 |
+
self.dataset,
|
294 |
[self.train_proportion, self.val_proportion, self.test_proportion],
|
295 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
def train_dataloader(self):
|
298 |
return DataLoader(
|
|
|
304 |
|
305 |
def val_dataloader(self):
|
306 |
return DataLoader(
|
307 |
+
self.val_ds,
|
308 |
+
batch_size=self.batch_size,
|
309 |
+
num_workers=self.num_workers,
|
310 |
)
|
311 |
|
312 |
def test_dataloader(self):
|
313 |
return DataLoader(
|
314 |
+
self.test_ds,
|
315 |
+
batch_size=self.batch_size,
|
316 |
+
num_workers=self.num_workers,
|
317 |
)
|
318 |
|
319 |
def get_label_weights(self):
|
320 |
+
weights = [
|
321 |
+
ds.song_dataset.get_label_weights() for ds in self.dataset._data.datasets
|
322 |
+
]
|
323 |
+
return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
|
324 |
|
325 |
|
326 |
+
def find_mean_std(dataset: Dataset, zscore=1.96, moe=0.02, p=0.5):
|
327 |
+
"""
|
328 |
+
Estimates the mean and standard deviations of the a dataset.
|
329 |
+
"""
|
330 |
+
sample_size = int(np.ceil((zscore**2 * p * (1 - p)) / (moe**2)))
|
331 |
+
sample_indices = np.random.choice(
|
332 |
+
np.arange(len(dataset)), size=sample_size, replace=False
|
333 |
+
)
|
334 |
+
mean = 0
|
335 |
+
std = 0
|
336 |
+
for i in sample_indices:
|
337 |
+
features = dataset[i][0]
|
338 |
+
mean += features.mean().item()
|
339 |
+
std += features.std().item()
|
340 |
+
print("std", std / sample_size)
|
341 |
+
print("mean", mean / sample_size)
|
342 |
+
|
343 |
+
|
344 |
+
def get_datasets(dataset_config: dict, feature_extractor) -> Dataset:
|
345 |
+
datasets = []
|
346 |
+
for dataset_path, kwargs in dataset_config.items():
|
347 |
+
module_name, class_name = dataset_path.rsplit(".", 1)
|
348 |
+
module = importlib.import_module(module_name)
|
349 |
+
ProvidedDataset = getattr(module, class_name)
|
350 |
+
datasets.append(ProvidedDataset(**kwargs))
|
351 |
+
return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocessing/pipelines.py
CHANGED
@@ -3,29 +3,26 @@ import torchaudio
|
|
3 |
from torchaudio import transforms as taT, functional as taF
|
4 |
import torch.nn as nn
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
super().__init__()
|
17 |
self.input_freq = input_freq
|
18 |
self.snr_mean = snr_mean
|
19 |
-
self.mask_count = mask_count
|
20 |
self.noise = self.get_noise(noise_path)
|
21 |
-
self.
|
22 |
-
self.
|
23 |
-
|
24 |
-
|
|
|
25 |
)
|
26 |
-
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
27 |
-
self.time_mask = taT.TimeMasking(time_mask_size)
|
28 |
-
|
29 |
|
30 |
def get_noise(self, path) -> torch.Tensor:
|
31 |
if path is None:
|
@@ -34,13 +31,15 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
34 |
if noise.shape[0] > 1:
|
35 |
noise = noise.mean(0, keepdim=True)
|
36 |
if sr != self.input_freq:
|
37 |
-
noise = taF.resample(noise,sr, self.input_freq)
|
38 |
return noise
|
39 |
|
40 |
-
def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
|
41 |
-
assert
|
|
|
|
|
42 |
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
|
43 |
-
noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
|
44 |
noise_power = noise.norm(p=2)
|
45 |
signal_power = waveform.norm(p=2)
|
46 |
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
|
@@ -49,14 +48,28 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
49 |
noisy_waveform = (scale * waveform + noise) / 2
|
50 |
return noisy_waveform
|
51 |
|
52 |
-
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
53 |
-
|
54 |
-
waveform = self.resample(waveform)
|
55 |
-
except:
|
56 |
-
print("oops")
|
57 |
waveform = self.preprocess_waveform(waveform)
|
58 |
if self.noise is not None:
|
59 |
waveform = self.add_noise(waveform)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
spec = self.audio_to_spectrogram(waveform)
|
61 |
|
62 |
# Spectrogram augmentation
|
@@ -67,14 +80,11 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
67 |
|
68 |
|
69 |
class WaveformPreprocessing(torch.nn.Module):
|
70 |
-
|
71 |
-
def __init__(self, expected_sample_length:int):
|
72 |
super().__init__()
|
73 |
self.expected_sample_length = expected_sample_length
|
74 |
-
|
75 |
|
76 |
-
|
77 |
-
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
|
78 |
# Take out extra channels
|
79 |
if waveform.shape[0] > 1:
|
80 |
waveform = waveform.mean(0, keepdim=True)
|
@@ -83,30 +93,34 @@ class WaveformPreprocessing(torch.nn.Module):
|
|
83 |
waveform = self._rectify_duration(waveform)
|
84 |
return waveform
|
85 |
|
86 |
-
|
87 |
-
def _rectify_duration(self,waveform:torch.Tensor):
|
88 |
expected_samples = self.expected_sample_length
|
89 |
sample_count = waveform.shape[1]
|
90 |
if expected_samples == sample_count:
|
91 |
return waveform
|
92 |
elif expected_samples > sample_count:
|
93 |
pad_amount = expected_samples - sample_count
|
94 |
-
return torch.nn.functional.pad(
|
|
|
|
|
95 |
else:
|
96 |
-
return waveform[
|
97 |
|
98 |
|
99 |
-
class AudioToSpectrogram
|
100 |
def __init__(
|
101 |
self,
|
102 |
sample_rate=16000,
|
103 |
):
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
self.to_db = taT.AmplitudeToDB()
|
108 |
|
109 |
-
def
|
110 |
spectrogram = self.spec(waveform)
|
111 |
spectrogram = self.to_db(spectrogram)
|
112 |
-
|
|
|
|
|
|
|
|
3 |
from torchaudio import transforms as taT, functional as taF
|
4 |
import torch.nn as nn
|
5 |
|
6 |
+
|
7 |
+
class WaveformTrainingPipeline(torch.nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
input_freq=16000,
|
11 |
+
resample_freq=16000,
|
12 |
+
expected_duration=6,
|
13 |
+
snr_mean=6.0,
|
14 |
+
noise_path=None,
|
15 |
+
):
|
16 |
super().__init__()
|
17 |
self.input_freq = input_freq
|
18 |
self.snr_mean = snr_mean
|
|
|
19 |
self.noise = self.get_noise(noise_path)
|
20 |
+
self.resample_frequency = resample_freq
|
21 |
+
self.resample = taT.Resample(input_freq, resample_freq)
|
22 |
+
|
23 |
+
self.preprocess_waveform = WaveformPreprocessing(
|
24 |
+
resample_freq * expected_duration
|
25 |
)
|
|
|
|
|
|
|
26 |
|
27 |
def get_noise(self, path) -> torch.Tensor:
|
28 |
if path is None:
|
|
|
31 |
if noise.shape[0] > 1:
|
32 |
noise = noise.mean(0, keepdim=True)
|
33 |
if sr != self.input_freq:
|
34 |
+
noise = taF.resample(noise, sr, self.input_freq)
|
35 |
return noise
|
36 |
|
37 |
+
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
|
38 |
+
assert (
|
39 |
+
self.noise is not None
|
40 |
+
), "Cannot add noise because a noise file was not provided."
|
41 |
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
|
42 |
+
noise = self.noise.repeat(1, num_repeats)[:, : waveform.shape[1]]
|
43 |
noise_power = noise.norm(p=2)
|
44 |
signal_power = waveform.norm(p=2)
|
45 |
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
|
|
|
48 |
noisy_waveform = (scale * waveform + noise) / 2
|
49 |
return noisy_waveform
|
50 |
|
51 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
52 |
+
waveform = self.resample(waveform)
|
|
|
|
|
|
|
53 |
waveform = self.preprocess_waveform(waveform)
|
54 |
if self.noise is not None:
|
55 |
waveform = self.add_noise(waveform)
|
56 |
+
return waveform
|
57 |
+
|
58 |
+
|
59 |
+
class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
|
60 |
+
def __init__(
|
61 |
+
self, freq_mask_size=10, time_mask_size=80, mask_count=2, *args, **kwargs
|
62 |
+
):
|
63 |
+
super().__init__(*args, **kwargs)
|
64 |
+
self.mask_count = mask_count
|
65 |
+
self.audio_to_spectrogram = AudioToSpectrogram(
|
66 |
+
sample_rate=self.resample_frequency,
|
67 |
+
)
|
68 |
+
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
69 |
+
self.time_mask = taT.TimeMasking(time_mask_size)
|
70 |
+
|
71 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
72 |
+
waveform = super().forward(waveform)
|
73 |
spec = self.audio_to_spectrogram(waveform)
|
74 |
|
75 |
# Spectrogram augmentation
|
|
|
80 |
|
81 |
|
82 |
class WaveformPreprocessing(torch.nn.Module):
|
83 |
+
def __init__(self, expected_sample_length: int):
|
|
|
84 |
super().__init__()
|
85 |
self.expected_sample_length = expected_sample_length
|
|
|
86 |
|
87 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
|
|
88 |
# Take out extra channels
|
89 |
if waveform.shape[0] > 1:
|
90 |
waveform = waveform.mean(0, keepdim=True)
|
|
|
93 |
waveform = self._rectify_duration(waveform)
|
94 |
return waveform
|
95 |
|
96 |
+
def _rectify_duration(self, waveform: torch.Tensor):
|
|
|
97 |
expected_samples = self.expected_sample_length
|
98 |
sample_count = waveform.shape[1]
|
99 |
if expected_samples == sample_count:
|
100 |
return waveform
|
101 |
elif expected_samples > sample_count:
|
102 |
pad_amount = expected_samples - sample_count
|
103 |
+
return torch.nn.functional.pad(
|
104 |
+
waveform, (0, pad_amount), mode="constant", value=0.0
|
105 |
+
)
|
106 |
else:
|
107 |
+
return waveform[:, :expected_samples]
|
108 |
|
109 |
|
110 |
+
class AudioToSpectrogram:
|
111 |
def __init__(
|
112 |
self,
|
113 |
sample_rate=16000,
|
114 |
):
|
115 |
+
self.spec = taT.MelSpectrogram(
|
116 |
+
sample_rate=sample_rate, n_mels=128, n_fft=1024
|
117 |
+
) # Note: this doesn't work on mps right now.
|
118 |
self.to_db = taT.AmplitudeToDB()
|
119 |
|
120 |
+
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
|
121 |
spectrogram = self.spec(waveform)
|
122 |
spectrogram = self.to_db(spectrogram)
|
123 |
+
|
124 |
+
# Normalize
|
125 |
+
spectrogram = (spectrogram - spectrogram.mean()) / (2 * spectrogram.std())
|
126 |
+
return spectrogram
|
preprocessing/preprocess.py
CHANGED
@@ -3,7 +3,9 @@ import numpy as np
|
|
3 |
import re
|
4 |
import json
|
5 |
from pathlib import Path
|
|
|
6 |
import os
|
|
|
7 |
import torchaudio
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
@@ -95,7 +97,6 @@ def vectorize_label_probs(
|
|
95 |
for k, v in labels.items():
|
96 |
item_vec = (unique_labels == k) * v
|
97 |
label_vec += item_vec
|
98 |
-
lv_cache = label_vec.copy()
|
99 |
label_vec[label_vec < 0] = 0
|
100 |
label_vec /= label_vec.sum()
|
101 |
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
|
@@ -113,49 +114,70 @@ def vectorize_multi_label(
|
|
113 |
return probs
|
114 |
|
115 |
|
116 |
-
def
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
|
150 |
if __name__ == "__main__":
|
151 |
-
|
152 |
-
df = pd.read_csv("data/songs.csv")
|
153 |
-
l = links["link"].str.strip()
|
154 |
-
l = l.apply(lambda url: url if "http" in url else np.nan)
|
155 |
-
l = l.dropna()
|
156 |
-
df["Sample"].update(l)
|
157 |
-
addna = lambda url: url if type(url) == str and "http" in url else np.nan
|
158 |
-
df["Sample"] = df["Sample"].apply(addna)
|
159 |
-
is_valid = validate_audio(df["Sample"], "data/samples")
|
160 |
-
df["valid"] = is_valid
|
161 |
-
df.to_csv("data/songs_validated.csv")
|
|
|
3 |
import re
|
4 |
import json
|
5 |
from pathlib import Path
|
6 |
+
import glob
|
7 |
import os
|
8 |
+
import shutil
|
9 |
import torchaudio
|
10 |
import torch
|
11 |
from tqdm import tqdm
|
|
|
97 |
for k, v in labels.items():
|
98 |
item_vec = (unique_labels == k) * v
|
99 |
label_vec += item_vec
|
|
|
100 |
label_vec[label_vec < 0] = 0
|
101 |
label_vec /= label_vec.sum()
|
102 |
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
|
|
|
114 |
return probs
|
115 |
|
116 |
|
117 |
+
def sort_yt_files(
|
118 |
+
aliases_path="data/dance_aliases.json",
|
119 |
+
all_dances_folder="data/best-ballroom-music",
|
120 |
+
original_location="data/yt-ballroom-music/",
|
121 |
+
):
|
122 |
+
def normalize_string(s):
|
123 |
+
# Lowercase string and remove special characters
|
124 |
+
return re.sub(r"\W+", "", s.lower())
|
125 |
+
|
126 |
+
with open(aliases_path, "r") as f:
|
127 |
+
dances = json.load(f)
|
128 |
+
|
129 |
+
# Normalize the dance inputs and aliases
|
130 |
+
normalized_dances = {
|
131 |
+
normalize_string(dance_id): [normalize_string(alias) for alias in aliases]
|
132 |
+
for dance_id, aliases in dances.items()
|
133 |
+
}
|
134 |
+
|
135 |
+
# For every wav file in the target folder
|
136 |
+
bad_files = []
|
137 |
+
progress_bar = tqdm(os.listdir(all_dances_folder), unit="files moved")
|
138 |
+
for file_name in progress_bar:
|
139 |
+
if file_name.endswith(".wav"):
|
140 |
+
# check if the normalized wav file name contains the normalized dance alias
|
141 |
+
normalized_file_name = normalize_string(file_name)
|
142 |
+
|
143 |
+
matching_dance_ids = [
|
144 |
+
dance_id
|
145 |
+
for dance_id, aliases in normalized_dances.items()
|
146 |
+
if any(alias in normalized_file_name for alias in aliases)
|
147 |
+
]
|
148 |
+
|
149 |
+
if len(matching_dance_ids) == 0:
|
150 |
+
# See if the dance is in the path
|
151 |
+
original_filename = file_name.replace(".wav", "")
|
152 |
+
matches = glob.glob(
|
153 |
+
os.path.join(original_location, "**", original_filename),
|
154 |
+
recursive=True,
|
155 |
+
)
|
156 |
+
if len(matches) == 1:
|
157 |
+
normalized_file_name = normalize_string(matches[0])
|
158 |
+
matching_dance_ids = [
|
159 |
+
dance_id
|
160 |
+
for dance_id, aliases in normalized_dances.items()
|
161 |
+
if any(alias in normalized_file_name for alias in aliases)
|
162 |
+
]
|
163 |
+
|
164 |
+
if "swz" in matching_dance_ids and "vwz" in matching_dance_ids:
|
165 |
+
matching_dance_ids.remove("swz")
|
166 |
+
if len(matching_dance_ids) > 1 and "lhp" in matching_dance_ids:
|
167 |
+
matching_dance_ids.remove("lhp")
|
168 |
+
|
169 |
+
if len(matching_dance_ids) != 1:
|
170 |
+
bad_files.append(file_name)
|
171 |
+
progress_bar.set_description(f"bad files: {len(bad_files)}")
|
172 |
+
continue
|
173 |
+
dst = os.path.join("data", "ballroom-songs", matching_dance_ids[0].upper())
|
174 |
+
os.makedirs(dst, exist_ok=True)
|
175 |
+
filepath = os.path.join(all_dances_folder, file_name)
|
176 |
+
shutil.copy(filepath, os.path.join(dst, file_name))
|
177 |
+
|
178 |
+
with open("data/bad_files.json", "w") as f:
|
179 |
+
json.dump(bad_files, f)
|
180 |
|
181 |
|
182 |
if __name__ == "__main__":
|
183 |
+
sort_yt_files()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
import torchaudio
|
2 |
-
import numpy as np
|
3 |
-
from audio_utils import play_audio
|
4 |
-
from preprocessing.dataset import SongDataset
|
5 |
-
|
6 |
-
def test_audio_splitting():
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
audio_paths = ["data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav"]
|
11 |
-
labels = [np.array([1,0,1,0])]
|
12 |
-
whole_song, sr = torchaudio.load("data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav")
|
13 |
-
|
14 |
-
ds = SongDataset(audio_paths, labels)
|
15 |
-
song_parts = (ds._waveform_from_index(i) for i in range(len(ds)))
|
16 |
-
print("Sample Parts")
|
17 |
-
for part in song_parts:
|
18 |
-
play_audio(part,sr)
|
19 |
-
|
20 |
-
|
21 |
-
print("Whole Sample")
|
22 |
-
play_audio(whole_song,sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_datasets.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import set_path
|
2 |
+
import pytest
|
3 |
+
|
4 |
+
set_path()
|
5 |
+
from preprocessing.dataset import PipelinedDataset, BestBallroomDataset, SongDataset
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def test_preprocess_dataset():
|
10 |
+
dataset = BestBallroomDataset()
|
11 |
+
dataset = PipelinedDataset(dataset, lambda x: x * 0.0)
|
12 |
+
assert isinstance(dataset._data.song_dataset, SongDataset)
|
13 |
+
assert hasattr(dataset, "feature_extractor")
|
14 |
+
features, _ = dataset[0]
|
15 |
+
assert np.unique(features.numpy())[0] == 0.0
|
16 |
+
with pytest.raises(AttributeError):
|
17 |
+
dataset.foo
|
tests/test_pipelines.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import set_path
|
2 |
+
|
3 |
+
set_path()
|
4 |
+
from preprocessing.dataset import BestBallroomDataset
|
5 |
+
from preprocessing.pipelines import SpectrogramTrainingPipeline
|
6 |
+
|
7 |
+
|
8 |
+
def test_spectrogram_training_pipeline():
|
9 |
+
ds = BestBallroomDataset()
|
10 |
+
pipeline = SpectrogramTrainingPipeline()
|
11 |
+
waveform, _ = ds[0]
|
12 |
+
out = pipeline(waveform)
|
13 |
+
assert len(out.shape) == 3
|
tests/utils.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
# Add parent directory to Python path
|
6 |
+
def set_path():
|
7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
train.py
CHANGED
@@ -1,49 +1,16 @@
|
|
1 |
-
from torch.utils.data import DataLoader
|
2 |
-
import pandas as pd
|
3 |
from typing import Callable
|
4 |
-
|
5 |
-
from torch.utils.data import SubsetRandomSampler
|
6 |
-
from sklearn.model_selection import KFold
|
7 |
-
import pytorch_lightning as pl
|
8 |
-
from pytorch_lightning import callbacks as cb
|
9 |
-
from models.utils import LabelWeightedBCELoss
|
10 |
-
from models.audio_spectrogram_transformer import (
|
11 |
-
train as train_audio_spectrogram_transformer,
|
12 |
-
get_id_label_mapping,
|
13 |
-
)
|
14 |
-
from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment
|
15 |
-
from preprocessing.preprocess import get_examples
|
16 |
-
from models.residual import ResidualDancer, TrainingEnvironment
|
17 |
-
from models.decision_tree import DanceTreeClassifier, features_from_path
|
18 |
import yaml
|
19 |
-
from preprocessing.dataset import (
|
20 |
-
DanceDataModule,
|
21 |
-
WaveformSongDataset,
|
22 |
-
HuggingFaceWaveformSongDataset,
|
23 |
-
)
|
24 |
-
from torch.utils.data import random_split
|
25 |
-
import numpy as np
|
26 |
-
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
27 |
from argparse import ArgumentParser
|
|
|
28 |
|
29 |
-
|
30 |
-
import torch
|
31 |
-
from torch import nn
|
32 |
-
from sklearn.utils.class_weight import compute_class_weight
|
33 |
|
34 |
|
35 |
def get_training_fn(id: str) -> Callable:
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
case "ast_hf":
|
40 |
-
return train_ast
|
41 |
-
case "residual_dancer":
|
42 |
-
return train_model
|
43 |
-
case "decision_tree":
|
44 |
-
return train_decision_tree
|
45 |
-
case _:
|
46 |
-
raise Exception(f"Couldn't find a training function for '{id}'.")
|
47 |
|
48 |
|
49 |
def get_config(filepath: str) -> dict:
|
@@ -52,141 +19,6 @@ def get_config(filepath: str) -> dict:
|
|
52 |
return config
|
53 |
|
54 |
|
55 |
-
def cross_validation(config, k=5):
|
56 |
-
df = pd.read_csv("data/songs.csv")
|
57 |
-
g_config = config["global"]
|
58 |
-
batch_size = config["data_module"]["batch_size"]
|
59 |
-
x, y = get_examples(df, "data/samples", class_list=g_config["dance_ids"])
|
60 |
-
dataset = SongDataset(x, y)
|
61 |
-
splits = KFold(n_splits=k, shuffle=True, random_state=g_config["seed"])
|
62 |
-
trainer = pl.Trainer(accelerator=g_config["device"])
|
63 |
-
for fold, (train_idx, val_idx) in enumerate(splits.split(x, y)):
|
64 |
-
print(f"Fold {fold+1}")
|
65 |
-
model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
|
66 |
-
train_env = TrainingEnvironment(model, nn.BCELoss())
|
67 |
-
train_sampler = SubsetRandomSampler(train_idx)
|
68 |
-
test_sampler = SubsetRandomSampler(val_idx)
|
69 |
-
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
|
70 |
-
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
|
71 |
-
trainer.fit(train_env, train_loader)
|
72 |
-
trainer.test(train_env, test_loader)
|
73 |
-
|
74 |
-
|
75 |
-
def train_model(config: dict):
|
76 |
-
TARGET_CLASSES = config["global"]["dance_ids"]
|
77 |
-
DEVICE = config["global"]["device"]
|
78 |
-
SEED = config["global"]["seed"]
|
79 |
-
pl.seed_everything(SEED, workers=True)
|
80 |
-
data = DanceDataModule(target_classes=TARGET_CLASSES, **config["data_module"])
|
81 |
-
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
|
82 |
-
label_weights = data.get_label_weights().to(DEVICE)
|
83 |
-
criterion = LabelWeightedBCELoss(
|
84 |
-
label_weights
|
85 |
-
) # nn.CrossEntropyLoss(label_weights)
|
86 |
-
train_env = TrainingEnvironment(model, criterion, config)
|
87 |
-
callbacks = [
|
88 |
-
# cb.LearningRateFinder(update_attr=True),
|
89 |
-
cb.EarlyStopping("val/loss", patience=5),
|
90 |
-
cb.StochasticWeightAveraging(1e-2),
|
91 |
-
cb.RichProgressBar(),
|
92 |
-
cb.DeviceStatsMonitor(),
|
93 |
-
]
|
94 |
-
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
95 |
-
trainer.fit(train_env, datamodule=data)
|
96 |
-
trainer.test(train_env, datamodule=data)
|
97 |
-
|
98 |
-
|
99 |
-
def train_ast(config: dict):
|
100 |
-
TARGET_CLASSES = config["global"]["dance_ids"]
|
101 |
-
DEVICE = config["global"]["device"]
|
102 |
-
SEED = config["global"]["seed"]
|
103 |
-
dataset_kwargs = config["data_module"]["dataset_kwargs"]
|
104 |
-
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
105 |
-
train_proportion = 1.0 - test_proportion
|
106 |
-
song_data_path = "data/songs_cleaned.csv"
|
107 |
-
song_audio_path = "data/samples"
|
108 |
-
pl.seed_everything(SEED, workers=True)
|
109 |
-
|
110 |
-
df = pd.read_csv(song_data_path)
|
111 |
-
x, y = get_examples(
|
112 |
-
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
|
113 |
-
)
|
114 |
-
train_i, test_i = random_split(
|
115 |
-
np.arange(len(x)), [train_proportion, test_proportion]
|
116 |
-
)
|
117 |
-
train_ds = HuggingFaceWaveformSongDataset(
|
118 |
-
x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000
|
119 |
-
)
|
120 |
-
test_ds = HuggingFaceWaveformSongDataset(
|
121 |
-
x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000
|
122 |
-
)
|
123 |
-
train_audio_spectrogram_transformer(
|
124 |
-
TARGET_CLASSES, train_ds, test_ds, device=DEVICE
|
125 |
-
)
|
126 |
-
|
127 |
-
|
128 |
-
def train_ast_lightning(config: dict):
|
129 |
-
"""
|
130 |
-
work on integration between waveform dataset and environment. Should work for both HF and PTL.
|
131 |
-
"""
|
132 |
-
TARGET_CLASSES = config["global"]["dance_ids"]
|
133 |
-
DEVICE = config["global"]["device"]
|
134 |
-
SEED = config["global"]["seed"]
|
135 |
-
pl.seed_everything(SEED, workers=True)
|
136 |
-
data = DanceDataModule(
|
137 |
-
target_classes=TARGET_CLASSES,
|
138 |
-
dataset_cls=WaveformSongDataset,
|
139 |
-
**config["data_module"],
|
140 |
-
)
|
141 |
-
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
142 |
-
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
143 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
|
144 |
-
|
145 |
-
model = AutoModelForAudioClassification.from_pretrained(
|
146 |
-
model_checkpoint,
|
147 |
-
num_labels=len(label2id),
|
148 |
-
label2id=label2id,
|
149 |
-
id2label=id2label,
|
150 |
-
ignore_mismatched_sizes=True,
|
151 |
-
).to(DEVICE)
|
152 |
-
label_weights = data.get_label_weights().to(DEVICE)
|
153 |
-
criterion = LabelWeightedBCELoss(
|
154 |
-
label_weights
|
155 |
-
) # nn.CrossEntropyLoss(label_weights)
|
156 |
-
train_env = WaveformTrainingEnvironment(model, criterion, feature_extractor, config)
|
157 |
-
callbacks = [
|
158 |
-
# cb.LearningRateFinder(update_attr=True),
|
159 |
-
cb.EarlyStopping("val/loss", patience=5),
|
160 |
-
cb.StochasticWeightAveraging(1e-2),
|
161 |
-
cb.RichProgressBar(),
|
162 |
-
]
|
163 |
-
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
|
164 |
-
trainer.fit(train_env, datamodule=data)
|
165 |
-
trainer.test(train_env, datamodule=data)
|
166 |
-
|
167 |
-
|
168 |
-
def train_decision_tree(config: dict):
|
169 |
-
TARGET_CLASSES = config["global"]["dance_ids"]
|
170 |
-
DEVICE = config["global"]["device"]
|
171 |
-
SEED = config["global"]["seed"]
|
172 |
-
song_data_path = config["data_module"]["song_data_path"]
|
173 |
-
song_audio_path = config["data_module"]["song_audio_path"]
|
174 |
-
pl.seed_everything(SEED, workers=True)
|
175 |
-
|
176 |
-
df = pd.read_csv(song_data_path)
|
177 |
-
x, y = get_examples(
|
178 |
-
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True
|
179 |
-
)
|
180 |
-
# Convert y back to string classes
|
181 |
-
y = np.array(TARGET_CLASSES)[y.argmax(-1)]
|
182 |
-
train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2])
|
183 |
-
train_paths, train_y = x[train_i], y[train_i]
|
184 |
-
train_x = features_from_path(train_paths)
|
185 |
-
model = DanceTreeClassifier(device=DEVICE)
|
186 |
-
model.fit(train_x, train_y)
|
187 |
-
model.save()
|
188 |
-
|
189 |
-
|
190 |
if __name__ == "__main__":
|
191 |
parser = ArgumentParser(
|
192 |
description="Trains models on the dance dataset and saves weights."
|
@@ -198,6 +30,7 @@ if __name__ == "__main__":
|
|
198 |
)
|
199 |
args = parser.parse_args()
|
200 |
config = get_config(args.config)
|
201 |
-
|
202 |
-
|
|
|
203 |
train(config)
|
|
|
|
|
|
|
1 |
from typing import Callable
|
2 |
+
import importlib
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from argparse import ArgumentParser
|
5 |
+
import os
|
6 |
|
7 |
+
ROOT_DIR = os.path.basename(os.path.dirname(__file__))
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
def get_training_fn(id: str) -> Callable:
|
11 |
+
module_name, fn_name = id.rsplit(".", 1)
|
12 |
+
module = importlib.import_module("models." + module_name, ROOT_DIR)
|
13 |
+
return getattr(module, fn_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
def get_config(filepath: str) -> dict:
|
|
|
19 |
return config
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
if __name__ == "__main__":
|
23 |
parser = ArgumentParser(
|
24 |
description="Trains models on the dance dataset and saves weights."
|
|
|
30 |
)
|
31 |
args = parser.parse_args()
|
32 |
config = get_config(args.config)
|
33 |
+
training_fn_path = config["training_fn"]
|
34 |
+
print(f"Config: {args.config}\nTrainer Id: {training_fn_path}")
|
35 |
+
train = get_training_fn(training_fn_path)
|
36 |
train(config)
|