Spaces:
Runtime error
Runtime error
File size: 3,951 Bytes
a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 a7c4935 177a062 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | from urllib.parse import urlparse, urlencode
from wsgiref.handlers import format_date_time
from datetime import datetime
from loguru import logger
import hmac
import hashlib
import base64
import numpy as np
import json
import websockets.client
from websockets.exceptions import ConnectionClosedError, InvalidStatusCode
import time
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2
class TTSClient:
def __init__(
self,
app_id: str,
api_key: str,
api_secret: str,
endpoint="wss://ws-api.xfyun.cn/v2/tts",
) -> None:
self.app_id = app_id
self.api_key = api_key
self.api_secret = api_secret
self.endpoint = endpoint
self.common_args = {"app_id": self.app_id}
def prepare_data(self, text: str, sampling_rate=16000):
business_args = {
"aue": "raw",
"auf": f"audio/L16;rate={sampling_rate}",
"vcn": "xiaoyan",
"tte": "utf8",
}
result = {
"common": self.common_args,
"business": business_args,
"data": {
"status": 2,
"text": str(base64.b64encode(text.encode("utf-8")), "UTF8"),
},
}
logger.debug(f"Data: {result}")
return result
def create_url(self):
parse_result = urlparse(self.endpoint)
host = parse_result.hostname
# RFC1123 Timestamp
date = format_date_time(time.mktime(datetime.now().timetuple()))
path = parse_result.path
sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
logger.debug(f"Sign raw string: {sign_raw_str}")
sign_sha = hmac.new(
self.api_secret.encode("utf-8"),
sign_raw_str.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
sign_sha = base64.b64encode(sign_sha).decode("utf-8")
auth_raw_str = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
self.api_key,
"hmac-sha256",
"host date request-line",
sign_sha,
)
logger.debug(f"Authorization: {auth_raw_str}")
auth = base64.b64encode(auth_raw_str.encode("utf-8")).decode("utf-8")
params = {
"authorization": auth,
"date": date,
"host": host,
}
url = f"{self.endpoint}?{urlencode(params)}"
logger.debug(f"URL: {url}")
return url
def parse_result(self, result: bytes) -> np.ndarray:
return np.frombuffer(result, dtype=np.int16)
async def generate(self, text: str, sampling_rate=16000):
logger.debug("Generate URL")
url = self.create_url()
logger.debug("Preparing Data")
data = self.prepare_data(text, sampling_rate)
result = bytearray()
try:
async with websockets.client.connect(url) as ws:
logger.debug("Sending Data")
await ws.send(json.dumps(data))
while True:
try:
message = await ws.recv()
message = json.loads(message)
logger.debug(f"Received message: {message}")
audio = message["data"]["audio"]
logger.debug(f"Received audio length: {len(audio)}")
audio = base64.b64decode(audio)
status = message["data"]["status"]
result += audio
if status == STATUS_LAST_FRAME:
break
except ConnectionClosedError:
break
except InvalidStatusCode as e:
logger.error(f"Error: {e}")
raise e
logger.success("Audio generation finished")
return sampling_rate, self.parse_result(bytes(result))
__all__ = ["TTSClient"]
|