| |
| |
| import copy |
| import time |
|
|
| import torch |
| from datasets import load_dataset |
| from transformers import ( |
| AutoProcessor, |
| WhisperForConditionalGeneration, |
| ) |
|
|
|
|
| DEVICE = "cuda" |
| DTYPE = torch.float16 |
| SAMPLING_RATE = 16_000 |
| BATCH_SIZE = 1 |
| USE_FLASH_ATTN_2 = True |
|
|
| |
| GAMMAS = [5, 7, 6, 5, 4, 3, 5] |
| COUNT = 0 |
|
|
| |
| teacher = WhisperForConditionalGeneration.from_pretrained( |
| "/home/patrick/distil_whisper/", |
| torch_dtype=DTYPE, |
| variant="fp16", |
| low_cpu_mem_usage=True, |
| use_flash_attention_2=USE_FLASH_ATTN_2, |
| ) |
| student = WhisperForConditionalGeneration.from_pretrained( |
| "/home/patrick/distil_whisper_student/", |
| torch_dtype=DTYPE, |
| variant="fp16", |
| low_cpu_mem_usage=True, |
| use_flash_attention_2=USE_FLASH_ATTN_2, |
| ) |
| |
|
|
| student.generation_config = copy.deepcopy(teacher.generation_config) |
| student.generation_config.num_assistant_tokens_schedule = "constant" |
|
|
| |
| |
| |
| |
| |
| |
| |
| teacher.to(DEVICE) |
| student.to(DEVICE) |
|
|
| processor = AutoProcessor.from_pretrained("sanchit-gandhi/large-32-2-gpu-flat-lr") |
|
|
| ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
|
| total_time_default = 0 |
| total_time_spec = 0 |
| total_time_spec_2 = 0 |
|
|
| input_values = ds[0]["audio"]["array"] |
| inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE) |
| input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE) |
|
|
| _ = teacher.generate(input_features, max_length=100) |
|
|
| end_idx = ds.shape[0] |
| for audio_idx in range(0, end_idx, BATCH_SIZE): |
| input_values = ds[audio_idx : audio_idx + BATCH_SIZE] |
| input_values = [i["array"] for i in input_values["audio"]] |
|
|
| inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE) |
| input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE) |
|
|
| start_time = time.time() |
| out = teacher.generate(input_features, max_length=100) |
| run_time = time.time() - start_time |
| print(f"Normal Decoding: {run_time}") |
| total_time_default += run_time |
|
|
| default_out = processor.batch_decode(out, skip_special_tokens=True) |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| start_time = time.time() |
| with torch.no_grad(): |
| encoder_outputs = teacher.get_encoder()(input_features) |
|
|
| out = teacher.generate( |
| assistant_model=student, |
| assistant_encoder_outputs=encoder_outputs, |
| encoder_outputs=encoder_outputs, |
| max_length=100, |
| ) |
| run_time = time.time() - start_time |
|
|
| spec_out_2 = processor.batch_decode(out, skip_special_tokens=True) |
|
|
| print(f"Speculative Decoding 2: {run_time}") |
| total_time_spec_2 += run_time |
|
|
| if spec_out_2 != default_out: |
| COUNT += 1 |
| print(f"Audio {audio_idx} does not match. Spec: {spec_out_2}, True: {default_out}") |
|
|
|
|
| print(20 * "=") |
| print("Total time", total_time_default) |
| print(f"Overall speed-up spec 2 {total_time_default / total_time_spec_2}") |
| |
|
|