VoxPopuli
Collection
A collection of open-source artefacts (datasets + checkpoints) from the first VoxPopuli release.
•
32 items
•
Updated
•
3
Facebook's Wav2Vec2 base model pretrained on the 10K unlabeled subset of VoxPopuli corpus and fine-tuned on the transcribed data in en (refer to Table 1 of paper for more information).
Authors: Changhan Wang, Morgane Riviere, Ann Lee, Anne Wu, Chaitanya Talnikar, Daniel Haziza, Mary Williamson, Juan Pino, Emmanuel Dupoux from Facebook AI
See the official website for more information, here
In the following it is shown how the model can be used in inference on a sample of the Common Voice dataset
#!/usr/bin/env python3
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset
import torchaudio
import torch
# resample audio
# load model & processor
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-10k-voxpopuli-ft-en")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-10k-voxpopuli-ft-en")
# load dataset
ds = load_dataset("common_voice", "en", split="validation[:1%]")
# common voice does not match target sampling rate
common_voice_sample_rate = 48000
target_sample_rate = 16000
resampler = torchaudio.transforms.Resample(common_voice_sample_rate, target_sample_rate)
# define mapping fn to read in sound file and resample
def map_to_array(batch):
speech, _ = torchaudio.load(batch["path"])
speech = resampler(speech)
batch["speech"] = speech[0]
return batch
# load all audio files
ds = ds.map(map_to_array)
# run inference on the first 5 data samples
inputs = processor(ds[:5]["speech"], sampling_rate=target_sample_rate, return_tensors="pt", padding=True)
# inference
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, axis=-1)
print(processor.batch_decode(predicted_ids))