Spaces:
Runtime error
Runtime error
| from urllib.parse import urlparse, urlencode | |
| from wsgiref.handlers import format_date_time | |
| from datetime import datetime | |
| from loguru import logger | |
| import hmac | |
| import hashlib | |
| import base64 | |
| import numpy as np | |
| import json | |
| import websockets.client | |
| from websockets.exceptions import ConnectionClosedError, InvalidStatusCode | |
| import time | |
| STATUS_FIRST_FRAME = 0 | |
| STATUS_CONTINUE_FRAME = 1 | |
| STATUS_LAST_FRAME = 2 | |
| class TTSClient: | |
| def __init__( | |
| self, | |
| app_id: str, | |
| api_key: str, | |
| api_secret: str, | |
| endpoint="wss://ws-api.xfyun.cn/v2/tts", | |
| ) -> None: | |
| self.app_id = app_id | |
| self.api_key = api_key | |
| self.api_secret = api_secret | |
| self.endpoint = endpoint | |
| self.common_args = {"app_id": self.app_id} | |
| def prepare_data(self, text: str, sampling_rate=16000): | |
| business_args = { | |
| "aue": "raw", | |
| "auf": f"audio/L16;rate={sampling_rate}", | |
| "vcn": "xiaoyan", | |
| "tte": "utf8", | |
| } | |
| result = { | |
| "common": self.common_args, | |
| "business": business_args, | |
| "data": { | |
| "status": 2, | |
| "text": str(base64.b64encode(text.encode("utf-8")), "UTF8"), | |
| }, | |
| } | |
| logger.debug(f"Data: {result}") | |
| return result | |
| def create_url(self): | |
| parse_result = urlparse(self.endpoint) | |
| host = parse_result.hostname | |
| # RFC1123 Timestamp | |
| date = format_date_time(time.mktime(datetime.now().timetuple())) | |
| path = parse_result.path | |
| sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1" | |
| logger.debug(f"Sign raw string: {sign_raw_str}") | |
| sign_sha = hmac.new( | |
| self.api_secret.encode("utf-8"), | |
| sign_raw_str.encode("utf-8"), | |
| digestmod=hashlib.sha256, | |
| ).digest() | |
| sign_sha = base64.b64encode(sign_sha).decode("utf-8") | |
| auth_raw_str = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % ( | |
| self.api_key, | |
| "hmac-sha256", | |
| "host date request-line", | |
| sign_sha, | |
| ) | |
| logger.debug(f"Authorization: {auth_raw_str}") | |
| auth = base64.b64encode(auth_raw_str.encode("utf-8")).decode("utf-8") | |
| params = { | |
| "authorization": auth, | |
| "date": date, | |
| "host": host, | |
| } | |
| url = f"{self.endpoint}?{urlencode(params)}" | |
| logger.debug(f"URL: {url}") | |
| return url | |
| def parse_result(self, result: bytes) -> np.ndarray: | |
| return np.frombuffer(result, dtype=np.int16) | |
| async def generate(self, text: str, sampling_rate=16000): | |
| logger.debug("Generate URL") | |
| url = self.create_url() | |
| logger.debug("Preparing Data") | |
| data = self.prepare_data(text, sampling_rate) | |
| result = bytearray() | |
| try: | |
| async with websockets.client.connect(url) as ws: | |
| logger.debug("Sending Data") | |
| await ws.send(json.dumps(data)) | |
| while True: | |
| try: | |
| message = await ws.recv() | |
| message = json.loads(message) | |
| logger.debug(f"Received message: {message}") | |
| audio = message["data"]["audio"] | |
| logger.debug(f"Received audio length: {len(audio)}") | |
| audio = base64.b64decode(audio) | |
| status = message["data"]["status"] | |
| result += audio | |
| if status == STATUS_LAST_FRAME: | |
| break | |
| except ConnectionClosedError: | |
| break | |
| except InvalidStatusCode as e: | |
| logger.error(f"Error: {e}") | |
| raise e | |
| logger.success("Audio generation finished") | |
| return sampling_rate, self.parse_result(bytes(result)) | |
| __all__ = ["TTSClient"] | |