| #!/bin/bash |
|
|
| accelerate launch --multi_gpu --mixed_precision=bf16 --num_processes=2 run_distillation_pt.py \ |
| --model_name_or_path distil-whisper/large-32-2 \ |
| --teacher_model_name_or_path openai/whisper-large-v2 \ |
| --train_dataset_config_name all+all+all+l \ |
| --train_dataset_samples 2.9+10.4+14.9+226.6 \ |
| --train_dataset_name librispeech_asr+librispeech_asr+librispeech_asr+gigaspeech-l \ |
| --train_split_name train.clean.100+train.clean.360+train.other.500+train \ |
| --eval_dataset_name librispeech_asr+librispeech_asr+gigaspeech-l \ |
| --eval_dataset_config_name all+all+l \ |
| --eval_split_name validation.clean+validation.other+validation \ |
| --eval_text_column_name text+text+text \ |
| --eval_steps 2500 \ |
| --save_steps 2500 \ |
| --warmup_steps 50 \ |
| --learning_rate 0.0001 \ |
| --lr_scheduler_type constant_with_warmup \ |
| --logging_steps 25 \ |
| --save_total_limit 1 \ |
| --max_steps 10000 \ |
| --wer_threshold 10 \ |
| --per_device_train_batch_size 64 \ |
| --gradient_accumulation_steps 2 \ |
| --per_device_eval_batch_size 64 \ |
| --dataloader_num_workers 16 \ |
| --cache_dir /fsx/sanchit/cache \ |
| --dataset_cache_dir /fsx/sanchit/cache \ |
| --dtype bfloat16 \ |
| --output_dir ./ \ |
| --wandb_project distil-whisper-training \ |
| --do_train \ |
| --do_eval \ |
| --gradient_checkpointing \ |
| --overwrite_output_dir \ |
| --predict_with_generate \ |
| --freeze_encoder \ |
| --streaming |
|
|