wsntxxn commited on
Commit
b32bfa6
1 Parent(s): 65654c7

Upload Cnn8RnnSoundEventDetection

Browse files
Files changed (2) hide show
  1. config.json +4 -0
  2. hf_model.py +227 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "Cnn8RnnSoundEventDetection"
4
  ],
 
 
 
 
5
  "classes_num": 447,
6
  "torch_dtype": "float32",
7
  "transformers_version": "4.30.2"
 
2
  "architectures": [
3
  "Cnn8RnnSoundEventDetection"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "hf_model.Cnn8RnnConfig",
7
+ "AutoModel": "hf_model.Cnn8RnnSoundEventDetection"
8
+ },
9
  "classes_num": 447,
10
  "torch_dtype": "float32",
11
  "transformers_version": "4.30.2"
hf_model.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torchaudio import transforms
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import PretrainedConfig, PreTrainedModel
8
+ from transformers.utils.hub import cached_file
9
+
10
+
11
+ def init_layer(layer):
12
+ """Initialize a Linear or Convolutional layer. """
13
+ nn.init.xavier_uniform_(layer.weight)
14
+
15
+ if hasattr(layer, 'bias'):
16
+ if layer.bias is not None:
17
+ layer.bias.data.fill_(0.)
18
+
19
+ def init_bn(bn):
20
+ """Initialize a Batchnorm layer. """
21
+ bn.bias.data.fill_(0.)
22
+ bn.weight.data.fill_(1.)
23
+
24
+ def interpolate(x, ratio):
25
+ """Interpolate data in time domain. This is used to compensate the
26
+ resolution reduction in downsampling of a CNN.
27
+
28
+ Args:
29
+ x: (batch_size, time_steps, classes_num)
30
+ ratio: int, ratio to interpolate
31
+
32
+ Returns:
33
+ upsampled: (batch_size, time_steps * ratio, classes_num)
34
+ """
35
+ (batch_size, time_steps, classes_num) = x.shape
36
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
37
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
38
+ return upsampled
39
+
40
+ def pad_framewise_output(framewise_output, frames_num):
41
+ """Pad framewise_output to the same length as input frames. The pad value
42
+ is the same as the value of the last frame.
43
+
44
+ Args:
45
+ framewise_output: (batch_size, frames_num, classes_num)
46
+ frames_num: int, number of frames to pad
47
+
48
+ Outputs:
49
+ output: (batch_size, frames_num, classes_num)
50
+ """
51
+ pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
52
+ """tensor for padding"""
53
+
54
+ output = torch.cat((framewise_output, pad), dim=1)
55
+ """(batch_size, frames_num, classes_num)"""
56
+
57
+ return output
58
+
59
+
60
+ class ConvBlock(nn.Module):
61
+ def __init__(self, in_channels, out_channels):
62
+
63
+ super(ConvBlock, self).__init__()
64
+
65
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
66
+ out_channels=out_channels,
67
+ kernel_size=(3, 3), stride=(1, 1),
68
+ padding=(1, 1), bias=False)
69
+
70
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
71
+ out_channels=out_channels,
72
+ kernel_size=(3, 3), stride=(1, 1),
73
+ padding=(1, 1), bias=False)
74
+
75
+ self.bn1 = nn.BatchNorm2d(out_channels)
76
+ self.bn2 = nn.BatchNorm2d(out_channels)
77
+
78
+ self.init_weight()
79
+
80
+ def init_weight(self):
81
+ init_layer(self.conv1)
82
+ init_layer(self.conv2)
83
+ init_bn(self.bn1)
84
+ init_bn(self.bn2)
85
+
86
+
87
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
88
+
89
+ x = input
90
+ x = F.relu_(self.bn1(self.conv1(x)))
91
+ x = F.relu_(self.bn2(self.conv2(x)))
92
+ if pool_type == 'max':
93
+ x = F.max_pool2d(x, kernel_size=pool_size)
94
+ elif pool_type == 'avg':
95
+ x = F.avg_pool2d(x, kernel_size=pool_size)
96
+ elif pool_type == 'avg+max':
97
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
98
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
99
+ x = x1 + x2
100
+ else:
101
+ raise Exception('Incorrect argument!')
102
+
103
+ return x
104
+
105
+
106
+ class LinearSoftmax(nn.Module):
107
+ def __init__(self, pooldim=1):
108
+ super().__init__()
109
+ self.pooldim = pooldim
110
+
111
+ def forward(self, time_decision):
112
+ return (time_decision**2).sum(self.pooldim) / time_decision.sum(
113
+ self.pooldim)
114
+
115
+
116
+ class Cnn8RnnConfig(PretrainedConfig):
117
+
118
+ def __init__(
119
+ self,
120
+ classes_num: int = 447,
121
+ **kwargs
122
+ ):
123
+ self.classes_num = classes_num
124
+ super().__init__(**kwargs)
125
+
126
+
127
+ class Cnn8RnnSoundEventDetection(PreTrainedModel):
128
+
129
+ config_class = Cnn8RnnConfig
130
+
131
+ def __init__(self, config: Cnn8RnnConfig):
132
+ super().__init__(config)
133
+ self.config = config
134
+ self.time_resolution = 0.01
135
+ self.interpolate_ratio = 4 # Downsampled ratio
136
+
137
+ # Logmel spectrogram extractor
138
+ self.melspec_extractor = transforms.MelSpectrogram(
139
+ sample_rate=32000,
140
+ n_fft=1024,
141
+ win_length=1024,
142
+ hop_length=320,
143
+ f_min=50,
144
+ f_max=14000,
145
+ n_mels=64,
146
+ norm="slaney",
147
+ mel_scale="slaney"
148
+ )
149
+ self.db_transform = transforms.AmplitudeToDB()
150
+
151
+ self.bn0 = nn.BatchNorm2d(64)
152
+
153
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
154
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
155
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
156
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
157
+
158
+ self.fc1 = nn.Linear(512, 512, bias=True)
159
+ self.rnn = nn.GRU(512, 256, bidirectional=True, batch_first=True)
160
+ self.fc_audioset = nn.Linear(512, config.classes_num, bias=True)
161
+ self.temporal_pooling = LinearSoftmax()
162
+
163
+ self.init_weight()
164
+
165
+ def init_weight(self):
166
+ init_bn(self.bn0)
167
+ init_layer(self.fc1)
168
+ init_layer(self.fc_audioset)
169
+
170
+ def forward(self, waveform):
171
+ x = self.melspec_extractor(waveform)
172
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
173
+ x = x.transpose(1, 2)
174
+ x = x.unsqueeze(1)
175
+
176
+ frames_num = x.shape[2]
177
+
178
+ x = x.transpose(1, 3)
179
+ x = self.bn0(x)
180
+ x = x.transpose(1, 3)
181
+
182
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg+max')
183
+ x = F.dropout(x, p=0.2, training=self.training)
184
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg+max')
185
+ x = F.dropout(x, p=0.2, training=self.training)
186
+ x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg+max')
187
+ x = F.dropout(x, p=0.2, training=self.training)
188
+ x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg+max')
189
+ x = F.dropout(x, p=0.2, training=self.training) # (batch_size, 256, time_steps / 4, mel_bins / 16)
190
+ x = torch.mean(x, dim=3)
191
+
192
+ x = x.transpose(1, 2)
193
+ x = F.dropout(x, p=0.5, training=self.training)
194
+ x = F.relu_(self.fc1(x))
195
+ x, _ = self.rnn(x)
196
+ segmentwise_output = torch.sigmoid(self.fc_audioset(x)).clamp(1e-7, 1.)
197
+ clipwise_output = self.temporal_pooling(segmentwise_output)
198
+
199
+ # Get framewise output
200
+ framewise_output = interpolate(segmentwise_output,
201
+ self.interpolate_ratio)
202
+ framewise_output = pad_framewise_output(framewise_output, frames_num)
203
+
204
+ output_dict = {
205
+ 'framewise_output': framewise_output,
206
+ 'clipwise_output': clipwise_output
207
+ }
208
+
209
+ return output_dict
210
+
211
+
212
+ def save_pretrained(self, save_directory, *args, **kwargs):
213
+ super().save_pretrained(save_directory, *args, **kwargs)
214
+ with open(os.path.join(save_directory, "classes.txt"), "w") as f:
215
+ for class_name in self.classes:
216
+ f.write(class_name + "\n")
217
+
218
+
219
+ @classmethod
220
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
221
+ **kwargs):
222
+ model = super().from_pretrained(pretrained_model_name_or_path,
223
+ *model_args, **kwargs)
224
+ class_file = cached_file(pretrained_model_name_or_path, "classes.txt")
225
+ with open(class_file, "w") as f:
226
+ model.classes = [l.strip() for l in f]
227
+ return model