waidhoferj commited on
Commit
51f4763
1 Parent(s): 17a2a7d

updated production build to use multiple overlapping samples

Browse files
app.py CHANGED
@@ -7,7 +7,7 @@ from functools import cache
7
  from pathlib import Path
8
  from models.residual import ResidualDancer
9
  from models.training_environment import TrainingEnvironment
10
- from preprocessing.pipelines import SpectrogramProductionPipeline
11
  import torch
12
  from torch import nn
13
  import yaml
@@ -17,6 +17,8 @@ CONFIG_FILE = Path("models/weights/ResidualDancer/multilabel/config.yaml")
17
 
18
  DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
19
 
 
 
20
 
21
  class DancePredictor:
22
  def __init__(
@@ -37,6 +39,9 @@ class DancePredictor:
37
  self.labels = np.array(labels)
38
  self.device = device
39
  self.model = self.get_model(weight_path)
 
 
 
40
  self.extractor = SpectrogramProductionPipeline()
41
 
42
  def get_model(self, weight_path: str) -> nn.Module:
@@ -87,10 +92,21 @@ class DancePredictor:
87
  waveform = torchaudio.functional.resample(
88
  waveform, sample_rate, self.resample_frequency
89
  )
90
- features = self.extractor(waveform)
91
- features = features.unsqueeze(0).to(self.device)
 
 
 
 
 
 
 
 
92
  results = self.model(features)
93
- results = nn.functional.softmax(results.squeeze(0), dim=0)
 
 
 
94
  results = results.detach().cpu().numpy()
95
 
96
  result_mask = results > self.threshold
@@ -116,6 +132,9 @@ def predict(audio: tuple[int, np.ndarray]) -> list[str]:
116
  if audio is None:
117
  return "Dance Not Found"
118
  sample_rate, waveform = audio
 
 
 
119
 
120
  model = get_model(CONFIG_FILE)
121
  results = model(waveform, sample_rate)
@@ -133,7 +152,6 @@ def demo():
133
 
134
  recording_interface = gr.Interface(
135
  fn=predict,
136
- description="Record at least **6 seconds** of the song.",
137
  inputs=gr.Audio(source="microphone", label="Song Recording"),
138
  outputs=gr.Label(label="Dances"),
139
  examples=example_audio,
 
7
  from pathlib import Path
8
  from models.residual import ResidualDancer
9
  from models.training_environment import TrainingEnvironment
10
+ from preprocessing.pipelines import SpectrogramProductionPipeline, WaveformPreprocessing
11
  import torch
12
  from torch import nn
13
  import yaml
 
17
 
18
  DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
19
 
20
+ MIN_DURATION = 3.0
21
+
22
 
23
  class DancePredictor:
24
  def __init__(
 
39
  self.labels = np.array(labels)
40
  self.device = device
41
  self.model = self.get_model(weight_path)
42
+ self.process_waveform = WaveformPreprocessing(
43
+ resample_frequency * expected_duration
44
+ )
45
  self.extractor = SpectrogramProductionPipeline()
46
 
47
  def get_model(self, weight_path: str) -> nn.Module:
 
92
  waveform = torchaudio.functional.resample(
93
  waveform, sample_rate, self.resample_frequency
94
  )
95
+ window_size = self.resample_frequency * self.expected_duration
96
+ n_preds = int(waveform.shape[1] // (window_size / 2))
97
+ step_size = int(waveform.shape[1] / n_preds)
98
+
99
+ inputs = [
100
+ waveform[:, i * step_size : i * step_size + window_size]
101
+ for i in range(n_preds)
102
+ ]
103
+ features = [self.extractor(window) for window in inputs]
104
+ features = torch.stack(features).to(self.device)
105
  results = self.model(features)
106
+ # Convert to probabilities
107
+ results = nn.functional.softmax(results, dim=1)
108
+ # Take average prediction over all of the windows
109
+ results = results.mean(dim=0)
110
  results = results.detach().cpu().numpy()
111
 
112
  result_mask = results > self.threshold
 
132
  if audio is None:
133
  return "Dance Not Found"
134
  sample_rate, waveform = audio
135
+ duration = len(waveform) / sample_rate
136
+ if duration < MIN_DURATION:
137
+ return f"Please record at least {MIN_DURATION} seconds of audio"
138
 
139
  model = get_model(CONFIG_FILE)
140
  results = model(waveform, sample_rate)
 
152
 
153
  recording_interface = gr.Interface(
154
  fn=predict,
 
155
  inputs=gr.Audio(source="microphone", label="Song Recording"),
156
  outputs=gr.Label(label="Dances"),
157
  examples=example_audio,
models/config/train_local.yaml CHANGED
@@ -1,12 +1,15 @@
1
  training_fn: residual.train_residual_dancer
2
- checkpoint: lightning_logs/version_176/checkpoints/epoch=12-step=40404.ckpt
3
  device: mps
4
  seed: 42
5
  dance_ids: &dance_ids
6
  - BCH
 
7
  - CHA
8
- - JIV
9
  - ECS
 
 
 
 
10
  - QST
11
  - RMB
12
  - SFT
@@ -20,8 +23,7 @@ dance_ids: &dance_ids
20
  data_module:
21
  batch_size: 128
22
  num_workers: 10
23
- # data_subset: 0.001
24
- test_proportion: 0.001
25
 
26
  datasets:
27
  preprocessing.dataset.BestBallroomDataset:
@@ -31,7 +33,7 @@ datasets:
31
 
32
  preprocessing.dataset.Music4DanceDataset:
33
  song_data_path: data/songs_cleaned.csv
34
- song_audio_path: data/samples # data/samples
35
  class_list: *dance_ids
36
  multi_label: True
37
  min_votes: 1
@@ -56,7 +58,4 @@ trainer:
56
  # overfit_batches: 1
57
 
58
  training_environment:
59
- learning_rate: 0.000053
60
- # loggers:
61
- # models.training_environment.SpectrogramLogger:
62
- # frequency: 100
 
1
  training_fn: residual.train_residual_dancer
 
2
  device: mps
3
  seed: 42
4
  dance_ids: &dance_ids
5
  - BCH
6
+ - BOL
7
  - CHA
 
8
  - ECS
9
+ - HST
10
+ - LHP
11
+ - NC2
12
+ - JIV
13
  - QST
14
  - RMB
15
  - SFT
 
23
  data_module:
24
  batch_size: 128
25
  num_workers: 10
26
+ test_proportion: 0.15
 
27
 
28
  datasets:
29
  preprocessing.dataset.BestBallroomDataset:
 
33
 
34
  preprocessing.dataset.Music4DanceDataset:
35
  song_data_path: data/songs_cleaned.csv
36
+ song_audio_path: data/samples
37
  class_list: *dance_ids
38
  multi_label: True
39
  min_votes: 1
 
58
  # overfit_batches: 1
59
 
60
  training_environment:
61
+ learning_rate: 0.00053
 
 
 
models/residual.py CHANGED
@@ -119,14 +119,11 @@ def train_residual_dancer(config: dict):
119
  data = DanceDataModule(dataset, **config["data_module"])
120
  model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
121
  label_weights = data.get_label_weights().to(DEVICE)
122
- criterion = LabelWeightedBCELoss(
123
- label_weights
124
- ) # nn.CrossEntropyLoss(label_weights)
125
 
126
  train_env = TrainingEnvironment(model, criterion, config)
127
  callbacks = [
128
- # cb.LearningRateFinder(update_attr=True),
129
- cb.EarlyStopping("val/loss", patience=1),
130
  cb.StochasticWeightAveraging(1e-2),
131
  cb.RichProgressBar(),
132
  ]
 
119
  data = DanceDataModule(dataset, **config["data_module"])
120
  model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
121
  label_weights = data.get_label_weights().to(DEVICE)
122
+ criterion = LabelWeightedBCELoss(label_weights)
 
 
123
 
124
  train_env = TrainingEnvironment(model, criterion, config)
125
  callbacks = [
126
+ cb.EarlyStopping("val/loss", patience=2),
 
127
  cb.StochasticWeightAveraging(1e-2),
128
  cb.RichProgressBar(),
129
  ]
preprocessing/dataset.py CHANGED
@@ -424,11 +424,7 @@ def record_audio_durations(folder: str):
424
  music_files = iglob(os.path.join(folder, "**", "*.wav"), recursive=True)
425
  for file in music_files:
426
  meta = ta.info(file)
427
- durations[file] = meta.num_frames / meta.sample_rate
428
 
429
  with open(os.path.join(folder, "audio_durations.json"), "w") as f:
430
  json.dump(durations, f)
431
-
432
-
433
- class GTZAN:
434
- pass
 
424
  music_files = iglob(os.path.join(folder, "**", "*.wav"), recursive=True)
425
  for file in music_files:
426
  meta = ta.info(file)
427
+ durations[os.path.relpath(file, folder)] = meta.num_frames / meta.sample_rate
428
 
429
  with open(os.path.join(folder, "audio_durations.json"), "w") as f:
430
  json.dump(durations, f)
 
 
 
 
preprocessing/pipelines.py CHANGED
@@ -95,23 +95,27 @@ class WaveformPreprocessing(torch.nn.Module):
95
  self.expected_sample_length = expected_sample_length
96
 
97
  def forward(self, waveform: torch.Tensor) -> torch.Tensor:
 
98
  # Take out extra channels
99
- if waveform.shape[0] > 1:
100
- waveform = waveform.mean(0, keepdim=True)
101
 
102
  # ensure it is the correct length
103
- waveform = self._rectify_duration(waveform)
104
  return waveform
105
 
106
- def _rectify_duration(self, waveform: torch.Tensor):
107
  expected_samples = self.expected_sample_length
108
- sample_count = waveform.shape[1]
109
  if expected_samples == sample_count:
110
  return waveform
111
  elif expected_samples > sample_count:
112
  pad_amount = expected_samples - sample_count
113
  return torch.nn.functional.pad(
114
- waveform, (0, pad_amount), mode="constant", value=0.0
 
 
 
115
  )
116
  else:
117
  return waveform[:, :expected_samples]
 
95
  self.expected_sample_length = expected_sample_length
96
 
97
  def forward(self, waveform: torch.Tensor) -> torch.Tensor:
98
+ c_dim = 1 if len(waveform.shape) == 3 else 0
99
  # Take out extra channels
100
+ if waveform.shape[c_dim] > 1:
101
+ waveform = waveform.mean(c_dim, keepdim=True)
102
 
103
  # ensure it is the correct length
104
+ waveform = self._rectify_duration(waveform, c_dim)
105
  return waveform
106
 
107
+ def _rectify_duration(self, waveform: torch.Tensor, channel_dim: int):
108
  expected_samples = self.expected_sample_length
109
+ sample_count = waveform.shape[channel_dim + 1]
110
  if expected_samples == sample_count:
111
  return waveform
112
  elif expected_samples > sample_count:
113
  pad_amount = expected_samples - sample_count
114
  return torch.nn.functional.pad(
115
+ waveform,
116
+ (channel_dim + 1) * [0] + [pad_amount],
117
+ mode="constant",
118
+ value=0.0,
119
  )
120
  else:
121
  return waveform[:, :expected_samples]