Wav2Vec2-Large-XLSR-53-Japanese

Fine-tuned facebook/wav2vec2-large-xlsr-53 on Japanese using the Common Voice and Japanese speech corpus of Saruwatari-lab, University of Tokyo JSUT. When using this model, make sure that your speech input is sampled at 16kHz.

Usage

The model can be used directly (without a language model) as follows:

!pip install mecab-python3
!pip install unidic-lite
!pip install pykakasi
!python -m unidic download
import torch
import torchaudio
import librosa
from datasets import load_dataset
import MeCab
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re
# config
wakati = MeCab.Tagger("-Owakati")
chars_to_ignore_regex = '[\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\,\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใ€\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใ€‚\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\๏ผŽ\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใ€Œ\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใ€\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\โ€ฆ\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\๏ผŸ\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใƒป]'
kakasi = pykakasi.kakasi()
kakasi.setMode("J","H")
kakasi.setMode("K","H")
kakasi.setMode("r","Hepburn")
conv = kakasi.getConverter()
# load data, processor and model
test_dataset = load_dataset("common_voice", "ja", split="test[:2%]")
processor = Wav2Vec2Processor.from_pretrained("vumichien/wav2vec2-large-xlsr-japanese-hแป‰ragana")
model = Wav2Vec2ForCTC.from_pretrained("vumichien/wav2vec2-large-xlsr-japanese-hแป‰ragana")
resampler = lambda sr, y: librosa.resample(y.numpy().squeeze(), sr, 16_000)
# Preprocessing the datasets.
def speech_file_to_array_fn(batch):
    batch["sentence"] = conv.do(wakati.parse(batch["sentence"]).strip())
    batch["sentence"] = re.sub(chars_to_ignore_regex,'', batch["sentence"]).strip()
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(sampling_rate, speech_array).squeeze()
    return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
    logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])

Evaluation

The model can be evaluated as follows on the Japanese test data of Common Voice.

!pip install mecab-python3
!pip install unidic-lite
!pip install pykakasi
!python -m unidic download
import torch
import librosa
import torchaudio
from datasets import load_dataset, load_metric
import MeCab
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re
#config
wakati = MeCab.Tagger("-Owakati")
chars_to_ignore_regex = '[\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\,\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใ€\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใ€‚\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\๏ผŽ\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใ€Œ\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใ€\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\โ€ฆ\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\๏ผŸ\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ใƒป]'
kakasi = pykakasi.kakasi()
kakasi.setMode("J","H")
kakasi.setMode("K","H")
kakasi.setMode("r","Hepburn")
conv = kakasi.getConverter()
# load data, processor and model
test_dataset = load_dataset("common_voice", "ja", split="test")
wer = load_metric("wer")
cer = load_metric("cer")
processor = Wav2Vec2Processor.from_pretrained("vumichien/wav2vec2-large-xlsr-japanese-hแป‰ragana")
model = Wav2Vec2ForCTC.from_pretrained("vumichien/wav2vec2-large-xlsr-japanese-hแป‰ragana")
model.to("cuda")
resampler = lambda sr, y: librosa.resample(y.numpy().squeeze(), sr, 16_000)
# Preprocessing the datasets.
def speech_file_to_array_fn(batch):
    batch["sentence"] = conv.do(wakati.parse(batch["sentence"]).strip())
    batch["sentence"] = re.sub(chars_to_ignore_regex,'', batch["sentence"]).strip()
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(sampling_rate, speech_array).squeeze()
    return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
# evaluate function
def evaluate(batch):
    inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_strings"] = processor.batch_decode(pred_ids)
    return batch
result = test_dataset.map(evaluate, batched=True, batch_size=8)
print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
print("CER: {:2f}".format(100 * cer.compute(predictions=result["pred_strings"], references=result["sentence"])))

Test Result

WER: 24.74%, CER: 10.99%

Training

The Common Voice train, validation datasets and Japanese speech corpus datasets were used for training.

Downloads last month
581
Safetensors
Model size
0.3B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Dataset used to train vumichien/wav2vec2-large-xlsr-japanese-hiragana

Spaces using vumichien/wav2vec2-large-xlsr-japanese-hiragana 4

Evaluation results