Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
ba35f85
1
Parent(s):
9f53273
updated production weights
Browse files- app.py +27 -15
- assets/song-samples/besame_mucho.wav +3 -0
- models/config/production.yaml +20 -0
- models/config/train_local.yaml +6 -6
- models/residual.py +2 -2
- models/weights/ResidualDancer/weights.ckpt +2 -2
- preprocessing/dataset.py +6 -2
- preprocessing/pipelines.py +15 -0
app.py
CHANGED
@@ -2,18 +2,21 @@ from pathlib import Path
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import os
|
|
|
5 |
from functools import cache
|
6 |
from pathlib import Path
|
7 |
-
from models.
|
8 |
from models.training_environment import TrainingEnvironment
|
|
|
9 |
import torch
|
10 |
from torch import nn
|
11 |
import yaml
|
12 |
import torchaudio
|
13 |
|
14 |
-
CONFIG_FILE = Path("models/config/
|
15 |
-
MODEL_CLS =
|
16 |
-
|
|
|
17 |
|
18 |
|
19 |
class DancePredictor:
|
@@ -22,7 +25,7 @@ class DancePredictor:
|
|
22 |
weight_path: str,
|
23 |
labels: list[str],
|
24 |
expected_duration=6,
|
25 |
-
threshold=0.
|
26 |
resample_frequency=16000,
|
27 |
device="cpu",
|
28 |
):
|
@@ -35,11 +38,13 @@ class DancePredictor:
|
|
35 |
self.labels = np.array(labels)
|
36 |
self.device = device
|
37 |
self.model = self.get_model(weight_path)
|
38 |
-
self.extractor =
|
39 |
|
40 |
def get_model(self, weight_path: str) -> nn.Module:
|
41 |
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
42 |
-
|
|
|
|
|
43 |
for key in list(weights):
|
44 |
weights[
|
45 |
key.replace(
|
@@ -56,10 +61,12 @@ class DancePredictor:
|
|
56 |
config = yaml.safe_load(f)
|
57 |
weight_path = config["checkpoint"]
|
58 |
labels = sorted(config["dance_ids"])
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
63 |
return DancePredictor(
|
64 |
weight_path,
|
65 |
labels,
|
@@ -81,9 +88,6 @@ class DancePredictor:
|
|
81 |
waveform = torchaudio.functional.resample(
|
82 |
waveform, sample_rate, self.resample_frequency
|
83 |
)
|
84 |
-
waveform = waveform[
|
85 |
-
:, : self.resample_frequency * self.expected_duration
|
86 |
-
] # TODO PAD
|
87 |
features = self.extractor(waveform)
|
88 |
features = features.unsqueeze(0).to(self.device)
|
89 |
results = self.model(features)
|
@@ -103,7 +107,15 @@ def get_model(config_path: str) -> DancePredictor:
|
|
103 |
return model
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
|
|
|
|
107 |
sample_rate, waveform = audio
|
108 |
|
109 |
model = get_model(CONFIG_FILE)
|
@@ -116,7 +128,7 @@ def demo():
|
|
116 |
description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
|
117 |
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
118 |
example_audio = [
|
119 |
-
str(song) for song in song_samples.iterdir() if song.name
|
120 |
]
|
121 |
all_dances = get_model(CONFIG_FILE).labels
|
122 |
|
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import os
|
5 |
+
import pandas as pd
|
6 |
from functools import cache
|
7 |
from pathlib import Path
|
8 |
+
from models.residual import ResidualDancer
|
9 |
from models.training_environment import TrainingEnvironment
|
10 |
+
from preprocessing.pipelines import SpectrogramProductionPipeline
|
11 |
import torch
|
12 |
from torch import nn
|
13 |
import yaml
|
14 |
import torchaudio
|
15 |
|
16 |
+
CONFIG_FILE = Path("models/config/production.yaml")
|
17 |
+
MODEL_CLS = ResidualDancer
|
18 |
+
|
19 |
+
DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
|
20 |
|
21 |
|
22 |
class DancePredictor:
|
|
|
25 |
weight_path: str,
|
26 |
labels: list[str],
|
27 |
expected_duration=6,
|
28 |
+
threshold=0.1,
|
29 |
resample_frequency=16000,
|
30 |
device="cpu",
|
31 |
):
|
|
|
38 |
self.labels = np.array(labels)
|
39 |
self.device = device
|
40 |
self.model = self.get_model(weight_path)
|
41 |
+
self.extractor = SpectrogramProductionPipeline()
|
42 |
|
43 |
def get_model(self, weight_path: str) -> nn.Module:
|
44 |
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
45 |
+
n_classes = len(self.labels)
|
46 |
+
# NOTE: Channels are not taken into account
|
47 |
+
model = ResidualDancer(n_classes=n_classes).to(self.device)
|
48 |
for key in list(weights):
|
49 |
weights[
|
50 |
key.replace(
|
|
|
61 |
config = yaml.safe_load(f)
|
62 |
weight_path = config["checkpoint"]
|
63 |
labels = sorted(config["dance_ids"])
|
64 |
+
dance_mapping = get_dance_mapping(DANCE_MAPPING_FILE)
|
65 |
+
labels = [dance_mapping[label] for label in labels]
|
66 |
+
expected_duration = config.get("expected_duration", 6)
|
67 |
+
threshold = config.get("threshold", 0.1)
|
68 |
+
resample_frequency = config.get("resample_frequency", 16000)
|
69 |
+
device = config.get("device", "cpu")
|
70 |
return DancePredictor(
|
71 |
weight_path,
|
72 |
labels,
|
|
|
88 |
waveform = torchaudio.functional.resample(
|
89 |
waveform, sample_rate, self.resample_frequency
|
90 |
)
|
|
|
|
|
|
|
91 |
features = self.extractor(waveform)
|
92 |
features = features.unsqueeze(0).to(self.device)
|
93 |
results = self.model(features)
|
|
|
107 |
return model
|
108 |
|
109 |
|
110 |
+
@cache
|
111 |
+
def get_dance_mapping(mapping_file: str) -> dict[str, str]:
|
112 |
+
mapping_df = pd.read_csv(mapping_file)
|
113 |
+
return {row["id"]: row["name"] for _, row in mapping_df.iterrows()}
|
114 |
+
|
115 |
+
|
116 |
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
117 |
+
if audio is None:
|
118 |
+
return "Dance Not Found"
|
119 |
sample_rate, waveform = audio
|
120 |
|
121 |
model = get_model(CONFIG_FILE)
|
|
|
128 |
description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
|
129 |
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
130 |
example_audio = [
|
131 |
+
str(song) for song in song_samples.iterdir() if not song.name.startswith(".")
|
132 |
]
|
133 |
all_dances = get_model(CONFIG_FILE).labels
|
134 |
|
assets/song-samples/besame_mucho.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14ccffab50d9119ec5250fc84e09542dbbf350450102c108ab61846a3c3031c8
|
3 |
+
size 5290062
|
models/config/production.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint: models/weights/ResidualDancer/weights.ckpt
|
2 |
+
device: cpu
|
3 |
+
seed: 42
|
4 |
+
dance_ids: &dance_ids
|
5 |
+
- BCH
|
6 |
+
- CHA
|
7 |
+
- JIV
|
8 |
+
- ECS
|
9 |
+
- QST
|
10 |
+
- RMB
|
11 |
+
- SFT
|
12 |
+
- SLS
|
13 |
+
- SMB
|
14 |
+
- SWZ
|
15 |
+
- TGO
|
16 |
+
- VWZ
|
17 |
+
- WCS
|
18 |
+
|
19 |
+
model:
|
20 |
+
n_channels: 128
|
models/config/train_local.yaml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
training_fn:
|
2 |
-
checkpoint: lightning_logs/
|
3 |
device: mps
|
4 |
seed: 42
|
5 |
dance_ids: &dance_ids
|
@@ -24,10 +24,10 @@ data_module:
|
|
24 |
test_proportion: 0.2
|
25 |
|
26 |
datasets:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
|
32 |
preprocessing.dataset.Music4DanceDataset:
|
33 |
song_data_path: data/songs_cleaned.csv
|
|
|
1 |
+
training_fn: residual.train_residual_dancer
|
2 |
+
checkpoint: lightning_logs/version_176/checkpoints/epoch=12-step=40404.ckpt
|
3 |
device: mps
|
4 |
seed: 42
|
5 |
dance_ids: &dance_ids
|
|
|
24 |
test_proportion: 0.2
|
25 |
|
26 |
datasets:
|
27 |
+
preprocessing.dataset.BestBallroomDataset:
|
28 |
+
audio_dir: data/ballroom-songs
|
29 |
+
class_list: *dance_ids
|
30 |
+
audio_window_jitter: 0.7
|
31 |
|
32 |
preprocessing.dataset.Music4DanceDataset:
|
33 |
song_data_path: data/songs_cleaned.csv
|
models/residual.py
CHANGED
@@ -110,7 +110,7 @@ def train_residual_dancer(config: dict):
|
|
110 |
TARGET_CLASSES = config["dance_ids"]
|
111 |
DEVICE = config["device"]
|
112 |
SEED = config["seed"]
|
113 |
-
torch.set_float32_matmul_precision(
|
114 |
pl.seed_everything(SEED, workers=True)
|
115 |
feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
|
116 |
dataset = get_datasets(config["datasets"], feature_extractor)
|
@@ -123,7 +123,7 @@ def train_residual_dancer(config: dict):
|
|
123 |
train_env = TrainingEnvironment(model, criterion, config)
|
124 |
callbacks = [
|
125 |
# cb.LearningRateFinder(update_attr=True),
|
126 |
-
cb.EarlyStopping("val/loss", patience=
|
127 |
cb.StochasticWeightAveraging(1e-2),
|
128 |
cb.RichProgressBar(),
|
129 |
cb.DeviceStatsMonitor(),
|
|
|
110 |
TARGET_CLASSES = config["dance_ids"]
|
111 |
DEVICE = config["device"]
|
112 |
SEED = config["seed"]
|
113 |
+
torch.set_float32_matmul_precision("medium")
|
114 |
pl.seed_everything(SEED, workers=True)
|
115 |
feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
|
116 |
dataset = get_datasets(config["datasets"], feature_extractor)
|
|
|
123 |
train_env = TrainingEnvironment(model, criterion, config)
|
124 |
callbacks = [
|
125 |
# cb.LearningRateFinder(update_attr=True),
|
126 |
+
cb.EarlyStopping("val/loss", patience=1),
|
127 |
cb.StochasticWeightAveraging(1e-2),
|
128 |
cb.RichProgressBar(),
|
129 |
cb.DeviceStatsMonitor(),
|
models/weights/ResidualDancer/weights.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90a58841ce4f40f2981227b63dd848e474e8868795a57da84053e3281c4889c7
|
3 |
+
size 193643085
|
preprocessing/dataset.py
CHANGED
@@ -78,8 +78,8 @@ class SongDataset(Dataset):
|
|
78 |
return waveform, dance_labels
|
79 |
else:
|
80 |
# WARNING: Could cause train/test split leak
|
81 |
-
|
82 |
-
|
83 |
|
84 |
def _idx2audio_idx(self, idx: int) -> int:
|
85 |
return self._get_audio_loc_from_idx(idx)[0]
|
@@ -424,3 +424,7 @@ def record_audio_durations(folder: str):
|
|
424 |
|
425 |
with open(os.path.join(folder, "audio_durations.json"), "w") as f:
|
426 |
json.dump(durations, f)
|
|
|
|
|
|
|
|
|
|
78 |
return waveform, dance_labels
|
79 |
else:
|
80 |
# WARNING: Could cause train/test split leak
|
81 |
+
print("Invalid output, trying next index...")
|
82 |
+
return self[idx - 1]
|
83 |
|
84 |
def _idx2audio_idx(self, idx: int) -> int:
|
85 |
return self._get_audio_loc_from_idx(idx)[0]
|
|
|
424 |
|
425 |
with open(os.path.join(folder, "audio_durations.json"), "w") as f:
|
426 |
json.dump(durations, f)
|
427 |
+
|
428 |
+
|
429 |
+
class GTZAN:
|
430 |
+
pass
|
preprocessing/pipelines.py
CHANGED
@@ -74,6 +74,21 @@ class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
|
|
74 |
return spec
|
75 |
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
class WaveformPreprocessing(torch.nn.Module):
|
78 |
def __init__(self, expected_sample_length: int):
|
79 |
super().__init__()
|
|
|
74 |
return spec
|
75 |
|
76 |
|
77 |
+
class SpectrogramProductionPipeline(torch.nn.Module):
|
78 |
+
def __init__(self, sample_rate=16000, expected_duration=6, *args, **kwargs) -> None:
|
79 |
+
super().__init__(*args, **kwargs)
|
80 |
+
self.preprocess_waveform = WaveformPreprocessing(
|
81 |
+
sample_rate * expected_duration
|
82 |
+
)
|
83 |
+
self.audio_to_spectrogram = AudioToSpectrogram(
|
84 |
+
sample_rate=sample_rate,
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
88 |
+
waveform = self.preprocess_waveform(waveform)
|
89 |
+
return self.audio_to_spectrogram(waveform)
|
90 |
+
|
91 |
+
|
92 |
class WaveformPreprocessing(torch.nn.Module):
|
93 |
def __init__(self, expected_sample_length: int):
|
94 |
super().__init__()
|