maolin.liu commited on
Commit
85378a6
·
1 Parent(s): 2278032

[feature]Support RabbitMQ.

Browse files
Files changed (5) hide show
  1. cli.py +10 -0
  2. consumer/__init__.py +0 -0
  3. consumer/asr.py +104 -0
  4. consumer/base.py +235 -0
  5. requirements.txt +1 -0
cli.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from consumer.asr import TranscribeConsumer
2
+
3
+
4
+ def cli():
5
+ transcribe_consumer = TranscribeConsumer()
6
+ transcribe_consumer.consume_messages(transcribe_consumer.queue_name, transcribe_consumer.consume)
7
+
8
+
9
+ if __name__ == '__main__':
10
+ cli()
consumer/__init__.py ADDED
File without changes
consumer/asr.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import logging
4
+ import os
5
+ import uuid
6
+ from pathlib import Path
7
+ from typing import Literal
8
+
9
+ from faster_whisper import WhisperModel
10
+ from pydantic import BaseModel, Field, ValidationError, model_validator
11
+
12
+ from .base import BasicMessageReceiver, BasicMessageSender, Headers, Priority
13
+
14
+
15
+ class TranscribeInputMessage(BaseModel):
16
+ uuid: str = Field(title='Request Unique Id.')
17
+ audio_file: str
18
+ language: Literal['en', 'zh',]
19
+ using_file_content: bool
20
+
21
+ @model_validator(mode='after')
22
+ def check_audio_file(self):
23
+ if self.using_file_content:
24
+ return self
25
+
26
+ if not Path(self.audio_file).exists():
27
+ raise FileNotFoundError(f'Audio file not exists.')
28
+ return self
29
+
30
+
31
+ class TranscribeOutputMessage(BaseModel):
32
+ uuid: str
33
+ if_success: bool
34
+ msg: str
35
+ transcribed_text: str = Field(default='')
36
+
37
+
38
+ class TranscribeConsumer(BasicMessageReceiver):
39
+
40
+ def __init__(self):
41
+ super().__init__()
42
+
43
+ self.exchange_name = 'transcribe'
44
+ self.queue_name = 'transcribe-input'
45
+ self.routing_key = 'transcribe-input'
46
+
47
+ self.setup_consume_parameters()
48
+ self.setup_message_sender()
49
+
50
+ model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3')
51
+ # Run on GPU with FP16
52
+ self.asr_model = WhisperModel(model_size, device='cuda', compute_type='float16')
53
+
54
+ def setup_consume_parameters(self):
55
+ self.declare_exchange(self.exchange_name)
56
+ self.declare_queue(self.queue_name, max_priority=-1)
57
+ self.bind_queue(self.exchange_name, self.queue_name, self.routing_key)
58
+
59
+ def setup_message_sender(self):
60
+ self.sender = BasicMessageSender()
61
+
62
+ def send_message(self, message: dict):
63
+ routing_key = 'transcribe-output'
64
+ headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
65
+ self.sender.send_message(
66
+ exchange_name=self.exchange_name,
67
+ routing_key=routing_key,
68
+ body=message,
69
+ headers=None
70
+ )
71
+
72
+ def send_success_message(self, uuid: str, transcribed_text):
73
+ message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
74
+ transcribed_text=transcribed_text)
75
+ self.send_message(message.model_dump())
76
+
77
+ def send_fail_message(self, uuid: str, error: str):
78
+ message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
79
+ self.send_message(message.model_dump())
80
+
81
+ def consume(self, channel, method, properties, message):
82
+ body = self.decode_message(message)
83
+
84
+ try:
85
+ validated_message = TranscribeInputMessage.model_validate(body)
86
+
87
+ audio_file = validated_message.audio_file
88
+ if validated_message.using_file_content:
89
+ audio_file = io.BytesIO(base64.b64decode(validated_message.audio_file))
90
+
91
+ segments, _ = self.asr_model.transcribe(audio_file, language=validated_message.language)
92
+
93
+ transcribed_text = ''
94
+ for segment in segments:
95
+ transcribed_text = segment.text
96
+ break
97
+ except ValidationError as exc:
98
+ logging.exception('Consume message failed: \n message: %s\n\n exception info: %s', message, exc)
99
+ self.send_fail_message(body.get('uuid'), f'{exc}')
100
+ except Exception as exc:
101
+ logging.exception('Consume message failed: \n message: %s\n\n exception info: %s', message, exc)
102
+ self.send_fail_message(body.get('uuid'), f'{exc}')
103
+ else:
104
+ self.send_success_message(validated_message.uuid, transcribed_text)
consumer/base.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import functools
3
+ import json
4
+ import logging
5
+ import os
6
+ import ssl
7
+ import time
8
+ from enum import Enum
9
+ from typing import Dict, Optional, Literal
10
+
11
+ import msgpack
12
+ import pika
13
+ from pika.exceptions import AMQPConnectionError
14
+ from pydantic import BaseModel, field_validator
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def sync(f):
20
+ @functools.wraps(f)
21
+ def wrapper(*args, **kwargs):
22
+ return asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
23
+
24
+ return wrapper
25
+
26
+
27
+ class Priority(Enum):
28
+ LOW = 1
29
+ NORMAL = 5
30
+ HIGH = 10
31
+
32
+
33
+ class Headers(BaseModel):
34
+ job_id: str
35
+ priority: Priority
36
+ task_type: Optional[str] = None
37
+
38
+ @field_validator('priority', mode='before')
39
+ @classmethod
40
+ def _convert_priority(cls, value):
41
+ return Priority[value]
42
+
43
+
44
+ class RabbitMQConfig(BaseModel):
45
+ host: str
46
+ port: int
47
+ username: str
48
+ password: str
49
+ protocol: str
50
+
51
+
52
+ class BasicPikaClient:
53
+ def __init__(self):
54
+ self.username = os.environ.get('RABBITMQ_USER', '')
55
+ self.password = os.environ.get('RABBITMQ_PASSWD', '')
56
+ self.host = os.environ.get('RABBITMQ_HOST', 'localhost')
57
+ self.port = os.environ.get('RABBITMQ_PORT', 5672)
58
+ self.protocol = "amqp"
59
+
60
+ self._init_connection_parameters()
61
+ self._connect()
62
+
63
+ def _connect(self):
64
+ tries = 0
65
+ while True:
66
+ try:
67
+ self.connection = pika.BlockingConnection(self.parameters)
68
+ self.channel = self.connection.channel()
69
+ if self.connection.is_open:
70
+ break
71
+ except (AMQPConnectionError, Exception) as e:
72
+ time.sleep(5)
73
+ tries += 1
74
+ if tries == 20:
75
+ raise AMQPConnectionError(e)
76
+
77
+ def _init_connection_parameters(self):
78
+ if any([self.username, self.password]):
79
+ self.credentials = pika.PlainCredentials(self.username, self.password)
80
+ self.parameters = pika.ConnectionParameters(
81
+ host=self.host,
82
+ port=int(self.port),
83
+ virtual_host="/",
84
+ credentials=self.credentials,
85
+ )
86
+ else:
87
+ self.parameters = pika.ConnectionParameters(
88
+ self.host,
89
+ int(self.port),
90
+ "/",
91
+ )
92
+
93
+ if self.protocol == "amqps":
94
+ # SSL Context for TLS configuration of Amazon MQ for RabbitMQ
95
+ ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
96
+ ssl_context.set_ciphers("ECDHE+AESGCM:!ECDSA")
97
+ self.parameters.ssl_options = pika.SSLOptions(context=ssl_context)
98
+
99
+ def check_connection(self):
100
+ if not self.connection or self.connection.is_closed:
101
+ self._connect()
102
+
103
+ def close(self):
104
+ self.channel.close()
105
+ self.connection.close()
106
+
107
+ def declare_queue(
108
+ self, queue_name, exclusive: bool = False, max_priority: int = 10
109
+ ):
110
+ self.check_connection()
111
+ logger.debug(f"Trying to declare queue({queue_name})...")
112
+
113
+ self.channel.queue_declare(
114
+ queue=queue_name,
115
+ exclusive=exclusive,
116
+ durable=True,
117
+ arguments={"x-max-priority": max_priority} if max_priority > 0 else None
118
+ )
119
+
120
+ def declare_exchange(self, exchange_name: str, exchange_type: str = "direct"):
121
+ self.check_connection()
122
+ self.channel.exchange_declare(
123
+ exchange=exchange_name, exchange_type=exchange_type
124
+ )
125
+
126
+ def bind_queue(self, exchange_name: str, queue_name: str, routing_key: str):
127
+ self.check_connection()
128
+ self.channel.queue_bind(
129
+ exchange=exchange_name, queue=queue_name, routing_key=routing_key
130
+ )
131
+
132
+ def unbind_queue(self, exchange_name: str, queue_name: str, routing_key: str):
133
+ self.channel.queue_unbind(
134
+ queue=queue_name, exchange=exchange_name, routing_key=routing_key
135
+ )
136
+
137
+
138
+ class BasicMessageSender(BasicPikaClient):
139
+ message_encoding_type: Literal['bytes', 'json'] = 'json'
140
+
141
+ def encode_message(self, body: Dict, encoding_type: str = "bytes"):
142
+ if encoding_type == "bytes":
143
+ return msgpack.packb(body)
144
+ elif encoding_type == "json":
145
+ return json.dumps(body)
146
+ else:
147
+ raise NotImplementedError
148
+
149
+ def send_message(
150
+ self,
151
+ exchange_name: str,
152
+ routing_key: str,
153
+ body: Dict,
154
+ headers: Optional[Headers],
155
+ ):
156
+ body = self.encode_message(body=body, encoding_type=self.message_encoding_type)
157
+
158
+ properties = pika.BasicProperties(delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE,
159
+ priority=headers.priority.value,
160
+ headers=headers.model_dump() if headers else None)
161
+ self.channel.basic_publish(
162
+ exchange=exchange_name,
163
+ routing_key=routing_key,
164
+ body=body,
165
+ properties=properties,
166
+ )
167
+ logger.debug(
168
+ f"Sent message. Exchange: {exchange_name}, Routing Key: {routing_key}, Body: {body[:128]}"
169
+ )
170
+
171
+
172
+ class BasicMessageReceiver(BasicPikaClient):
173
+ def __init__(self):
174
+ super().__init__()
175
+ self.channel_tag = None
176
+
177
+ def decode_message(self, body):
178
+ if type(body) == bytes:
179
+ return msgpack.unpackb(body)
180
+ elif type(body) == str:
181
+ return json.loads(body)
182
+ else:
183
+ raise NotImplementedError
184
+
185
+ def get_message(self, queue_name: str, auto_ack: bool = False):
186
+ method_frame, header_frame, body = self.channel.basic_get(
187
+ queue=queue_name, auto_ack=auto_ack
188
+ )
189
+ if method_frame:
190
+ logger.debug(f"{method_frame}, {header_frame}, {body}")
191
+ return method_frame, header_frame, body
192
+ else:
193
+ logger.debug("No message returned")
194
+ return None
195
+
196
+ def consume_messages(self, queue, callback):
197
+ self.check_connection()
198
+ self.channel_tag = self.channel.basic_consume(
199
+ queue=queue, on_message_callback=callback, auto_ack=True
200
+ )
201
+ logger.debug(" [*] Waiting for messages. To exit press CTRL+C")
202
+ self.channel.start_consuming()
203
+
204
+ def cancel_consumer(self):
205
+ if self.channel_tag is not None:
206
+ self.channel.basic_cancel(self.channel_tag)
207
+ self.channel_tag = None
208
+ else:
209
+ logger.error("Do not cancel a non-existing job")
210
+
211
+
212
+ class ExtractFaceFramesConsumer(BasicMessageReceiver):
213
+ @sync
214
+ async def consume(self, channel, method, properties, body):
215
+ body = self.decode_message(body=body)
216
+ file_content = await self._download_image(img_url=body["url"])
217
+ # consume message logic ...
218
+
219
+ async def _download_image(self, img_url):
220
+ # do some async stuff here
221
+ pass
222
+
223
+
224
+ def create_consumer():
225
+ worker = ExtractFaceFramesConsumer()
226
+ worker.declare_queue(queue_name="myqueue")
227
+ worker.declare_exchange(exchange_name="myexchange")
228
+ worker.bind_queue(
229
+ exchange_name="myexchange", queue_name="myqueue", routing_key="randomkey"
230
+ )
231
+ worker.consume_messages(queue="myqueue", callback=worker.consume)
232
+
233
+
234
+ if __name__ == "__main__":
235
+ create_consumer()
requirements.txt CHANGED
@@ -3,3 +3,4 @@ fastapi==0.115.5
3
  python-multipart==0.0.17
4
  websockets==14.1
5
  uvicorn==0.32.1
 
 
3
  python-multipart==0.0.17
4
  websockets==14.1
5
  uvicorn==0.32.1
6
+ pika==1.3.2