| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Initialise a student Whisper model from a pre-trained teacher model for |
| teacher-student distillation. |
| """ |
|
|
| import argparse |
| import copy |
| import logging |
|
|
| import jax |
| import numpy as np |
| from flax.core import freeze, unfreeze |
| from transformers import GenerationConfig, WhisperFeatureExtractor, WhisperProcessor |
|
|
| from distil_whisper import FlaxWhisperForConditionalGeneration |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary." |
| ) |
| parser.add_argument( |
| "--teacher_checkpoint", |
| type=str, |
| required=True, |
| help="The HF Hub ID of the teacher checkpoint.", |
| ) |
| parser.add_argument( |
| "--subfolder", |
| type=str, |
| default="", |
| help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you " |
| "can specify the folder name here.", |
| ) |
| parser.add_argument( |
| "--encoder_layers", |
| type=int, |
| default=None, |
| help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.", |
| ) |
| parser.add_argument( |
| "--decoder_layers", |
| type=int, |
| default=2, |
| help="Number of decoder layers to use in the student model. Defaults to 2 layers.", |
| ) |
| parser.add_argument( |
| "--max_source_positions", |
| type=int, |
| default=None, |
| help="The maximum sequence length of log-mel filter-bank features that this model might ever be used with. Can " |
| "be used to create a student model with a shorter context length than the teacher model. Defaults to the number " |
| "of source positions in the teacher model (1500).", |
| ) |
| parser.add_argument( |
| "--save_dir", |
| type=str, |
| required=True, |
| help="Where to save the student weights and processor.", |
| ) |
| parser.add_argument( |
| "--push_to_hub", |
| type=bool, |
| required=False, |
| default=False, |
| help="Whether to push the student weights and processor to the Hub.", |
| ) |
| parser.add_argument( |
| "--cache_dir", |
| type=str, |
| default=None, |
| help="Where to store the pretrained models downloaded from huggingface.co", |
| ) |
|
|
| args = parser.parse_args() |
| return args |
|
|
|
|
| def init_student_model_from_teacher( |
| teacher_checkpoint, |
| encoder_layers=None, |
| decoder_layers=2, |
| max_source_positions=None, |
| save_dir=None, |
| push_to_hub=None, |
| cache_dir=None, |
| subfolder="", |
| ): |
| teacher_model, teacher_params = FlaxWhisperForConditionalGeneration.from_pretrained( |
| teacher_checkpoint, |
| _do_init=False, |
| cache_dir=cache_dir, |
| subfolder=subfolder, |
| ) |
| processor = WhisperProcessor.from_pretrained(teacher_checkpoint) |
| generation_config = GenerationConfig.from_pretrained(teacher_checkpoint) |
|
|
| teacher_config = teacher_model.config |
| teacher_encoder_layers = teacher_config.encoder_layers |
| teacher_decoder_layers = teacher_config.decoder_layers |
|
|
| student_config = copy.deepcopy(teacher_config) |
| student_config.update( |
| { |
| "encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers, |
| "decoder_layers": decoder_layers, |
| "max_source_positions": ( |
| max_source_positions if max_source_positions is not None else student_config.max_source_positions |
| ), |
| } |
| ) |
|
|
| encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int) |
| encoder_mapping[-1] = teacher_encoder_layers - 1 |
|
|
| encoder_map = {} |
| for student_layer, teacher_layer in enumerate(encoder_mapping): |
| encoder_map[str(teacher_layer)] = str(student_layer) |
|
|
| decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int) |
| decoder_mapping[-1] = teacher_decoder_layers - 1 |
|
|
| decoder_map = {} |
| for student_layer, teacher_layer in enumerate(decoder_mapping): |
| decoder_map[str(teacher_layer)] = str(student_layer) |
|
|
| |
| student_params = unfreeze(teacher_params) |
| student_params["model"]["decoder"]["layers"] = {} |
|
|
| for layer in teacher_params["model"]["decoder"]["layers"]: |
| if layer in decoder_map: |
| |
| student_params["model"]["decoder"]["layers"][decoder_map[layer]] = teacher_params["model"]["decoder"][ |
| "layers" |
| ][layer] |
|
|
| if encoder_layers is not None: |
| student_params["model"]["encoder"]["layers"] = {} |
| for layer in teacher_params["model"]["encoder"]["layers"]: |
| if layer in encoder_map: |
| |
| student_params["model"]["encoder"]["layers"][encoder_map[layer]] = teacher_params["model"]["encoder"][ |
| "layers" |
| ][layer] |
|
|
| if max_source_positions is not None: |
| |
| student_params["model"]["encoder"]["embed_positions"]["embedding"] = teacher_params["model"]["encoder"][ |
| "embed_positions" |
| ]["embedding"][: student_config.max_source_positions, :] |
| |
| chunk_length = int(student_config.max_source_positions * 2 / 100) |
| processor.feature_extractor = WhisperFeatureExtractor(chunk_length=chunk_length) |
|
|
| |
| del teacher_params, teacher_model |
|
|
| |
| student_params = freeze(student_params) |
| student_model = FlaxWhisperForConditionalGeneration(student_config, _do_init=False) |
|
|
| if save_dir is not None: |
| student_model.save_pretrained(save_dir, params=student_params) |
| |
| processor.save_pretrained(save_dir) |
| generation_config.save_pretrained(save_dir) |
|
|
| |
| logger.info("Checking we can load the saved model...") |
| student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained( |
| save_dir, |
| _do_init=False, |
| ) |
| processor = WhisperProcessor.from_pretrained(save_dir) |
|
|
| |
| input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="np").input_features |
| decoder_start_token_id = student_model.config.decoder_start_token_id |
| decoder_input_ids = np.ones((input_features.shape[0], 1)) * decoder_start_token_id |
|
|
| |
| logger.info("Checking we can run the converted model forward...") |
| _ = student_model(input_features, decoder_input_ids=decoder_input_ids, params=student_params).logits |
| logger.info("Conversion successful!") |
|
|
| if push_to_hub: |
| student_model.push_to_hub(save_dir, params=student_params) |
| processor.push_to_hub(save_dir) |
| generation_config.push_to_hub(save_dir) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
|
|
| |
| logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) |
|
|
| init_student_model_from_teacher( |
| teacher_checkpoint=args.teacher_checkpoint, |
| encoder_layers=args.encoder_layers, |
| decoder_layers=args.decoder_layers, |
| max_source_positions=args.max_source_positions, |
| save_dir=args.save_dir, |
| push_to_hub=args.push_to_hub, |
| cache_dir=args.cache_dir, |
| subfolder=args.subfolder, |
| ) |
|
|