| import base64 |
| import io |
| import json |
| import logging |
| import os |
| from pathlib import Path |
| from typing import Literal, Union |
|
|
| from faster_whisper import WhisperModel |
| from pydantic import BaseModel, Field, ValidationError, model_validator |
|
|
| from .base import BasicMessageReceiver, BasicMessageSender |
|
|
|
|
| def setup_logger(): |
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
| return logger |
|
|
|
|
| logger = setup_logger() |
|
|
|
|
| class TranscribeInputMessage(BaseModel): |
| uuid: str = Field(title='Request Unique Id.') |
| audio_file: str |
| language: Literal['en', 'zh',] |
| using_file_content: bool |
|
|
| @model_validator(mode='after') |
| def check_audio_file(self): |
| if self.using_file_content: |
| return self |
|
|
| if not Path(self.audio_file).exists(): |
| raise FileNotFoundError(f'Audio file not exists.') |
| return self |
|
|
|
|
| class TranscribeOutputMessage(BaseModel): |
| uuid: str |
| if_success: bool |
| msg: str |
| transcribed_text: str = Field(default='') |
|
|
|
|
| class TranscribeConsumer(BasicMessageReceiver): |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| self.exchange_name = 'transcribe' |
| self.input_queue_name = 'transcribe-input' |
| self.input_routing_key = 'transcribe-input' |
| self.output_queue_name = 'transcribe-output' |
| self.output_routing_key = 'transcribe-output' |
|
|
| self.setup_consume_parameters() |
| self.setup_producer_parameters() |
|
|
| logger.info('Loading model...') |
| model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3') |
| |
| self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16') |
| logger.info('Load model finished.') |
|
|
| def setup_consume_parameters(self): |
| logger.info( |
| f'Create consumer exchange: {self.exchange_name}, ' |
| f'routing-key: {self.input_routing_key}, ' |
| f'queue: {self.input_queue_name}' |
| ) |
| self.declare_exchange(self.exchange_name) |
| self.declare_queue(self.input_queue_name, max_priority=-1) |
| self.bind_queue(self.exchange_name, self.input_queue_name, self.input_routing_key) |
|
|
| def setup_producer_parameters(self): |
| logger.info( |
| f'Create producer exchange: {self.exchange_name}, ' |
| f'routing-key: {self.output_routing_key}, ' |
| f'queue: {self.output_queue_name}' |
| ) |
| self.declare_exchange(self.exchange_name) |
| self.declare_queue(self.output_queue_name, max_priority=-1) |
| self.bind_queue(self.exchange_name, self.output_queue_name, self.output_routing_key) |
|
|
| def send_message(self, message: Union[dict, str]): |
| routing_key = 'transcribe-output' |
| |
| sender = BasicMessageSender() |
| sender.send_message( |
| exchange_name=self.exchange_name, |
| routing_key=routing_key, |
| body=message, |
| headers=None |
| ) |
| logger.info(f'{"-" * 80}') |
| logger.info(f"Send message to Exchange: {self.exchange_name}, Routing-key: {routing_key}, \n" |
| f"Messgae body: {message}") |
| logger.info(f'{"-" * 80}') |
|
|
| def send_success_message(self, uuid: str, transcribed_text): |
| message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.', |
| transcribed_text=transcribed_text) |
| self.send_message(message.model_dump_json()) |
|
|
| def send_fail_message(self, uuid: str, error: str): |
| message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error) |
| self.send_message(message.model_dump_json()) |
|
|
| def consume(self, channel, method, properties, message): |
| logger.info(f'Recevied a message: {message}') |
| try: |
| body = self.decode_message(message) |
| except json.JSONDecodeError as exc: |
| logging.exception('Message decode failed: \n message:\n %s\n\n exception info:\n %s', message, exc) |
| self.send_fail_message('', f'Message decode failed, message: \n {message}') |
| return |
|
|
| try: |
| validated_message = TranscribeInputMessage.model_validate(body) |
|
|
| audio_file = validated_message.audio_file |
| if validated_message.using_file_content: |
| audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file)) |
|
|
| logger.info(f'Start transcribe input...') |
| segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language) |
|
|
| transcribed_segment_text = [] |
| for segment in segments: |
| transcribed_segment_text.append(segment.text) |
| transcribed_text = ', '.join(transcribed_segment_text) |
| logger.info(f'Transcribed text: {transcribed_text}') |
| except ValidationError as exc: |
| logging.exception('Message validated failed: \n message:\n %s\n\n exception info:\n %s', message, exc) |
| self.send_fail_message(body.get('uuid'), f'{exc}') |
| except Exception as exc: |
| logging.exception('Consume message failed: \n message:\n %s\n\n exception info:\n %s', message, exc) |
| self.send_fail_message(body.get('uuid'), f'{exc}') |
| else: |
| self.send_success_message(validated_message.uuid, transcribed_text) |
|
|