| import os |
| import argparse |
| from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES |
| from tqdm import tqdm |
| import torch |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
| from transformers.utils import is_flash_attn_2_available |
| from time import time |
|
|
| TRANSCRIPTOR_WHISPER = "openai/whisper-large-v3-turbo" |
| TRANSCRIPTOR_DISTIL_WHISPER = "distil-whisper/distil-large-v3" |
| TRANSCRIPTOR = TRANSCRIPTOR_DISTIL_WHISPER |
|
|
|
|
| def get_language_dict(): |
| language_dict = {} |
| |
| for language_name, language_code in LANGUAGE_NAME_TO_CODE.items(): |
| |
| lang_code = language_code.split('_')[0].lower() |
| |
| |
| if lang_code in WHISPER_LANGUAGES: |
| |
| language_dict[language_name] = { |
| "transcriber": lang_code, |
| "translator": language_code |
| } |
| return language_dict |
|
|
| def transcription_to_dict(transcription): |
| """ |
| Convierte una transcripción en formato string a un diccionario estructurado. |
| |
| Args: |
| transcription (str): String que contiene la transcripción con timestamps |
| |
| Returns: |
| dict: Diccionario con el texto completo y los chunks con sus timestamps |
| """ |
| try: |
| |
| if isinstance(transcription, str): |
| |
| transcription_dict = eval(transcription) |
| else: |
| transcription_dict = transcription |
| |
| |
| if not isinstance(transcription_dict, dict): |
| raise ValueError("La transcripción no tiene el formato esperado") |
| |
| if 'text' not in transcription_dict or 'chunks' not in transcription_dict: |
| raise ValueError("La transcripción no contiene los campos requeridos (text y chunks)") |
| |
| |
| cleaned_chunks = [] |
| for chunk in transcription_dict['chunks']: |
| |
| if (chunk.get('text') and |
| isinstance(chunk.get('timestamp'), (list, tuple)) and |
| len(chunk['timestamp']) == 2 and |
| chunk['timestamp'][0] is not None and |
| chunk['timestamp'][1] is not None): |
| |
| cleaned_chunks.append({ |
| 'start': float(chunk['timestamp'][0]), |
| 'end': float(chunk['timestamp'][1]), |
| 'text': chunk['text'].strip() |
| }) |
| |
| |
| result = { |
| 'text': transcription_dict['text'], |
| 'chunks': cleaned_chunks |
| } |
| |
| return result |
| |
| except Exception as e: |
| print(f"Error procesando la transcripción: {e}") |
| return None |
|
|
| def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5): |
| """ |
| Transcribe audio file using Whisper model. |
| |
| Args: |
| audio_file (str): Path to audio file |
| language (str): Language code for transcription |
| device (str): Device to use for inference ('cuda' or 'cpu') |
| chunk_length_s (int): Length of audio chunks in seconds |
| stride_length_s (int): Stride length between chunks in seconds |
| """ |
| output_folder = "transcriptions" |
| if not os.path.exists(output_folder): |
| os.makedirs(output_folder) |
|
|
| |
| audio_filename = os.path.basename(audio_file) |
| filename_without_ext = os.path.splitext(audio_filename)[0] |
| output_file = os.path.join(output_folder, f"{filename_without_ext}.srt") |
|
|
| device = torch.device(device) |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| model_id = TRANSCRIPTOR |
| t0 = time() |
| |
| |
| print(f"Using Flash Attention 2: {is_flash_attn_2_available()}") |
| if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: |
| model_kwargs = {"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"} |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| model_id, |
| torch_dtype=torch_dtype, |
| low_cpu_mem_usage=True, |
| use_safetensors=True, |
| **model_kwargs |
| ) |
| else: |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| model_id, |
| torch_dtype=torch_dtype, |
| low_cpu_mem_usage=True, |
| use_safetensors=True, |
| ) |
| model.to(device) |
|
|
| processor = AutoProcessor.from_pretrained(model_id) |
|
|
| timestamp = True |
| if TRANSCRIPTOR == TRANSCRIPTOR_DISTIL_WHISPER: |
| timestamp = "word" |
| else: |
| timestamp = True |
|
|
| |
| if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: |
| pipe = pipeline( |
| "automatic-speech-recognition", |
| model=model, |
| tokenizer=processor.tokenizer, |
| feature_extractor=processor.feature_extractor, |
| torch_dtype=torch_dtype, |
| device=device, |
| chunk_length_s=chunk_length_s, |
| stride_length_s=stride_length_s, |
| return_timestamps=timestamp, |
| max_new_tokens=128, |
| batch_size=24, |
| model_kwargs=model_kwargs |
| ) |
| else: |
| pipe = pipeline( |
| "automatic-speech-recognition", |
| model=model, |
| tokenizer=processor.tokenizer, |
| feature_extractor=processor.feature_extractor, |
| torch_dtype=torch_dtype, |
| device=device, |
| chunk_length_s=chunk_length_s, |
| stride_length_s=stride_length_s, |
| return_timestamps=timestamp, |
| max_new_tokens=128, |
| ) |
|
|
| |
| if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: |
| result = pipe( |
| audio_file, |
| return_timestamps=timestamp, |
| batch_size=24, |
| generate_kwargs={ |
| "language": language, |
| "task": "transcribe", |
| "use_cache": True, |
| "num_beams": 1 |
| } |
| ) |
| else: |
| result = pipe( |
| audio_file, |
| return_timestamps=timestamp, |
| generate_kwargs={ |
| "language": language, |
| "task": "transcribe", |
| "use_cache": True, |
| "num_beams": 1 |
| } |
| ) |
| |
| t = time() |
| print(f"Time to transcribe: {t - t0:.2f} seconds") |
|
|
| transcription_str = result |
| transcription_dict = transcription_to_dict(transcription_str) |
|
|
| return transcription_str, transcription_dict |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description='Transcribe audio files') |
| parser.add_argument('input_files', help='Input audio files') |
| parser.add_argument('language', help='Language of the audio file') |
| parser.add_argument('num_speakers', help='Number of speakers in the audio file') |
| parser.add_argument('device', help='Device to use for PyTorch inference') |
| args = parser.parse_args() |
|
|
| chunks_folder = "chunks" |
|
|
| with open(args.input_files, 'r') as f: |
| inputs = f.read().splitlines() |
| |
| progress_bar = tqdm(total=len(inputs), desc="Transcribe audio files progress") |
| for input in inputs: |
| input_file, _ = input.split('.') |
| _, input_name = input_file.split('/') |
| extension = "mp3" |
| file = f'{chunks_folder}/{input_name}.{extension}' |
| language_dict = get_language_dict() |
| transcribe(file, language_dict[args.language]["transcriber"], args.num_speakers, args.device) |
| progress_bar.update(1) |