| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ Fine-tuning a 🤗 Transformers Whisper model for automatic speech recognition""" |
|
|
| import functools |
| import json |
| import logging |
| import os |
| import re |
| import sys |
| import warnings |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, List, Optional, Union |
| import evaluate |
|
|
| import numpy as np |
| import torch |
| from pprint import pprint |
| import evaluate |
| from datasets import DatasetDict, load_dataset |
| from datasets import Audio |
|
|
| from transformers import ( |
| HfArgumentParser, |
| TrainingArguments, |
| set_seed, |
| WhisperFeatureExtractor, |
| WhisperTokenizer, |
| WhisperForConditionalGeneration, |
| WhisperProcessor, |
| Seq2SeqTrainer, |
| Seq2SeqTrainingArguments, |
| ) |
| from transformers.trainer_utils import get_last_checkpoint, is_main_process |
| from transformers.utils import check_min_version |
| from transformers.utils.versions import require_version |
|
|
|
|
| def list_field(default=None, metadata=None): |
| return field(default_factory=lambda: default, metadata=metadata) |
|
|
|
|
| @dataclass |
| class Seq2SeqTrainingArguments(TrainingArguments): |
| """ |
| Args: |
| sortish_sampler (`bool`, *optional*, defaults to `False`): |
| Whether to use a *sortish sampler* or not. Only possible if the underlying datasets are *Seq2SeqDataset* |
| for now but will become generally available in the near future. |
| It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness |
| for the training set. |
| predict_with_generate (`bool`, *optional*, defaults to `False`): |
| Whether to use generate to calculate generative metrics (ROUGE, BLEU). |
| generation_max_length (`int`, *optional*): |
| The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the |
| `max_length` value of the model configuration. |
| generation_num_beams (`int`, *optional*): |
| The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the |
| `num_beams` value of the model configuration. |
| """ |
|
|
| sortish_sampler: bool = field(default=False, metadata={ |
| "help": "Whether to use SortishSampler or not."}) |
| predict_with_generate: bool = field( |
| default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} |
| ) |
| generation_max_length: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default " |
| "to the `max_length` value of the model configuration." |
| ) |
| }, |
| ) |
| generation_num_beams: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default " |
| "to the `num_beams` value of the model configuration." |
| ) |
| }, |
| ) |
| xla: bool = field(default=False, metadata={ |
| "help": "Whether to activate the XLA compilation or not"}) |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| """ |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. |
| """ |
|
|
| model_name_or_path: str = field( |
| metadata={ |
| "help": "Path to pretrained model or model identifier from huggingface.co/models"} |
| ) |
| language: str = field( |
| metadata={"help": "Whisper specific language"} |
| ) |
| task: str = field( |
| metadata={ |
| "help": "Whisper specific task, i.e., 'transcribe' or 'translate'"} |
| ) |
| tokenizer_name_or_path: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models"}, |
| ) |
| cache_dir: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, |
| ) |
| freeze_feature_encoder: bool = field( |
| default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} |
| ) |
| attention_dropout: float = field( |
| default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."} |
| ) |
| activation_dropout: float = field( |
| default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."} |
| ) |
| feat_proj_dropout: float = field(default=0.0, metadata={ |
| "help": "The dropout ratio for the projected features."}) |
| hidden_dropout: float = field( |
| default=0.0, |
| metadata={ |
| "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler." |
| }, |
| ) |
| final_dropout: float = field( |
| default=0.0, |
| metadata={ |
| "help": "The dropout probability for the final projection layer."}, |
| ) |
| mask_time_prob: float = field( |
| default=0.05, |
| metadata={ |
| "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector" |
| "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature" |
| "vectors will be masked along the time axis." |
| }, |
| ) |
| mask_time_length: int = field( |
| default=10, |
| metadata={"help": "Length of vector span to mask along the time axis."}, |
| ) |
| mask_feature_prob: float = field( |
| default=0.0, |
| metadata={ |
| "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" |
| "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." |
| }, |
| ) |
| mask_feature_length: int = field( |
| default=10, |
| metadata={"help": "Length of vector span to mask along the feature axis."}, |
| ) |
| layerdrop: float = field(default=0.0, metadata={ |
| "help": "The LayerDrop probability."}) |
| ctc_loss_reduction: Optional[str] = field( |
| default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} |
| ) |
| ctc_zero_infinity: Optional[bool] = field( |
| default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."} |
| ) |
|
|
|
|
| @dataclass |
| class DataTrainingArguments: |
| """ |
| Arguments pertaining to what data we are going to input our model for training and eval. |
| |
| Using `HfArgumentParser` we can turn this class |
| into argparse arguments to be able to specify them on |
| the command line. |
| """ |
|
|
| dataset_name: str = field( |
| metadata={ |
| "help": "The configuration name of the dataset to use (via the datasets library)."} |
| ) |
| dataset_config_name: str = field( |
| default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
| ) |
| train_split_name: str = field( |
| default="train", |
| metadata={ |
| "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" |
| }, |
| ) |
| eval_split_name: str = field( |
| default="test", |
| metadata={ |
| "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" |
| }, |
| ) |
| audio_column_name: str = field( |
| default="audio", |
| metadata={ |
| "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, |
| ) |
| text_column_name: str = field( |
| default="sentence", |
| metadata={ |
| "help": "The name of the dataset column containing the text data. Defaults to 'sentence'"}, |
| ) |
| overwrite_cache: bool = field( |
| default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} |
| ) |
| preprocessing_num_workers: Optional[int] = field( |
| default=None, |
| metadata={"help": "The number of processes to use for the preprocessing."}, |
| ) |
| max_train_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": "For debugging purposes or quicker training, truncate the number of training examples to this " |
| "value if set." |
| }, |
| ) |
| max_eval_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " |
| "value if set." |
| }, |
| ) |
| chars_to_ignore: Optional[List[str]] = list_field( |
| default=None, |
| metadata={"help": "A list of characters to remove from the transcripts."}, |
| ) |
| eval_metrics: List[str] = list_field( |
| default=["wer"], |
| metadata={ |
| "help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"}, |
| ) |
| max_duration_in_seconds: float = field( |
| default=20.0, |
| metadata={ |
| "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" |
| }, |
| ) |
| min_duration_in_seconds: float = field( |
| default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} |
| ) |
| preprocessing_only: bool = field( |
| default=False, |
| metadata={ |
| "help": "Whether to only do data preprocessing and skip training. " |
| "This is especially useful when data preprocessing errors out in distributed training due to timeout. " |
| "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " |
| "so that the cached datasets can consequently be loaded in distributed training" |
| }, |
| ) |
| use_auth_token: bool = field( |
| default=False, |
| metadata={ |
| "help": "If :obj:`True`, will use the token generated when running" |
| ":obj:`transformers-cli login` as HTTP bearer authorization for remote files." |
| }, |
| ) |
| unk_token: str = field( |
| default="[UNK]", |
| metadata={"help": "The unk token for the tokenizer"}, |
| ) |
| pad_token: str = field( |
| default="[PAD]", |
| metadata={"help": "The padding token for the tokenizer"}, |
| ) |
| word_delimiter_token: str = field( |
| default="|", |
| metadata={"help": "The word delimiter token for the tokenizer"}, |
| ) |
|
|
| phoneme_language: Optional[str] = field( |
| default=None, |
| metadata={ |
| "help": "The target language that should be used be" |
| " passed to the tokenizer for tokenization. Note that" |
| " this is only relevant if the model classifies the" |
| " input audio to a sequence of phoneme sequences." |
| }, |
| ) |
| print_training_arguments: bool = field( |
| default=True, |
| metadata={ |
| "help": "Prints the training arguments. For debugging" |
| }, |
| ) |
|
|
|
|
| @dataclass |
| class DataCollatorSpeechSeq2SeqWithPadding: |
| processor: Any |
|
|
| def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
| |
| |
| input_features = [{"input_features": feature["input_features"]} |
| for feature in features] |
| batch = self.processor.feature_extractor.pad( |
| input_features, return_tensors="pt") |
|
|
| |
| label_features = [{"input_ids": feature["labels"]} |
| for feature in features] |
| |
| labels_batch = self.processor.tokenizer.pad( |
| label_features, return_tensors="pt") |
|
|
| |
| labels = labels_batch["input_ids"].masked_fill( |
| labels_batch.attention_mask.ne(1), -100) |
|
|
| |
| |
| if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): |
| labels = labels[:, 1:] |
|
|
| batch["labels"] = labels |
| return batch |
|
|
|
|
| def main(): |
| |
| |
| |
| parser = HfArgumentParser( |
| (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
| |
| def compute_metrics(pred): |
| pred_ids = pred.predictions |
| label_ids = pred.label_ids |
|
|
| |
| label_ids[label_ids == -100] = tokenizer.pad_token_id |
|
|
| |
| pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
| label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
|
| wer = 100 * metric.compute(predictions=pred_str, references=label_str) |
|
|
| return {"wer": wer} |
|
|
| |
| def prepare_dataset(batch): |
| |
| audio = batch["audio"] |
|
|
| |
| batch["input_features"] = feature_extractor( |
| audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] |
|
|
| |
| batch["labels"] = tokenizer(batch["sentence"]).input_ids |
| return batch |
|
|
| def print_training_arguments(model_args, data_args, training_args): |
| print("Starting with the following parameters:") |
| print("\n* Model arguments:") |
| pprint(vars(model_args), indent=2) |
| print("\n* Data arguments") |
| pprint(vars(data_args), indent=2) |
| print("\n* Training arguments") |
| pprint(vars(training_args), indent=2) |
|
|
|
|
| |
| if data_args.print_training_arguments: |
| print_training_arguments(model_args, data_args, training_args) |
|
|
| |
| |
| |
| feature_extractor = WhisperFeatureExtractor.from_pretrained( |
| model_args.model_name_or_path) |
| tokenizer = WhisperTokenizer.from_pretrained( |
| model_args.model_name_or_path, language=model_args.language, task=model_args.task) |
| processor = WhisperProcessor.from_pretrained( |
| model_args.model_name_or_path, language=model_args.language, task=model_args.task) |
| data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) |
|
|
| |
| processor.save_pretrained(training_args.output_dir) |
| tokenizer.save_pretrained(training_args.output_dir) |
|
|
|
|
| |
| train_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name, |
| split="train", streaming=True, use_auth_token=True) |
| eval_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name, |
| split="test", streaming=True, use_auth_token=True) |
| |
| |
| column_names=[x for x in train_dataset.info.features] |
| |
| |
| train_dataset = train_dataset.cast_column(data_args.audio_column_name, Audio(sampling_rate=16000)) |
| eval_dataset = eval_dataset.cast_column(data_args.audio_column_name, Audio(sampling_rate=16000)) |
| |
| |
| if data_args.audio_column_name != "audio": |
| train_dataset = train_dataset.rename_column( |
| data_args.audio_column_name, "audio") |
| eval_dataset = eval_dataset.rename_column( |
| data_args.audio_column_name, "audio") |
| column_names.remove(data_args.audio_column_name) |
|
|
| if data_args.text_column_name != "sentence": |
| train_dataset = train_dataset.rename_column( |
| data_args.text_column_name, "sentence") |
| eval_dataset = eval_dataset.rename_column( |
| data_args.text_column_name, "sentence") |
| column_names.remove(data_args.text_column_name) |
| |
| |
| train_dataset = train_dataset.map(prepare_dataset, remove_columns=column_names) |
| eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=column_names) |
| |
| |
| metric = evaluate.load("wer") |
|
|
| |
| last_checkpoint = None |
| if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: |
| last_checkpoint = get_last_checkpoint(training_args.output_dir) |
| if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: |
| raise ValueError( |
| f"Output directory ({training_args.output_dir}) already exists and is not empty. " |
| "Use --overwrite_output_dir to overcome." |
| ) |
| elif last_checkpoint is not None: |
| logger.info( |
| f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " |
| "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." |
|
|
| ) |
|
|
| |
| if training_args.do_train: |
| |
| |
|
|
| |
| if last_checkpoint is not None: |
| print("*** Found a checkpoint!") |
| checkpoint = last_checkpoint |
| elif os.path.isdir(model_args.model_name_or_path): |
| print("*** Loading checkpoint from parameters") |
| checkpoint = model_args.model_name_or_path |
| else: |
| checkpoint = None |
|
|
| |
| model = WhisperForConditionalGeneration.from_pretrained( |
| "openai/whisper-small", use_cache=False) |
|
|
| |
| model.config.forced_decoder_ids = None |
| model.config.suppress_tokens = [] |
|
|
| |
| set_seed(training_args.seed) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| trainer = Seq2SeqTrainer( |
| args=training_args, |
| model=model, |
| train_dataset=train_dataset.with_format("torch"), |
| eval_dataset=eval_dataset.with_format( |
| "torch").take(data_args.max_eval_samples), |
| data_collator=data_collator, |
| compute_metrics=compute_metrics, |
| tokenizer=processor.feature_extractor, |
| ) |
|
|
| train_result = trainer.train(resume_from_checkpoint=checkpoint) |
| trainer.save_model() |
|
|
| metrics = train_result.metrics |
| trainer.log_metrics("train", metrics) |
| trainer.save_metrics("train", metrics) |
|
|
| trainer.save_state() |
|
|
| |
| config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na" |
| kwargs = { |
| "finetuned_from": model_args.model_name_or_path, |
| "tasks": "automatic-speech-recognition", |
| "tags": ["hf-asr-leaderboard", "automatic-speech-recognition", data_args.dataset_name], |
| "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}", |
| "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}" |
| } |
|
|
| if training_args.push_to_hub: |
| trainer.push_to_hub(**kwargs) |
| else: |
| trainer.create_model_card(**kwargs) |
|
|
| return train_result |
|
|
| |
| def _mp_fn(index): |
| |
| print("The XLA is initiated") |
| main() |
|
|
| if __name__ == "__main__": |
| main() |
|
|