| import argparse |
|
|
| from transformers import pipeline |
| from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
| from datasets import load_dataset, Audio |
| import evaluate |
|
|
| wer_metric = evaluate.load("wer") |
|
|
|
|
| def is_target_text_in_range(ref): |
| if ref.strip() == "ignore time segment in scoring": |
| return False |
| else: |
| return ref.strip() != "" |
|
|
|
|
| def get_text(sample): |
| if "text" in sample: |
| return sample["text"] |
| elif "sentence" in sample: |
| return sample["sentence"] |
| elif "normalized_text" in sample: |
| return sample["normalized_text"] |
| elif "transcript" in sample: |
| return sample["transcript"] |
| elif "transcription" in sample: |
| return sample["transcription"] |
| else: |
| raise ValueError( |
| f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of " |
| ".join{sample.keys()}. Ensure a text column name is present in the dataset." |
| ) |
|
|
|
|
| whisper_norm = BasicTextNormalizer() |
|
|
|
|
| def normalise(batch): |
| batch["norm_text"] = whisper_norm(get_text(batch)) |
| return batch |
|
|
|
|
| def data(dataset): |
| for i, item in enumerate(dataset): |
| yield {**item["audio"], "reference": item["norm_text"]} |
|
|
|
|
| def main(args): |
| batch_size = args.batch_size |
| whisper_asr = pipeline( |
| "automatic-speech-recognition", model=args.model_id, device=args.device |
| ) |
|
|
| whisper_asr.model.config.forced_decoder_ids = ( |
| whisper_asr.tokenizer.get_decoder_prompt_ids( |
| language=args.language, task="transcribe" |
| ) |
| ) |
|
|
| dataset = load_dataset( |
| args.dataset, |
| args.config, |
| split=args.split, |
| streaming=args.streaming, |
| use_auth_token=True, |
| ) |
|
|
| |
| dataset = dataset.take(args.max_eval_samples) |
|
|
| dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) |
| dataset = dataset.map(normalise) |
| dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) |
|
|
| predictions = [] |
| references = [] |
|
|
| |
| for out in whisper_asr(data(dataset), batch_size=batch_size): |
| predictions.append(whisper_norm(out["text"])) |
| references.append(out["reference"][0]) |
|
|
| wer = wer_metric.compute(references=references, predictions=predictions) |
| wer = round(100 * wer, 2) |
|
|
| print("WER:", wer) |
| evaluate.push_to_hub( |
| model_id=args.model_id, |
| metric_value=wer, |
| metric_type="wer", |
| metric_name="WER", |
| dataset_name=args.dataset, |
| dataset_type=args.dataset, |
| dataset_split=args.split, |
| dataset_config=args.config, |
| task_type="automatic-speech-recognition", |
| task_name="Automatic Speech Recognition" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument( |
| "--model_id", |
| type=str, |
| required=True, |
| help="Model identifier. Should be loadable with 🤗 Transformers", |
| ) |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| default="mozilla-foundation/common_voice_11_0", |
| help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", |
| ) |
| parser.add_argument( |
| "--config", |
| type=str, |
| required=True, |
| help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice", |
| ) |
| parser.add_argument( |
| "--split", |
| type=str, |
| default="test", |
| help="Split of the dataset. *E.g.* `'test'`", |
| ) |
|
|
| parser.add_argument( |
| "--device", |
| type=int, |
| default=-1, |
| help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=16, |
| help="Number of samples to go through each streamed batch.", |
| ) |
| parser.add_argument( |
| "--max_eval_samples", |
| type=int, |
| default=None, |
| help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", |
| ) |
| parser.add_argument( |
| "--streaming", |
| type=bool, |
| default=True, |
| help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", |
| ) |
| parser.add_argument( |
| "--language", |
| type=str, |
| required=True, |
| help="Two letter language code for the transcription language, e.g. use 'en' for English.", |
| ) |
| args = parser.parse_args() |
|
|
| main(args) |