waidhoferj commited on
Commit
0030bc6
β€’
1 Parent(s): 4b8361a
.gitattributes CHANGED
@@ -1,2 +1,3 @@
1
  *.wav filter=lfs diff=lfs merge=lfs -text
2
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.wav filter=lfs diff=lfs merge=lfs -text
2
  *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,15 +1,10 @@
1
  __pycache__
2
  .DS_Store
3
- data/samples
4
- data/spotify-samples
5
- data/samples-backup.zip
6
- data/songs.csv
7
- data/songs_original.csv
8
  logs
9
  gradio_cached_examples
10
  explore.ipynb
11
  scrapers/auth
12
  lightning_logs
13
- data/backup_1.csv
14
- data/backup.csv
15
- data/*.zip
 
1
  __pycache__
2
  .DS_Store
3
+ data
 
 
 
 
4
  logs
5
  gradio_cached_examples
6
  explore.ipynb
7
  scrapers/auth
8
  lightning_logs
9
+ .lr_find_*
10
+ .cache
 
README.md CHANGED
@@ -11,3 +11,11 @@ pinned: false
11
  ---
12
 
13
  # Dance Classifier
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  # Dance Classifier
14
+
15
+ Classifies the dance style that best accompanies a provided song. Users record or upload an audio clip and the model provides a list of matching dance styles.
16
+
17
+ ## Getting Started
18
+
19
+ 1. Download dependencies: `conda env create --file environment.yml`
20
+ 2. Open environment: `conda activate dancer-net`
21
+ 3. Start the demo application: `python app.py`
TODO.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ - βœ… Ensure app.py audio input sounds like training data
2
+ - Verify that the training spectrogram matches the predict spectrogram
3
+ - Count number of example misses in dataset loading
4
+ - Verify windowing and jitter params in Song Dataset
5
+ - Create an attention-based network
6
+ - βœ… Increase parameter count in network
7
+ - Verify that labels really match what is on the music4dance site
8
+ - Read the Medium series about audio DL
9
+ - double check \_rectify_duration
10
+ - Filter out songs that have only one vote
assets/song-samples/take_it_to_the_limit.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c69e0eeb4321c44daaaaf95dd596b1d813b9f7e9b5ef4ac5ae9fe11878d4b13b
3
- size 5292082
 
 
 
 
assets/song-samples/{alejandro.wav β†’ the_long_day_is_over.wav} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:85f9a65fc4adb1fc0cbdbfafb7f7268a0934d97a120110d3f3a43375e59cba54
3
- size 5292078
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8f957921bbd5c322f67748aca228dd7ebf9af005692c57d1050299861883214
3
+ size 5290062
audio_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ from IPython.display import Audio, display
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ SAMPLE_RIR_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav"
6
+
7
+ SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
8
+
9
+ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
10
+ spec = spec.squeeze(0)
11
+ spec = spec.numpy()
12
+ fig, axs = plt.subplots(1, 1)
13
+ axs.set_title(title or "Spectrogram (db)")
14
+ axs.set_ylabel(ylabel)
15
+ axs.set_xlabel("frame")
16
+ im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
17
+ if xmax:
18
+ axs.set_xlim((0, xmax))
19
+ fig.colorbar(im, ax=axs)
20
+ plt.show(block=False)
21
+
22
+ def play_audio(waveform, sample_rate):
23
+ waveform = waveform.numpy()
24
+
25
+ num_channels, num_frames = waveform.shape
26
+ if num_channels == 1:
27
+ display(Audio(waveform[0], rate=sample_rate))
28
+ elif num_channels == 2:
29
+ display(Audio((waveform[0], waveform[1]), rate=sample_rate))
30
+ else:
31
+ raise ValueError("Waveform with more than 2 channels are not supported.")
32
+
33
+ def get_rir_sample(path, resample=None, processed=False):
34
+ rir_raw, sample_rate = torch.load(path)
35
+ if not processed:
36
+ return rir_raw, sample_rate
37
+ rir = rir_raw[:, int(sample_rate*1.01):int(sample_rate*1.3)]
38
+ rir = rir / torch.norm(rir, p=2)
39
+ rir = torch.flip(rir, [1])
40
+ return rir, sample_rate
41
+
42
+
environment.yml CHANGED
@@ -15,6 +15,9 @@ dependencies:
15
  - requests
16
  - bidict
17
  - tqdm
 
 
18
  - pip
19
  - gradio
 
20
  prefix: /opt/homebrew/Caskroom/miniforge/base/envs/dancer-net
 
15
  - requests
16
  - bidict
17
  - tqdm
18
+ - pytorch-lightning
19
+ - rich
20
  - pip
21
  - gradio
22
+ - wakepy
23
  prefix: /opt/homebrew/Caskroom/miniforge/base/envs/dancer-net
models/config/dance-predictor.yaml CHANGED
@@ -1,20 +1,15 @@
1
- weight_path: lightning_logs/version_0/checkpoints/epoch=5-step=870.ckpt
2
  expected_duration: 6
3
- threshold: 0.5
4
  resample_frequency: 16000
5
  device: cpu
6
  labels:
7
  - Argentine Tango
8
- - Balboa
9
  - Bachata
10
- - Blues
11
  - Cha Cha
12
- - Cumbia
13
- - Carolina Shag
14
  - East Coast Swing
15
  - Hustle
16
  - Jive
17
- - Lindy Hop
18
  - Quickstep
19
  - Rumba
20
  - Slow Foxtrot
@@ -23,4 +18,3 @@ labels:
23
  - Slow Waltz
24
  - Tango (Ballroom)
25
  - Viennese Waltz
26
- - West Coast Swing
 
1
+ weight_path: models/weights/ResidualDancer/weights.ckpt
2
  expected_duration: 6
3
+ threshold: 0.4
4
  resample_frequency: 16000
5
  device: cpu
6
  labels:
7
  - Argentine Tango
 
8
  - Bachata
 
9
  - Cha Cha
 
 
10
  - East Coast Swing
11
  - Hustle
12
  - Jive
 
13
  - Quickstep
14
  - Rumba
15
  - Slow Foxtrot
 
18
  - Slow Waltz
19
  - Tango (Ballroom)
20
  - Viennese Waltz
 
models/config/train.yaml CHANGED
@@ -1,23 +1,46 @@
1
- device: mps
2
- seed: 42
3
- dance_ids:
4
- - ATN
5
- - BBA
6
- - BCH
7
- - BLU
8
- - CHA
9
- - CMB
10
- - CSG
11
- - ECS
12
- - HST
13
- - JIV
14
- - LHP
15
- - QST
16
- - RMB
17
- - SFT
18
- - SLS
19
- - SMB
20
- - SWZ
21
- - TGO
22
- - VWZ
23
- - WCS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ device: mps
3
+ seed: 42
4
+ dance_ids:
5
+ - ATN
6
+ - BCH
7
+ - CHA
8
+ - ECS
9
+ - HST
10
+ - JIV
11
+ - QST
12
+ - RMB
13
+ - SFT
14
+ - SLS
15
+ - SMB
16
+ - SWZ
17
+ - TGO
18
+ - VWZ
19
+ - WCS
20
+ data_module:
21
+ batch_size: 1024
22
+ num_workers: 10
23
+ min_votes: 2
24
+ song_data_path: data/songs_cleaned.csv
25
+ song_audio_path: data/samples
26
+ dataset_kwargs:
27
+ audio_window_duration: 6
28
+ audio_window_jitter: 1.5
29
+ audio_pipeline_kwargs:
30
+ mask_count: 0 # Don't mask the data
31
+ snr_mean: 15.0 # Pretty much eliminate the noise
32
+ freq_mask_size: 10
33
+ time_mask_size: 80
34
+
35
+ trainer:
36
+ log_every_n_steps: 15
37
+ accelerator: gpu
38
+ max_epochs: 50
39
+ min_epochs: 5
40
+ fast_dev_run: False
41
+ track_grad_norm: 2
42
+ # gradient_clip_val: 0.5
43
+ training_environment:
44
+ learning_rate: 0.0033
45
+ model:
46
+ n_channels: 128
models/residual.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import torchaudio
8
  import yaml
9
  from .utils import calculate_metrics
10
- from preprocessing.pipelines import AudioPipeline
11
 
12
  # Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
13
 
@@ -15,6 +15,9 @@ class ResidualDancer(nn.Module):
15
  def __init__(self,n_channels=128, n_classes=50):
16
  super().__init__()
17
 
 
 
 
18
  # Spectrogram
19
  self.spec_bn = nn.BatchNorm2d(1)
20
 
@@ -33,7 +36,7 @@ class ResidualDancer(nn.Module):
33
  self.dense1 = nn.Linear(n_channels*4, n_channels*4)
34
  self.bn = nn.BatchNorm1d(n_channels*4)
35
  self.dense2 = nn.Linear(n_channels*4, n_classes)
36
- self.dropout = nn.Dropout(0.3)
37
 
38
  def forward(self, x):
39
  x = self.spec_bn(x)
@@ -88,34 +91,51 @@ class ResBlock(nn.Module):
88
 
89
  class TrainingEnvironment(pl.LightningModule):
90
 
91
- def __init__(self, model: nn.Module, criterion: nn.Module, learning_rate=1e-4, *args, **kwargs):
92
  super().__init__(*args, **kwargs)
93
  self.model = model
94
  self.criterion = criterion
95
  self.learning_rate = learning_rate
 
 
 
 
 
 
 
96
 
97
  def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
98
  features, labels = batch
99
  outputs = self.model(features)
100
  loss = self.criterion(outputs, labels)
101
- batch_metrics = calculate_metrics(outputs, labels)
102
- self.log_dict(batch_metrics)
 
 
 
 
 
 
 
103
  return loss
104
 
 
105
  def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
106
  x, y = batch
107
  preds = self.model(x)
108
- metrics = calculate_metrics(preds, y, prefix="val_")
109
- metrics["val_loss"] = self.criterion(preds, y)
110
- self.log_dict(metrics)
111
 
112
  def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
113
  x, y = batch
114
  preds = self.model(x)
115
- self.log_dict(calculate_metrics(preds, y, prefix="test_"))
116
 
117
  def configure_optimizers(self):
118
- return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
 
 
119
 
120
 
121
  class DancePredictor:
@@ -133,7 +153,8 @@ class DancePredictor:
133
  self.expected_duration = expected_duration
134
  self.threshold = threshold
135
  self.resample_frequency = resample_frequency
136
- self.audio_pipeline = AudioPipeline(input_freq=self.resample_frequency)
 
137
  self.labels = np.array(labels)
138
  self.device = device
139
  self.model = self.get_model(weight_path)
@@ -155,20 +176,16 @@ class DancePredictor:
155
 
156
  @torch.no_grad()
157
  def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
158
- min_sample_len = sample_rate * self.expected_duration
159
- if min_sample_len > len(waveform):
160
- raise Exception("You must record for at least 6 seconds")
161
- if len(waveform.shape) > 1 and waveform.shape[1] > 1:
162
  waveform = waveform.transpose(1,0)
163
- waveform = waveform.mean(axis=0, keepdims=True)
164
- else:
165
  waveform = np.expand_dims(waveform, 0)
166
- waveform = waveform[: ,:min_sample_len]
167
  waveform = torch.from_numpy(waveform.astype("int16"))
168
  waveform = torchaudio.functional.apply_codec(waveform,sample_rate, "wav", channels_first=True)
169
 
170
  waveform = torchaudio.functional.resample(waveform, sample_rate,self.resample_frequency)
171
- spectrogram = self.audio_pipeline(waveform)
 
172
  spectrogram = spectrogram.unsqueeze(0).to(self.device)
173
 
174
  results = self.model(spectrogram)
 
7
  import torchaudio
8
  import yaml
9
  from .utils import calculate_metrics
10
+ from preprocessing.pipelines import WaveformPreprocessing, AudioToSpectrogram
11
 
12
  # Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py
13
 
 
15
  def __init__(self,n_channels=128, n_classes=50):
16
  super().__init__()
17
 
18
+ self.n_channels = n_channels
19
+ self.n_classes = n_classes
20
+
21
  # Spectrogram
22
  self.spec_bn = nn.BatchNorm2d(1)
23
 
 
36
  self.dense1 = nn.Linear(n_channels*4, n_channels*4)
37
  self.bn = nn.BatchNorm1d(n_channels*4)
38
  self.dense2 = nn.Linear(n_channels*4, n_classes)
39
+ self.dropout = nn.Dropout(0.2)
40
 
41
  def forward(self, x):
42
  x = self.spec_bn(x)
 
91
 
92
  class TrainingEnvironment(pl.LightningModule):
93
 
94
+ def __init__(self, model: nn.Module, criterion: nn.Module, config:dict, learning_rate=1e-4, *args, **kwargs):
95
  super().__init__(*args, **kwargs)
96
  self.model = model
97
  self.criterion = criterion
98
  self.learning_rate = learning_rate
99
+ self.config=config
100
+ self.save_hyperparameters({
101
+ "model": type(model).__name__,
102
+ "loss": type(criterion).__name__,
103
+ "config": config,
104
+ **kwargs
105
+ })
106
 
107
  def training_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int) -> torch.Tensor:
108
  features, labels = batch
109
  outputs = self.model(features)
110
  loss = self.criterion(outputs, labels)
111
+ metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True)
112
+ self.log_dict(metrics, prog_bar=True)
113
+ # Log spectrograms
114
+ if batch_index % 100 == 0:
115
+ tensorboard = self.logger.experiment
116
+ img_index = torch.randint(0, len(features), (1,)).item()
117
+ img = features[img_index][0]
118
+ img = (img - img.min()) / (img.max() - img.min())
119
+ tensorboard.add_image(f"batch: {batch_index}, element: {img_index}", img, 0, dataformats='HW')
120
  return loss
121
 
122
+
123
  def validation_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
124
  x, y = batch
125
  preds = self.model(x)
126
+ metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True)
127
+ metrics["val/loss"] = self.criterion(preds, y)
128
+ self.log_dict(metrics,prog_bar=True)
129
 
130
  def test_step(self, batch:tuple[torch.Tensor, torch.TensorType], batch_index:int):
131
  x, y = batch
132
  preds = self.model(x)
133
+ self.log_dict(calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True)
134
 
135
  def configure_optimizers(self):
136
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
137
+ # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"}
138
+ return [optimizer]
139
 
140
 
141
  class DancePredictor:
 
153
  self.expected_duration = expected_duration
154
  self.threshold = threshold
155
  self.resample_frequency = resample_frequency
156
+ self.preprocess_waveform = WaveformPreprocessing(resample_frequency * expected_duration)
157
+ self.audio_to_spectrogram = AudioToSpectrogram(resample_frequency)
158
  self.labels = np.array(labels)
159
  self.device = device
160
  self.model = self.get_model(weight_path)
 
176
 
177
  @torch.no_grad()
178
  def __call__(self, waveform: np.ndarray, sample_rate:int) -> dict[str,float]:
179
+ if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
 
 
 
180
  waveform = waveform.transpose(1,0)
181
+ elif len(waveform.shape) == 1:
 
182
  waveform = np.expand_dims(waveform, 0)
 
183
  waveform = torch.from_numpy(waveform.astype("int16"))
184
  waveform = torchaudio.functional.apply_codec(waveform,sample_rate, "wav", channels_first=True)
185
 
186
  waveform = torchaudio.functional.resample(waveform, sample_rate,self.resample_frequency)
187
+ waveform = self.preprocess_waveform(waveform)
188
+ spectrogram = self.audio_to_spectrogram(waveform)
189
  spectrogram = spectrogram.unsqueeze(0).to(self.device)
190
 
191
  results = self.model(spectrogram)
models/utils.py CHANGED
@@ -4,6 +4,10 @@ import numpy as np
4
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
5
 
6
  class LabelWeightedBCELoss(nn.Module):
 
 
 
 
7
  def __init__(self, label_weights:torch.Tensor, reduction="mean"):
8
  super().__init__()
9
  self.label_weights = label_weights
@@ -22,17 +26,37 @@ class LabelWeightedBCELoss(nn.Module):
22
  return self.reduction(losses)
23
 
24
 
25
- def calculate_metrics(pred, target, threshold=0.5, prefix="") -> dict[str, torch.Tensor]:
 
 
 
26
  target = target.detach().cpu().numpy()
27
  pred = pred.detach().cpu().numpy()
28
- pred = np.array(pred > threshold, dtype=float)
 
 
 
 
 
29
  metrics= {
30
- 'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
31
- 'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
32
- 'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
33
- 'accuracy': accuracy_score(y_true=target, y_pred=pred),
34
  }
35
- if prefix != "":
36
- metrics = {prefix + k : v for k, v in metrics.items()}
 
 
 
 
 
37
 
38
- return {k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
 
 
 
 
 
 
 
 
4
  from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
5
 
6
  class LabelWeightedBCELoss(nn.Module):
7
+ """
8
+ Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
9
+ Allows for the weighing of each probability distribution wrt loss.
10
+ """
11
  def __init__(self, label_weights:torch.Tensor, reduction="mean"):
12
  super().__init__()
13
  self.label_weights = label_weights
 
26
  return self.reduction(losses)
27
 
28
 
29
+ # TODO: Code a onehot
30
+
31
+
32
+ def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]:
33
  target = target.detach().cpu().numpy()
34
  pred = pred.detach().cpu().numpy()
35
+ params = {
36
+ "y_true": target if multi_label else target.argmax(1) ,
37
+ "y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1),
38
+ "zero_division": 0,
39
+ "average":"macro"
40
+ }
41
  metrics= {
42
+ 'precision': precision_score(**params),
43
+ 'recall': recall_score(**params),
44
+ 'f1': f1_score(**params),
45
+ 'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
46
  }
47
+ return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
48
+
49
+ class EarlyStopping:
50
+ def __init__(self, patience=0):
51
+ self.patience = patience
52
+ self.last_measure = np.inf
53
+ self.consecutive_increase = 0
54
 
55
+ def step(self, val) -> bool:
56
+ if self.last_measure <= val:
57
+ self.consecutive_increase +=1
58
+ else:
59
+ self.consecutive_increase = 0
60
+ self.last_measure = val
61
+
62
+ return self.patience < self.consecutive_increase
models/weights/ResidualDancer/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "classes": [
3
- "ATN",
4
- "BBA",
5
- "BCH",
6
- "BLU",
7
- "CHA",
8
- "CMB",
9
- "CSG",
10
- "ECS",
11
- "HST",
12
- "JIV",
13
- "LHP",
14
- "QST",
15
- "RMB",
16
- "SFT",
17
- "SLS",
18
- "SMB",
19
- "SWZ",
20
- "TGO",
21
- "VWZ",
22
- "WCS"
23
- ]
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/weights/ResidualDancer/dancer_net.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1888558eed82a5d99ac1dab55969a9ea36455d11a9370355d1f2b984598d30ff
3
- size 48453416
 
 
 
 
assets/song-samples/exs_and_ohs.wav β†’ models/weights/ResidualDancer/weights.ckpt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4e53fe157ff687b5464e98c7d0c03d0712527c3a7ed24b6b063a328fcf7bf608
3
- size 5292082
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e107090ff62ac0b79f4f40271e8b1dd6c3d10d8146264ec49df3c8febe99aa23
3
+ size 193651217
preprocessing/dataset.py CHANGED
@@ -3,87 +3,122 @@ from torch.utils.data import Dataset, DataLoader, random_split
3
  import numpy as np
4
  import pandas as pd
5
  import torchaudio as ta
6
- from .pipelines import AudioPipeline
7
  import pytorch_lightning as pl
8
  from .preprocess import get_examples
 
9
 
10
 
11
 
12
  class SongDataset(Dataset):
13
  def __init__(self,
14
  audio_paths: list[str],
15
- dance_labels: list[np.ndarray],
16
  audio_duration=30, # seconds
17
  audio_window_duration=6, # seconds
 
 
 
18
  ):
19
  assert audio_duration % audio_window_duration == 0, "Audio window should divide duration evenly."
 
20
 
21
  self.audio_paths = audio_paths
22
  self.dance_labels = dance_labels
23
  audio_info = ta.info(audio_paths[0])
24
  self.sample_rate = audio_info.sample_rate
25
  self.audio_window_duration = int(audio_window_duration)
 
26
  self.audio_duration = int(audio_duration)
27
 
28
- self.audio_pipeline = AudioPipeline(input_freq=self.sample_rate)
29
 
30
  def __len__(self):
31
  return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
32
 
33
- def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
34
  waveform = self._waveform_from_index(idx)
 
35
  spectrogram = self.audio_pipeline(waveform)
36
 
37
  dance_labels = self._label_from_index(idx)
38
 
39
- return spectrogram, dance_labels
 
 
 
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def _waveform_from_index(self, idx:int) -> torch.Tensor:
43
- audio_file_idx = idx * self.audio_window_duration // self.audio_duration
44
- frame_offset = idx % self.audio_duration // self.audio_window_duration
 
 
 
 
 
45
  num_frames = self.sample_rate * self.audio_window_duration
46
- waveform, sample_rate = ta.load(self.audio_paths[audio_file_idx], frame_offset=frame_offset, num_frames=num_frames)
47
  assert sample_rate == self.sample_rate, f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
48
  return waveform
49
 
50
 
51
  def _label_from_index(self, idx:int) -> torch.Tensor:
52
- label_idx = idx * self.audio_window_duration // self.audio_duration
53
- return torch.from_numpy(self.dance_labels[label_idx])
54
-
55
 
56
  class DanceDataModule(pl.LightningDataModule):
57
  def __init__(self,
58
- song_data_path="data/songs.csv",
59
  song_audio_path="data/samples",
60
  test_proportion=0.15,
61
  val_proportion=0.1,
62
  target_classes:list[str]=None,
 
63
  batch_size:int=64,
64
- num_workers=10
 
65
  ):
66
  super().__init__()
67
  self.song_data_path = song_data_path
68
  self.song_audio_path = song_audio_path
69
  self.val_proportion=val_proportion
70
  self.test_proportion=test_proportion
71
- self.train_proporition= 1.-test_proportion-val_proportion
72
  self.target_classes=target_classes
73
  self.batch_size = batch_size
74
  self.num_workers = num_workers
 
75
 
76
- df = pd.read_csv("data/songs.csv")
77
- self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes)
78
-
79
 
80
  def setup(self, stage: str):
81
- dataset = SongDataset(self.x,self.y)
82
- self.train_ds, self.val_ds, self.test_ds = random_split(dataset, [self.train_proporition, self.val_proportion, self.test_proportion])
 
 
 
 
 
83
 
84
-
85
  def train_dataloader(self):
86
- return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers)
87
 
88
  def val_dataloader(self):
89
  return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers)
@@ -92,4 +127,5 @@ class DanceDataModule(pl.LightningDataModule):
92
  return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers)
93
 
94
  def get_label_weights(self):
95
- return torch.from_numpy(len(self.y) / (len(self.y[0]) * sum(self.y)))
 
 
3
  import numpy as np
4
  import pandas as pd
5
  import torchaudio as ta
6
+ from .pipelines import AudioTrainingPipeline
7
  import pytorch_lightning as pl
8
  from .preprocess import get_examples
9
+ from sklearn.model_selection import train_test_split
10
 
11
 
12
 
13
  class SongDataset(Dataset):
14
  def __init__(self,
15
  audio_paths: list[str],
16
+ dance_labels: list[np.ndarray],
17
  audio_duration=30, # seconds
18
  audio_window_duration=6, # seconds
19
+ audio_window_jitter=0.0, # seconds
20
+ audio_pipeline_kwargs={},
21
+ resample_frequency=16000
22
  ):
23
  assert audio_duration % audio_window_duration == 0, "Audio window should divide duration evenly."
24
+ assert audio_window_duration > audio_window_jitter, "Jitter should be a small fraction of the audio window duration."
25
 
26
  self.audio_paths = audio_paths
27
  self.dance_labels = dance_labels
28
  audio_info = ta.info(audio_paths[0])
29
  self.sample_rate = audio_info.sample_rate
30
  self.audio_window_duration = int(audio_window_duration)
31
+ self.audio_window_jitter = audio_window_jitter
32
  self.audio_duration = int(audio_duration)
33
 
34
+ self.audio_pipeline = AudioTrainingPipeline(self.sample_rate, resample_frequency, audio_window_duration, **audio_pipeline_kwargs)
35
 
36
  def __len__(self):
37
  return len(self.audio_paths) * self.audio_duration // self.audio_window_duration
38
 
39
+ def __getitem__(self, idx:int) -> tuple[torch.Tensor, torch.Tensor]:
40
  waveform = self._waveform_from_index(idx)
41
+ assert waveform.shape[1] > 10, f"No data found: {self._backtrace_audio_path(idx)}"
42
  spectrogram = self.audio_pipeline(waveform)
43
 
44
  dance_labels = self._label_from_index(idx)
45
 
46
+ example_is_valid = self._validate_output(spectrogram, dance_labels)
47
+ if example_is_valid:
48
+ return spectrogram, dance_labels
49
+ else:
50
+ # Try the previous one
51
+ # This happens when some of the audio recordings are really quiet
52
+ # This WILL NOT leak into other data partitions because songs belong entirely to a partition
53
+ return self[idx-1]
54
 
55
+ def _convert_idx(self,idx:int) -> int:
56
+ return idx * self.audio_window_duration // self.audio_duration
57
+
58
+ def _backtrace_audio_path(self, index:int) -> str:
59
+ return self.audio_paths[self._convert_idx(index)]
60
+
61
+ def _validate_output(self,x,y):
62
+ is_finite = not torch.any(torch.isinf(x))
63
+ is_numerical = not torch.any(torch.isnan(x))
64
+ has_data = torch.any(x != 0.0)
65
+ is_binary = len(torch.unique(y)) < 3
66
+ return all((is_finite,is_numerical, has_data, is_binary))
67
 
68
  def _waveform_from_index(self, idx:int) -> torch.Tensor:
69
+ audio_filepath = self.audio_paths[self._convert_idx(idx)]
70
+ num_windows = self.audio_duration // self.audio_window_duration
71
+ frame_index = idx % num_windows
72
+ jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0
73
+ jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0
74
+ jitter = int(torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate)
75
+ frame_offset = frame_index * self.audio_window_duration * self.sample_rate + jitter
76
  num_frames = self.sample_rate * self.audio_window_duration
77
+ waveform, sample_rate = ta.load(audio_filepath, frame_offset=frame_offset, num_frames=num_frames)
78
  assert sample_rate == self.sample_rate, f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
79
  return waveform
80
 
81
 
82
  def _label_from_index(self, idx:int) -> torch.Tensor:
83
+ return torch.from_numpy(self.dance_labels[self._convert_idx(idx)])
 
 
84
 
85
  class DanceDataModule(pl.LightningDataModule):
86
  def __init__(self,
87
+ song_data_path="data/songs_cleaned.csv",
88
  song_audio_path="data/samples",
89
  test_proportion=0.15,
90
  val_proportion=0.1,
91
  target_classes:list[str]=None,
92
+ min_votes=1,
93
  batch_size:int=64,
94
+ num_workers=10,
95
+ dataset_kwargs={}
96
  ):
97
  super().__init__()
98
  self.song_data_path = song_data_path
99
  self.song_audio_path = song_audio_path
100
  self.val_proportion=val_proportion
101
  self.test_proportion=test_proportion
102
+ self.train_proportion= 1.-test_proportion-val_proportion
103
  self.target_classes=target_classes
104
  self.batch_size = batch_size
105
  self.num_workers = num_workers
106
+ self.dataset_kwargs = dataset_kwargs
107
 
108
+ df = pd.read_csv(song_data_path)
109
+ self.x,self.y = get_examples(df, self.song_audio_path,class_list=self.target_classes, multi_label=True, min_votes=min_votes)
 
110
 
111
  def setup(self, stage: str):
112
+ train_i, val_i, test_i = random_split(np.arange(len(self.x)), [self.train_proportion, self.val_proportion, self.test_proportion])
113
+ self.train_ds = self._dataset_from_indices(train_i)
114
+ self.val_ds = self._dataset_from_indices(val_i)
115
+ self.test_ds = self._dataset_from_indices(test_i)
116
+
117
+ def _dataset_from_indices(self, idx:list[int]) -> SongDataset:
118
+ return SongDataset(self.x[idx], self.y[idx], **self.dataset_kwargs)
119
 
 
120
  def train_dataloader(self):
121
+ return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
122
 
123
  def val_dataloader(self):
124
  return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers)
 
127
  return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers)
128
 
129
  def get_label_weights(self):
130
+ n_examples, n_classes = self.y.shape
131
+ return torch.from_numpy(n_examples / (n_classes * sum(self.y)))
preprocessing/pipelines.py CHANGED
@@ -1,63 +1,109 @@
1
  import torch
 
2
  from torchaudio import transforms as taT, functional as taF
3
  import torch.nn as nn
4
 
5
- class AudioPipeline(torch.nn.Module):
6
- def __init__(
7
- self,
8
- input_freq=16000,
9
- resample_freq=16000,
10
- ):
11
- super().__init__()
12
- self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
13
- self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
14
- self.to_db = taT.AmplitudeToDB()
15
-
16
- def forward(self, waveform: torch.Tensor) -> torch.Tensor:
17
- if waveform.shape[0] > 1:
18
- waveform = waveform.mean(0, keepdim=True)
19
 
20
- waveform = (waveform - waveform.mean()) / waveform.abs().max()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- waveform = self.resample(waveform)
23
- spectrogram = self.spec(waveform)
24
- spectrogram = self.to_db(spectrogram)
25
 
26
- return spectrogram
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- class SpectrogramAugmentationPipeline(torch.nn.Module):
 
 
 
 
 
 
 
30
 
31
- def __init__(self):
32
- super().__init__()
33
- self.pipeline = nn.Sequential(
34
- taT.FrequencyMasking(80),
35
- taT.TimeMasking(80),
36
- taT.TimeStretch(80)
37
- )
38
 
39
- def forward(self, spectrogram:torch.Tensor) -> torch.Tensor:
40
- return self.pipeline(spectrogram)
41
 
 
42
 
43
- class WaveformAugmentationPipeline(torch.nn.Module):
44
- def __init__(self):
45
  super().__init__()
 
46
 
47
 
48
 
49
  def forward(self, waveform:torch.Tensor) -> torch.Tensor:
50
- taF.pitch_shift()
 
 
51
 
 
 
 
52
 
53
- class AudioTrainingPipeline(torch.nn.Module):
54
- def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  super().__init__()
56
- self.waveform_aug = WaveformAugmentationPipeline()
57
- self.spec_aug = SpectrogramAugmentationPipeline()
58
- self.audio_preprocessing = AudioPipeline()
59
 
60
- def forward(self, waveform:torch.Tensor) -> torch.Tensor:
61
- x = self.audio_preprocessing(waveform)
62
- x = self.spec_aug(x)
63
- return x
 
 
 
 
1
  import torch
2
+ import torchaudio
3
  from torchaudio import transforms as taT, functional as taF
4
  import torch.nn as nn
5
 
6
+ NOISE_PATH = "data/augmentation/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ class AudioTrainingPipeline(torch.nn.Module):
9
+ def __init__(self,
10
+ input_freq=16000,
11
+ resample_freq=16000,
12
+ expected_duration=6,
13
+ freq_mask_size=10,
14
+ time_mask_size=80,
15
+ mask_count = 2,
16
+ snr_mean=6.0):
17
+ super().__init__()
18
+ self.input_freq = input_freq
19
+ self.snr_mean = snr_mean
20
+ self.mask_count = mask_count
21
+ self.noise = self.get_noise()
22
+ self.resample = taT.Resample(input_freq,resample_freq)
23
+ self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
24
+ self.audio_to_spectrogram = AudioToSpectrogram(
25
+ sample_rate=resample_freq,
26
+ )
27
+ self.freq_mask = taT.FrequencyMasking(freq_mask_size)
28
+ self.time_mask = taT.TimeMasking(time_mask_size)
29
 
 
 
 
30
 
31
+ def get_noise(self) -> torch.Tensor:
32
+ noise, sr = torchaudio.load(NOISE_PATH)
33
+ if noise.shape[0] > 1:
34
+ noise = noise.mean(0, keepdim=True)
35
+ if sr != self.input_freq:
36
+ noise = taF.resample(noise,sr, self.input_freq)
37
+ return noise
38
 
39
+ def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
40
+ num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
41
+ noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
42
+ noise_power = noise.norm(p=2)
43
+ signal_power = waveform.norm(p=2)
44
+ snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
45
+ snr = torch.exp(snr_db / 10)
46
+ scale = snr * noise_power / signal_power
47
+ noisy_waveform = (scale * waveform + noise) / 2
48
+ return noisy_waveform
49
 
50
+ def forward(self, waveform:torch.Tensor) -> torch.Tensor:
51
+ try:
52
+ waveform = self.resample(waveform)
53
+ except:
54
+ print("oops")
55
+ waveform = self.preprocess_waveform(waveform)
56
+ waveform = self.add_noise(waveform)
57
+ spec = self.audio_to_spectrogram(waveform)
58
 
59
+ # Spectrogram augmentation
60
+ for _ in range(self.mask_count):
61
+ spec = self.freq_mask(spec)
62
+ spec = self.time_mask(spec)
63
+ return spec
 
 
64
 
 
 
65
 
66
+ class WaveformPreprocessing(torch.nn.Module):
67
 
68
+ def __init__(self, expected_sample_length:int):
 
69
  super().__init__()
70
+ self.expected_sample_length = expected_sample_length
71
 
72
 
73
 
74
  def forward(self, waveform:torch.Tensor) -> torch.Tensor:
75
+ # Take out extra channels
76
+ if waveform.shape[0] > 1:
77
+ waveform = waveform.mean(0, keepdim=True)
78
 
79
+ # ensure it is the correct length
80
+ waveform = self._rectify_duration(waveform)
81
+ return waveform
82
 
83
+
84
+ def _rectify_duration(self,waveform:torch.Tensor):
85
+ expected_samples = self.expected_sample_length
86
+ sample_count = waveform.shape[1]
87
+ if expected_samples == sample_count:
88
+ return waveform
89
+ elif expected_samples > sample_count:
90
+ pad_amount = expected_samples - sample_count
91
+ return torch.nn.functional.pad(waveform, (0, pad_amount),mode="constant", value=0.0)
92
+ else:
93
+ return waveform[:,:expected_samples]
94
+
95
+
96
+ class AudioToSpectrogram(torch.nn.Module):
97
+ def __init__(
98
+ self,
99
+ sample_rate=16000,
100
+ ):
101
  super().__init__()
 
 
 
102
 
103
+ self.spec = taT.MelSpectrogram(sample_rate=sample_rate, n_mels=128, n_fft=1024) # TODO: Change mels to 64
104
+ self.to_db = taT.AmplitudeToDB()
105
+
106
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
107
+ spectrogram = self.spec(waveform)
108
+ spectrogram = self.to_db(spectrogram)
109
+ return spectrogram
preprocessing/preprocess.py CHANGED
@@ -4,8 +4,9 @@ import re
4
  import json
5
  from pathlib import Path
6
  import os
 
7
  import torch
8
- import torchaudio.transforms as taT
9
 
10
  def url_to_filename(url:str) -> str:
11
  return f"{url.split('/')[-1]}.wav"
@@ -17,6 +18,35 @@ def get_songs_with_audio(df:pd.DataFrame, audio_dir:str) -> pd.DataFrame:
17
  df = df[valid_audio]
18
  return df
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
21
  tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
22
  dance_ratings = dance_ratings.apply(lambda v : json.loads(v.replace("'", "\"")))
@@ -64,7 +94,7 @@ def vectorize_multi_label(labels: dict[str,int], unique_labels:np.ndarray) -> np
64
  probs[probs > 0.0] = 1.0
65
  return probs
66
 
67
- def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[list[str], list[np.ndarray]]:
68
  sampled_songs = get_songs_with_audio(df, audio_dir)
69
  sampled_songs.loc[:,"DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
70
  if class_list is not None:
@@ -74,11 +104,28 @@ def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[list[
74
  if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
75
  else np.nan)
76
  sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
77
- labels = sampled_songs["DanceRating"]
 
 
78
  unique_labels = np.array(get_unique_labels(labels))
79
- labels = labels.apply(lambda i : vectorize_multi_label(i, unique_labels))
 
80
 
81
  audio_paths = [os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]]
82
 
83
- return audio_paths, list(labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
4
  import json
5
  from pathlib import Path
6
  import os
7
+ import torchaudio
8
  import torch
9
+ from tqdm import tqdm
10
 
11
  def url_to_filename(url:str) -> str:
12
  return f"{url.split('/')[-1]}.wav"
 
18
  df = df[valid_audio]
19
  return df
20
 
21
+ def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
22
+ """
23
+ Tests audio urls to ensure that their file exists and the contents is valid.
24
+ """
25
+ audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
26
+ def is_valid(url):
27
+ valid_url = type(url) == str and "http" in url
28
+ if not valid_url:
29
+ return False
30
+ filename = url_to_filename(url)
31
+ if filename not in audio_files:
32
+ return False
33
+ try:
34
+ w, _ = torchaudio.load(os.path.join(audio_dir, filename))
35
+ except:
36
+ return False
37
+ contents_invalid = torch.any(torch.isnan(w)) or torch.any(torch.isinf(w)) or len(torch.unique(w)) <= 2
38
+ return not contents_invalid
39
+
40
+ idxs = []
41
+ validations = []
42
+ for index, url in tqdm(audio_urls.items(), total=len(audio_urls), desc="Audio URLs Validated"):
43
+ idxs.append(index)
44
+ validations.append(is_valid(url))
45
+
46
+ return pd.Series(validations, index=idxs)
47
+
48
+
49
+
50
  def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
51
  tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
52
  dance_ratings = dance_ratings.apply(lambda v : json.loads(v.replace("'", "\"")))
 
94
  probs[probs > 0.0] = 1.0
95
  return probs
96
 
97
+ def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None, multi_label=True, min_votes=1) -> tuple[np.ndarray, np.ndarray]:
98
  sampled_songs = get_songs_with_audio(df, audio_dir)
99
  sampled_songs.loc[:,"DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
100
  if class_list is not None:
 
104
  if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
105
  else np.nan)
106
  sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
107
+ vote_mask = sampled_songs["DanceRating"].apply(lambda dances: any(votes >= min_votes for votes in dances.values()))
108
+ sampled_songs = sampled_songs[vote_mask]
109
+ labels = sampled_songs["DanceRating"].apply(lambda dances : {dance: votes for dance, votes in dances.items() if votes >= min_votes})
110
  unique_labels = np.array(get_unique_labels(labels))
111
+ vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
112
+ labels = labels.apply(lambda i : vectorizer(i, unique_labels))
113
 
114
  audio_paths = [os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]]
115
 
116
+ return np.array(audio_paths), np.stack(labels)
117
+
118
+
119
+ if __name__ == "__main__":
120
+ links = pd.read_csv("data/backup_2.csv", index_col="index")
121
+ df = pd.read_csv("data/songs.csv")
122
+ l = links["link"].str.strip()
123
+ l = l.apply(lambda url : url if "http" in url else np.nan)
124
+ l = l.dropna()
125
+ df["Sample"].update(l)
126
+ addna = lambda url : url if type(url) == str and "http" in url else np.nan
127
+ df["Sample"] = df["Sample"].apply(addna)
128
+ is_valid = validate_audio(df["Sample"],"data/samples")
129
+ df["valid"] = is_valid
130
+ df.to_csv("data/songs_validated.csv")
131
 
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- torchvision
2
  torch
 
3
  torchaudio
 
4
  numpy
5
  pandas
6
  seaborn
 
 
1
  torch
2
+ torchvision
3
  torchaudio
4
+ pytorch-lightning
5
  numpy
6
  pandas
7
  seaborn
scrapers/spotify.py CHANGED
@@ -49,14 +49,14 @@ def patch_missing_songs(
49
  if preview_url is not None:
50
  row["Sample"] = preview_url
51
  return row
52
- backup_file = open("data/backup_1.csv", "a")
53
  rows = []
54
  indices = []
 
 
55
  total_rows = len(missing_df)
56
- for i, row in tqdm(missing_df.iloc[11121:].iterrows(),total=total_rows):
57
  patched_row = patch_preview(row)
58
- backup_file.write(f"{i}, {patched_row['Sample']}\n")
59
- rows.append(patch_preview(row))
60
  indices.append(i)
61
 
62
 
@@ -65,23 +65,10 @@ def patch_missing_songs(
65
  return df
66
 
67
 
68
- def download_links():
69
- start = 3180
70
- with open("data/backup_2.csv") as f:
71
  links = [x.split(",")[1].strip() for x in f.readlines()]
72
- links = links[start:]
73
  links = [l for l in links if "https" in l]
74
- links = links[2680:]
75
  for link in tqdm(links, "Songs Downloaded"):
76
- download_song(link, "data/spotify-samples")
77
  time.sleep(5e-3) # hopefully wont be rate limited with delay 🀞
78
-
79
-
80
-
81
-
82
- if __name__ == "__main__":
83
- df = pd.read_csv("data/songs.csv")
84
- patched = patch_missing_songs(df)
85
- patched.to_csv("data/last_part.csv")
86
-
87
-
 
49
  if preview_url is not None:
50
  row["Sample"] = preview_url
51
  return row
 
52
  rows = []
53
  indices = []
54
+ after = 18418
55
+ missing_df = missing_df.iloc[after:]
56
  total_rows = len(missing_df)
57
+ for i, row in tqdm(missing_df.iterrows(),total=total_rows):
58
  patched_row = patch_preview(row)
59
+ rows.append(patched_row)
 
60
  indices.append(i)
61
 
62
 
 
65
  return df
66
 
67
 
68
+ def download_links_from_backup(backup_file:str, output_dir:str):
69
+ with open(backup_file) as f:
 
70
  links = [x.split(",")[1].strip() for x in f.readlines()]
 
71
  links = [l for l in links if "https" in l]
 
72
  for link in tqdm(links, "Songs Downloaded"):
73
+ download_song(link, output_dir)
74
  time.sleep(5e-3) # hopefully wont be rate limited with delay 🀞
 
 
 
 
 
 
 
 
 
 
tests.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import numpy as np
3
+ from audio_utils import play_audio
4
+ from preprocessing.dataset import SongDataset
5
+
6
+ def test_audio_splitting():
7
+
8
+
9
+
10
+ audio_paths = ["data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav"]
11
+ labels = [np.array([1,0,1,0])]
12
+ whole_song, sr = torchaudio.load("data/samples/95f2df65f7450db3b1af29aa77ba7edc6ab52075?cid=7ffadeb2e136495fb5a62d1ac9be8f62.wav")
13
+
14
+ ds = SongDataset(audio_paths, labels)
15
+ song_parts = (ds._waveform_from_index(i) for i in range(len(ds)))
16
+ print("Sample Parts")
17
+ for part in song_parts:
18
+ play_audio(part,sr)
19
+
20
+
21
+ print("Whole Sample")
22
+ play_audio(whole_song,sr)
train.py CHANGED
@@ -1,196 +1,69 @@
1
- import datetime
2
- import os
3
- import torch
4
  from torch.utils.data import DataLoader
5
- import torch.nn as nn
6
- from tqdm import tqdm
7
  import pandas as pd
8
- import numpy as np
9
- from torch.utils.data import random_split, SubsetRandomSampler
10
- import json
11
  from sklearn.model_selection import KFold
12
-
13
- from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
 
14
  from preprocessing.dataset import SongDataset
15
  from preprocessing.preprocess import get_examples
16
- from models.residual import ResidualDancer
17
-
18
- DEVICE = "mps"
19
- SEED = 42
20
- TARGET_CLASSES = ['ATN',
21
- 'BBA',
22
- 'BCH',
23
- 'BLU',
24
- 'CHA',
25
- 'CMB',
26
- 'CSG',
27
- 'ECS',
28
- 'HST',
29
- 'JIV',
30
- 'LHP',
31
- 'QST',
32
- 'RMB',
33
- 'SFT',
34
- 'SLS',
35
- 'SMB',
36
- 'SWZ',
37
- 'TGO',
38
- 'VWZ',
39
- 'WCS']
40
-
41
- def get_timestamp() -> str:
42
- return datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
43
-
44
- class EarlyStopping:
45
- def __init__(self, patience=0):
46
- self.patience = patience
47
- self.last_measure = np.inf
48
- self.consecutive_increase = 0
49
-
50
- def step(self, val) -> bool:
51
- if self.last_measure <= val:
52
- self.consecutive_increase +=1
53
- else:
54
- self.consecutive_increase = 0
55
- self.last_measure = val
56
-
57
- return self.patience < self.consecutive_increase
58
-
59
-
60
-
61
- def calculate_metrics(pred, target, threshold=0.5, prefix=""):
62
- target = target.detach().cpu().numpy()
63
- pred = pred.detach().cpu().numpy()
64
- pred = np.array(pred > threshold, dtype=float)
65
- metrics= {
66
- 'precision': precision_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
67
- 'recall': recall_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
68
- 'f1': f1_score(y_true=target, y_pred=pred, average='macro', zero_division=0),
69
- 'accuracy': accuracy_score(y_true=target, y_pred=pred),
70
- }
71
- if prefix != "":
72
- metrics = {prefix + k : v for k, v in metrics.items()}
73
-
74
- return metrics
75
 
 
 
 
 
76
 
77
- def evaluate(model:nn.Module, data_loader:DataLoader, criterion, device="mps") -> pd.Series:
78
- val_metrics = []
79
- for features, labels in (prog_bar := tqdm(data_loader)):
80
- features = features.to(device)
81
- labels = labels.to(device)
82
- with torch.no_grad():
83
- outputs = model(features)
84
- loss = criterion(outputs, labels)
85
- batch_metrics = calculate_metrics(outputs, labels, prefix="val_")
86
- batch_metrics["val_loss"] = loss.item()
87
- prog_bar.set_description(f'Validation - Loss: {batch_metrics["val_loss"]:.2f}, Accuracy: {batch_metrics["val_accuracy"]:.2f}')
88
- val_metrics.append(batch_metrics)
89
- return pd.DataFrame(val_metrics).mean()
90
-
91
-
92
-
93
- def train(
94
- model: nn.Module,
95
- data_loader: DataLoader,
96
- val_loader=None,
97
- epochs=3,
98
- lr=1e-3,
99
- device="mps"):
100
- criterion = nn.BCELoss()
101
- optimizer = torch.optim.Adam(model.parameters(),lr=lr)
102
- early_stop = EarlyStopping(1)
103
- scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr,
104
- steps_per_epoch=int(len(data_loader)),
105
- epochs=epochs,
106
- anneal_strategy='linear')
107
- metrics = []
108
- for epoch in range(1,epochs+1):
109
- train_metrics = []
110
- prog_bar = tqdm(data_loader)
111
- for features, labels in prog_bar:
112
- features = features.to(device)
113
- labels = labels.to(device)
114
- optimizer.zero_grad()
115
- outputs = model(features)
116
- loss = criterion(outputs, labels)
117
- loss.backward()
118
- optimizer.step()
119
- scheduler.step()
120
- batch_metrics = calculate_metrics(outputs, labels)
121
- batch_metrics["loss"] = loss.item()
122
- train_metrics.append(batch_metrics)
123
- prog_bar.set_description(f'Training - Epoch: {epoch}/{epochs}, Loss: {batch_metrics["loss"]:.2f}, Accuracy: {batch_metrics["accuracy"]:.2f}')
124
- train_metrics = pd.DataFrame(train_metrics).mean()
125
- if val_loader is not None:
126
- val_metrics = evaluate(model, val_loader, criterion)
127
- if early_stop.step(val_metrics["val_f1"]):
128
- break
129
- epoch_metrics = pd.concat([train_metrics, val_metrics], axis=0)
130
- else:
131
- epoch_metrics = train_metrics
132
- metrics.append(dict(epoch_metrics))
133
-
134
- return model, metrics
135
-
136
-
137
- def cross_validation(seed=42, batch_size=64, k=5, device="mps"):
138
  df = pd.read_csv("data/songs.csv")
139
- x,y = get_examples(df, "data/samples",class_list=TARGET_CLASSES)
140
-
 
141
  dataset = SongDataset(x,y)
142
- splits=KFold(n_splits=k,shuffle=True,random_state=seed)
143
- metrics = []
144
  for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
145
  print(f"Fold {fold+1}")
146
-
 
147
  train_sampler = SubsetRandomSampler(train_idx)
148
  test_sampler = SubsetRandomSampler(val_idx)
149
  train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
150
  test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
151
- n_classes = len(y[0])
152
- model = ResidualDancer(n_classes=n_classes).to(device)
153
- model, _ = train(model,train_loader, epochs=2, device=device)
154
- val_metrics = evaluate(model, test_loader, nn.BCELoss())
155
- metrics.append(val_metrics)
156
- metrics = pd.DataFrame(metrics)
157
- log_dir = os.path.join(
158
- "logs", get_timestamp()
159
- )
160
- os.makedirs(log_dir, exist_ok=True)
161
-
162
- metrics.to_csv(model.state_dict(), os.path.join(log_dir, "cross_val.csv"))
163
-
164
-
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- def train_model():
167
 
168
- df = pd.read_csv("data/songs.csv")
169
- x,y = get_examples(df, "data/samples",class_list=TARGET_CLASSES)
170
- dataset = SongDataset(x,y)
171
- train_count = int(len(dataset) * 0.9)
172
- datasets = random_split(dataset, [train_count, len(dataset) - train_count], torch.Generator().manual_seed(SEED))
173
- data_loaders = [DataLoader(data, batch_size=64, shuffle=True) for data in datasets]
174
- train_data, val_data = data_loaders
175
- example_spec, example_label = dataset[0]
176
- n_classes = len(example_label)
177
- model = ResidualDancer(n_classes=n_classes).to(DEVICE)
178
- model, metrics = train(model,train_data, val_data, epochs=3, device=DEVICE)
179
-
180
- log_dir = os.path.join(
181
- "logs", get_timestamp()
182
- )
183
- os.makedirs(log_dir, exist_ok=True)
184
-
185
- torch.save(model.state_dict(), os.path.join(log_dir, "residual_dancer.pt"))
186
- metrics = pd.DataFrame(metrics)
187
- metrics.to_csv(os.path.join(log_dir, "metrics.csv"))
188
- config = {
189
- "classes": TARGET_CLASSES
190
- }
191
- with open(os.path.join(log_dir, "config.json")) as f:
192
- json.dump(config, f)
193
- print("Training information saved!")
194
 
195
  if __name__ == "__main__":
196
- cross_validation()
 
 
 
 
 
 
1
  from torch.utils.data import DataLoader
 
 
2
  import pandas as pd
3
+ from torch import nn
4
+ from torch.utils.data import SubsetRandomSampler
 
5
  from sklearn.model_selection import KFold
6
+ import pytorch_lightning as pl
7
+ from pytorch_lightning import callbacks as cb
8
+ from models.utils import LabelWeightedBCELoss
9
  from preprocessing.dataset import SongDataset
10
  from preprocessing.preprocess import get_examples
11
+ from models.residual import ResidualDancer, TrainingEnvironment
12
+ import yaml
13
+ from preprocessing.dataset import DanceDataModule
14
+ from wakepy import keepawake
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def get_config(filepath:str) -> dict:
17
+ with open(filepath, "r") as f:
18
+ config = yaml.safe_load(f)
19
+ return config
20
 
21
+ def cross_validation(config, k=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  df = pd.read_csv("data/songs.csv")
23
+ g_config = config["global"]
24
+ batch_size = config["data_module"]["batch_size"]
25
+ x,y = get_examples(df, "data/samples",class_list=g_config["dance_ids"])
26
  dataset = SongDataset(x,y)
27
+ splits=KFold(n_splits=k,shuffle=True,random_state=g_config["seed"])
28
+ trainer = pl.Trainer(accelerator=g_config["device"])
29
  for fold, (train_idx,val_idx) in enumerate(splits.split(x,y)):
30
  print(f"Fold {fold+1}")
31
+ model = ResidualDancer(n_classes=len(g_config["dance_ids"]))
32
+ train_env = TrainingEnvironment(model,nn.BCELoss())
33
  train_sampler = SubsetRandomSampler(train_idx)
34
  test_sampler = SubsetRandomSampler(val_idx)
35
  train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
36
  test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
37
+ trainer.fit(train_env, train_loader)
38
+ trainer.test(train_env, test_loader)
39
+
40
+
41
+ def train_model(config:dict):
42
+ TARGET_CLASSES = config["global"]["dance_ids"]
43
+ DEVICE = config["global"]["device"]
44
+ SEED = config["global"]["seed"]
45
+ pl.seed_everything(SEED, workers=True)
46
+ data = DanceDataModule(target_classes=TARGET_CLASSES, **config['data_module'])
47
+ model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config['model'])
48
+ label_weights = data.get_label_weights().to(DEVICE)
49
+ criterion = LabelWeightedBCELoss(label_weights) #nn.CrossEntropyLoss(label_weights)
50
+ train_env = TrainingEnvironment(model, criterion, config)
51
+ callbacks = [
52
+ # cb.LearningRateFinder(update_attr=True),
53
+ cb.EarlyStopping("val/loss", patience=5),
54
+ cb.StochasticWeightAveraging(1e-2),
55
+ cb.RichProgressBar()
56
+ ]
57
+ trainer = pl.Trainer(
58
+ callbacks=callbacks,
59
+ **config["trainer"]
60
+ )
61
+ trainer.fit(train_env, datamodule=data)
62
+ trainer.test(train_env, datamodule=data)
63
 
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  if __name__ == "__main__":
67
+ config = get_config("models/config/train.yaml")
68
+ with keepawake():
69
+ train_model(config)