Spaces:
Runtime error
Runtime error
| from urllib.parse import urlparse, urlencode | |
| from wsgiref.handlers import format_date_time | |
| from datetime import datetime | |
| import hmac | |
| import hashlib | |
| import base64 | |
| import numpy as np | |
| import json | |
| import websockets.client | |
| from loguru import logger | |
| from websockets.exceptions import ConnectionClosedError | |
| import time | |
| STATUS_FIRST_FRAME = 0 | |
| STATUS_CONTINUE_FRAME = 1 | |
| STATUS_LAST_FRAME = 2 | |
| ''' | |
| XFYun's IAT Client, which is used to convert speech to text. | |
| ''' | |
| class IATClient: | |
| def __init__( | |
| self, | |
| app_id: str, | |
| api_key: str, | |
| api_secret: str, | |
| endpoint="wss://ws-api.xfyun.cn/v2/iat", | |
| ) -> None: | |
| self.app_id = app_id | |
| self.api_key = api_key | |
| self.api_secret = api_secret | |
| self.endpoint = endpoint | |
| self.common_args = {"app_id": self.app_id} | |
| self.business_args = { | |
| "domain": "iat", | |
| "language": "zh_cn", | |
| "accent": "mandarin", | |
| "vinfo": 1, | |
| "vad_eos": 10000, | |
| } | |
| # To convert ndarray audio data to PCM-style bytes | |
| # Gradio Audio Module returns a tuple of (sampling_rate, np.ndarray) | |
| # And the np.ndarray is the audio data, which is range from -32768 to 32767 matching PCM range. | |
| def encode_pcm(self, source: np.ndarray): | |
| return source.astype(np.int16).tobytes() | |
| def create_url(self): | |
| parse_result = urlparse(self.endpoint) | |
| host = parse_result.hostname | |
| # RFC1123 Timestamp | |
| date = format_date_time(time.mktime(datetime.now().timetuple())) | |
| path = parse_result.path | |
| sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1" | |
| logger.debug(f"Sign raw string: {sign_raw_str}") | |
| sign_sha = hmac.new( | |
| self.api_secret.encode("utf-8"), | |
| sign_raw_str.encode("utf-8"), | |
| digestmod=hashlib.sha256, | |
| ).digest() | |
| sign_sha = base64.b64encode(sign_sha).decode("utf-8") | |
| auth_raw_str = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % ( | |
| self.api_key, | |
| "hmac-sha256", | |
| "host date request-line", | |
| sign_sha, | |
| ) | |
| logger.debug(f"Authorization: {auth_raw_str}") | |
| auth = base64.b64encode(auth_raw_str.encode("utf-8")).decode("utf-8") | |
| params = { | |
| "authorization": auth, | |
| "date": date, | |
| "host": host, | |
| } | |
| url = f"{self.endpoint}?{urlencode(params)}" | |
| logger.debug(f"URL: {url}") | |
| return url | |
| def prepare_data(self, audio: bytes, chunk_size=1280, sampling_rate=16000): | |
| status = STATUS_FIRST_FRAME | |
| logger.debug(f"Total audio length: {len(audio)}") | |
| for i in range(0, len(audio), chunk_size): | |
| logger.debug(f"Processing chunk {i} to {i + chunk_size}") | |
| chunk = audio[i : i + chunk_size] | |
| if i + chunk_size >= len(audio): | |
| status = STATUS_LAST_FRAME | |
| data = { | |
| "status": status, | |
| "format": f"audio/L16;rate={sampling_rate}", | |
| "audio": base64.b64encode(chunk).decode("utf-8"), | |
| "encoding": "raw", | |
| } | |
| payload = {"data": data} | |
| if status == STATUS_FIRST_FRAME: | |
| payload["common"] = self.common_args | |
| payload["business"] = self.business_args | |
| yield payload | |
| status = STATUS_CONTINUE_FRAME | |
| async def dictate(self, audio: tuple[int, np.ndarray], interval=0.04): | |
| logger.debug(f"Generate URL") | |
| url = self.create_url() | |
| logger.debug("Encoding audio to PCM") | |
| sampling_rate, source = audio | |
| pcm = self.encode_pcm(source) | |
| async with websockets.client.connect(url) as ws: | |
| for payload in self.prepare_data(pcm, sampling_rate=sampling_rate): | |
| logger.debug('Sending payload') | |
| await ws.send(json.dumps(payload)) | |
| time.sleep(interval) | |
| try: | |
| async for message in ws: | |
| data: dict = json.loads(message) | |
| logger.debug(f"Received data: {data}") | |
| if not 'data' in data.keys(): | |
| yield '' | |
| break | |
| is_end = data["data"]["status"] == STATUS_LAST_FRAME | |
| ws_list = data["data"]["result"]["ws"] | |
| text = ''.join([cw["w"] for cw in sum([ws["cw"] for ws in ws_list], [])]) | |
| yield text | |
| if is_end: | |
| break | |
| except ConnectionClosedError as e: | |
| print(f"Connection closed: {e.code} {e.reason}") | |