maolin.liu commited on
Commit ·
85378a6
1
Parent(s): 2278032
[feature]Support RabbitMQ.
Browse files- cli.py +10 -0
- consumer/__init__.py +0 -0
- consumer/asr.py +104 -0
- consumer/base.py +235 -0
- requirements.txt +1 -0
cli.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from consumer.asr import TranscribeConsumer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def cli():
|
| 5 |
+
transcribe_consumer = TranscribeConsumer()
|
| 6 |
+
transcribe_consumer.consume_messages(transcribe_consumer.queue_name, transcribe_consumer.consume)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if __name__ == '__main__':
|
| 10 |
+
cli()
|
consumer/__init__.py
ADDED
|
File without changes
|
consumer/asr.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import io
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import uuid
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Literal
|
| 8 |
+
|
| 9 |
+
from faster_whisper import WhisperModel
|
| 10 |
+
from pydantic import BaseModel, Field, ValidationError, model_validator
|
| 11 |
+
|
| 12 |
+
from .base import BasicMessageReceiver, BasicMessageSender, Headers, Priority
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TranscribeInputMessage(BaseModel):
|
| 16 |
+
uuid: str = Field(title='Request Unique Id.')
|
| 17 |
+
audio_file: str
|
| 18 |
+
language: Literal['en', 'zh',]
|
| 19 |
+
using_file_content: bool
|
| 20 |
+
|
| 21 |
+
@model_validator(mode='after')
|
| 22 |
+
def check_audio_file(self):
|
| 23 |
+
if self.using_file_content:
|
| 24 |
+
return self
|
| 25 |
+
|
| 26 |
+
if not Path(self.audio_file).exists():
|
| 27 |
+
raise FileNotFoundError(f'Audio file not exists.')
|
| 28 |
+
return self
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TranscribeOutputMessage(BaseModel):
|
| 32 |
+
uuid: str
|
| 33 |
+
if_success: bool
|
| 34 |
+
msg: str
|
| 35 |
+
transcribed_text: str = Field(default='')
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TranscribeConsumer(BasicMessageReceiver):
|
| 39 |
+
|
| 40 |
+
def __init__(self):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.exchange_name = 'transcribe'
|
| 44 |
+
self.queue_name = 'transcribe-input'
|
| 45 |
+
self.routing_key = 'transcribe-input'
|
| 46 |
+
|
| 47 |
+
self.setup_consume_parameters()
|
| 48 |
+
self.setup_message_sender()
|
| 49 |
+
|
| 50 |
+
model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
|
| 51 |
+
# Run on GPU with FP16
|
| 52 |
+
self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
|
| 53 |
+
|
| 54 |
+
def setup_consume_parameters(self):
|
| 55 |
+
self.declare_exchange(self.exchange_name)
|
| 56 |
+
self.declare_queue(self.queue_name, max_priority=-1)
|
| 57 |
+
self.bind_queue(self.exchange_name, self.queue_name, self.routing_key)
|
| 58 |
+
|
| 59 |
+
def setup_message_sender(self):
|
| 60 |
+
self.sender = BasicMessageSender()
|
| 61 |
+
|
| 62 |
+
def send_message(self, message: dict):
|
| 63 |
+
routing_key = 'transcribe-output'
|
| 64 |
+
headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
|
| 65 |
+
self.sender.send_message(
|
| 66 |
+
exchange_name=self.exchange_name,
|
| 67 |
+
routing_key=routing_key,
|
| 68 |
+
body=message,
|
| 69 |
+
headers=None
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def send_success_message(self, uuid: str, transcribed_text):
|
| 73 |
+
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
|
| 74 |
+
transcribed_text=transcribed_text)
|
| 75 |
+
self.send_message(message.model_dump())
|
| 76 |
+
|
| 77 |
+
def send_fail_message(self, uuid: str, error: str):
|
| 78 |
+
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
|
| 79 |
+
self.send_message(message.model_dump())
|
| 80 |
+
|
| 81 |
+
def consume(self, channel, method, properties, message):
|
| 82 |
+
body = self.decode_message(message)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
validated_message = TranscribeInputMessage.model_validate(body)
|
| 86 |
+
|
| 87 |
+
audio_file = validated_message.audio_file
|
| 88 |
+
if validated_message.using_file_content:
|
| 89 |
+
audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file))
|
| 90 |
+
|
| 91 |
+
segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language)
|
| 92 |
+
|
| 93 |
+
transcribed_text = ''
|
| 94 |
+
for segment in segments:
|
| 95 |
+
transcribed_text = segment.text
|
| 96 |
+
break
|
| 97 |
+
except ValidationError as exc:
|
| 98 |
+
logging.exception('Consume message failed: \n message: %s\n\n exception info: %s', message, exc)
|
| 99 |
+
self.send_fail_message(body.get('uuid'), f'{exc}')
|
| 100 |
+
except Exception as exc:
|
| 101 |
+
logging.exception('Consume message failed: \n message: %s\n\n exception info: %s', message, exc)
|
| 102 |
+
self.send_fail_message(body.get('uuid'), f'{exc}')
|
| 103 |
+
else:
|
| 104 |
+
self.send_success_message(validated_message.uuid, transcribed_text)
|
consumer/base.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import functools
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import ssl
|
| 7 |
+
import time
|
| 8 |
+
from enum import Enum
|
| 9 |
+
from typing import Dict, Optional, Literal
|
| 10 |
+
|
| 11 |
+
import msgpack
|
| 12 |
+
import pika
|
| 13 |
+
from pika.exceptions import AMQPConnectionError
|
| 14 |
+
from pydantic import BaseModel, field_validator
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def sync(f):
|
| 20 |
+
@functools.wraps(f)
|
| 21 |
+
def wrapper(*args, **kwargs):
|
| 22 |
+
return asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
|
| 23 |
+
|
| 24 |
+
return wrapper
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Priority(Enum):
|
| 28 |
+
LOW = 1
|
| 29 |
+
NORMAL = 5
|
| 30 |
+
HIGH = 10
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Headers(BaseModel):
|
| 34 |
+
job_id: str
|
| 35 |
+
priority: Priority
|
| 36 |
+
task_type: Optional[str] = None
|
| 37 |
+
|
| 38 |
+
@field_validator('priority', mode='before')
|
| 39 |
+
@classmethod
|
| 40 |
+
def _convert_priority(cls, value):
|
| 41 |
+
return Priority[value]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class RabbitMQConfig(BaseModel):
|
| 45 |
+
host: str
|
| 46 |
+
port: int
|
| 47 |
+
username: str
|
| 48 |
+
password: str
|
| 49 |
+
protocol: str
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class BasicPikaClient:
|
| 53 |
+
def __init__(self):
|
| 54 |
+
self.username = os.environ.get('RABBITMQ_USER', '')
|
| 55 |
+
self.password = os.environ.get('RABBITMQ_PASSWD', '')
|
| 56 |
+
self.host = os.environ.get('RABBITMQ_HOST', 'localhost')
|
| 57 |
+
self.port = os.environ.get('RABBITMQ_PORT', 5672)
|
| 58 |
+
self.protocol = "amqp"
|
| 59 |
+
|
| 60 |
+
self._init_connection_parameters()
|
| 61 |
+
self._connect()
|
| 62 |
+
|
| 63 |
+
def _connect(self):
|
| 64 |
+
tries = 0
|
| 65 |
+
while True:
|
| 66 |
+
try:
|
| 67 |
+
self.connection = pika.BlockingConnection(self.parameters)
|
| 68 |
+
self.channel = self.connection.channel()
|
| 69 |
+
if self.connection.is_open:
|
| 70 |
+
break
|
| 71 |
+
except (AMQPConnectionError, Exception) as e:
|
| 72 |
+
time.sleep(5)
|
| 73 |
+
tries += 1
|
| 74 |
+
if tries == 20:
|
| 75 |
+
raise AMQPConnectionError(e)
|
| 76 |
+
|
| 77 |
+
def _init_connection_parameters(self):
|
| 78 |
+
if any([self.username, self.password]):
|
| 79 |
+
self.credentials = pika.PlainCredentials(self.username, self.password)
|
| 80 |
+
self.parameters = pika.ConnectionParameters(
|
| 81 |
+
host=self.host,
|
| 82 |
+
port=int(self.port),
|
| 83 |
+
virtual_host="/",
|
| 84 |
+
credentials=self.credentials,
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
self.parameters = pika.ConnectionParameters(
|
| 88 |
+
self.host,
|
| 89 |
+
int(self.port),
|
| 90 |
+
"/",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if self.protocol == "amqps":
|
| 94 |
+
# SSL Context for TLS configuration of Amazon MQ for RabbitMQ
|
| 95 |
+
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
|
| 96 |
+
ssl_context.set_ciphers("ECDHE+AESGCM:!ECDSA")
|
| 97 |
+
self.parameters.ssl_options = pika.SSLOptions(context=ssl_context)
|
| 98 |
+
|
| 99 |
+
def check_connection(self):
|
| 100 |
+
if not self.connection or self.connection.is_closed:
|
| 101 |
+
self._connect()
|
| 102 |
+
|
| 103 |
+
def close(self):
|
| 104 |
+
self.channel.close()
|
| 105 |
+
self.connection.close()
|
| 106 |
+
|
| 107 |
+
def declare_queue(
|
| 108 |
+
self, queue_name, exclusive: bool = False, max_priority: int = 10
|
| 109 |
+
):
|
| 110 |
+
self.check_connection()
|
| 111 |
+
logger.debug(f"Trying to declare queue({queue_name})...")
|
| 112 |
+
|
| 113 |
+
self.channel.queue_declare(
|
| 114 |
+
queue=queue_name,
|
| 115 |
+
exclusive=exclusive,
|
| 116 |
+
durable=True,
|
| 117 |
+
arguments={"x-max-priority": max_priority} if max_priority > 0 else None
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def declare_exchange(self, exchange_name: str, exchange_type: str = "direct"):
|
| 121 |
+
self.check_connection()
|
| 122 |
+
self.channel.exchange_declare(
|
| 123 |
+
exchange=exchange_name, exchange_type=exchange_type
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def bind_queue(self, exchange_name: str, queue_name: str, routing_key: str):
|
| 127 |
+
self.check_connection()
|
| 128 |
+
self.channel.queue_bind(
|
| 129 |
+
exchange=exchange_name, queue=queue_name, routing_key=routing_key
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def unbind_queue(self, exchange_name: str, queue_name: str, routing_key: str):
|
| 133 |
+
self.channel.queue_unbind(
|
| 134 |
+
queue=queue_name, exchange=exchange_name, routing_key=routing_key
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class BasicMessageSender(BasicPikaClient):
|
| 139 |
+
message_encoding_type: Literal['bytes', 'json'] = 'json'
|
| 140 |
+
|
| 141 |
+
def encode_message(self, body: Dict, encoding_type: str = "bytes"):
|
| 142 |
+
if encoding_type == "bytes":
|
| 143 |
+
return msgpack.packb(body)
|
| 144 |
+
elif encoding_type == "json":
|
| 145 |
+
return json.dumps(body)
|
| 146 |
+
else:
|
| 147 |
+
raise NotImplementedError
|
| 148 |
+
|
| 149 |
+
def send_message(
|
| 150 |
+
self,
|
| 151 |
+
exchange_name: str,
|
| 152 |
+
routing_key: str,
|
| 153 |
+
body: Dict,
|
| 154 |
+
headers: Optional[Headers],
|
| 155 |
+
):
|
| 156 |
+
body = self.encode_message(body=body, encoding_type=self.message_encoding_type)
|
| 157 |
+
|
| 158 |
+
properties = pika.BasicProperties(delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE,
|
| 159 |
+
priority=headers.priority.value,
|
| 160 |
+
headers=headers.model_dump() if headers else None)
|
| 161 |
+
self.channel.basic_publish(
|
| 162 |
+
exchange=exchange_name,
|
| 163 |
+
routing_key=routing_key,
|
| 164 |
+
body=body,
|
| 165 |
+
properties=properties,
|
| 166 |
+
)
|
| 167 |
+
logger.debug(
|
| 168 |
+
f"Sent message. Exchange: {exchange_name}, Routing Key: {routing_key}, Body: {body[:128]}"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class BasicMessageReceiver(BasicPikaClient):
|
| 173 |
+
def __init__(self):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.channel_tag = None
|
| 176 |
+
|
| 177 |
+
def decode_message(self, body):
|
| 178 |
+
if type(body) == bytes:
|
| 179 |
+
return msgpack.unpackb(body)
|
| 180 |
+
elif type(body) == str:
|
| 181 |
+
return json.loads(body)
|
| 182 |
+
else:
|
| 183 |
+
raise NotImplementedError
|
| 184 |
+
|
| 185 |
+
def get_message(self, queue_name: str, auto_ack: bool = False):
|
| 186 |
+
method_frame, header_frame, body = self.channel.basic_get(
|
| 187 |
+
queue=queue_name, auto_ack=auto_ack
|
| 188 |
+
)
|
| 189 |
+
if method_frame:
|
| 190 |
+
logger.debug(f"{method_frame}, {header_frame}, {body}")
|
| 191 |
+
return method_frame, header_frame, body
|
| 192 |
+
else:
|
| 193 |
+
logger.debug("No message returned")
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
def consume_messages(self, queue, callback):
|
| 197 |
+
self.check_connection()
|
| 198 |
+
self.channel_tag = self.channel.basic_consume(
|
| 199 |
+
queue=queue, on_message_callback=callback, auto_ack=True
|
| 200 |
+
)
|
| 201 |
+
logger.debug(" [*] Waiting for messages. To exit press CTRL+C")
|
| 202 |
+
self.channel.start_consuming()
|
| 203 |
+
|
| 204 |
+
def cancel_consumer(self):
|
| 205 |
+
if self.channel_tag is not None:
|
| 206 |
+
self.channel.basic_cancel(self.channel_tag)
|
| 207 |
+
self.channel_tag = None
|
| 208 |
+
else:
|
| 209 |
+
logger.error("Do not cancel a non-existing job")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class ExtractFaceFramesConsumer(BasicMessageReceiver):
|
| 213 |
+
@sync
|
| 214 |
+
async def consume(self, channel, method, properties, body):
|
| 215 |
+
body = self.decode_message(body=body)
|
| 216 |
+
file_content = await self._download_image(img_url=body["url"])
|
| 217 |
+
# consume message logic ...
|
| 218 |
+
|
| 219 |
+
async def _download_image(self, img_url):
|
| 220 |
+
# do some async stuff here
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def create_consumer():
|
| 225 |
+
worker = ExtractFaceFramesConsumer()
|
| 226 |
+
worker.declare_queue(queue_name="myqueue")
|
| 227 |
+
worker.declare_exchange(exchange_name="myexchange")
|
| 228 |
+
worker.bind_queue(
|
| 229 |
+
exchange_name="myexchange", queue_name="myqueue", routing_key="randomkey"
|
| 230 |
+
)
|
| 231 |
+
worker.consume_messages(queue="myqueue", callback=worker.consume)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
create_consumer()
|
requirements.txt
CHANGED
|
@@ -3,3 +3,4 @@ fastapi==0.115.5
|
|
| 3 |
python-multipart==0.0.17
|
| 4 |
websockets==14.1
|
| 5 |
uvicorn==0.32.1
|
|
|
|
|
|
| 3 |
python-multipart==0.0.17
|
| 4 |
websockets==14.1
|
| 5 |
uvicorn==0.32.1
|
| 6 |
+
pika==1.3.2
|