File size: 3,951 Bytes
a7c4935
 
 
177a062
a7c4935
 
 
 
 
 
177a062
a7c4935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177a062
 
 
a7c4935
177a062
a7c4935
 
 
177a062
a7c4935
177a062
a7c4935
 
 
 
 
177a062
 
a7c4935
 
 
 
 
 
 
 
 
177a062
a7c4935
 
 
 
 
177a062
 
a7c4935
 
 
 
 
177a062
 
a7c4935
 
 
 
 
 
177a062
a7c4935
 
 
 
 
177a062
 
a7c4935
177a062
 
a7c4935
177a062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7c4935
177a062
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from urllib.parse import urlparse, urlencode
from wsgiref.handlers import format_date_time
from datetime import datetime
from loguru import logger
import hmac
import hashlib
import base64
import numpy as np
import json
import websockets.client
from websockets.exceptions import ConnectionClosedError, InvalidStatusCode
import time

STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2


class TTSClient:
    def __init__(
        self,
        app_id: str,
        api_key: str,
        api_secret: str,
        endpoint="wss://ws-api.xfyun.cn/v2/tts",
    ) -> None:
        self.app_id = app_id
        self.api_key = api_key
        self.api_secret = api_secret
        self.endpoint = endpoint
        self.common_args = {"app_id": self.app_id}

    def prepare_data(self, text: str, sampling_rate=16000):
        business_args = {
            "aue": "raw",
            "auf": f"audio/L16;rate={sampling_rate}",
            "vcn": "xiaoyan",
            "tte": "utf8",
        }
        result = {
            "common": self.common_args,
            "business": business_args,
            "data": {
                "status": 2,
                "text": str(base64.b64encode(text.encode("utf-8")), "UTF8"),
            },
        }
        logger.debug(f"Data: {result}")
        return result

    def create_url(self):
        parse_result = urlparse(self.endpoint)
        host = parse_result.hostname
        # RFC1123 Timestamp
        date = format_date_time(time.mktime(datetime.now().timetuple()))
        path = parse_result.path

        sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
        logger.debug(f"Sign raw string: {sign_raw_str}")
        sign_sha = hmac.new(
            self.api_secret.encode("utf-8"),
            sign_raw_str.encode("utf-8"),
            digestmod=hashlib.sha256,
        ).digest()
        sign_sha = base64.b64encode(sign_sha).decode("utf-8")
        auth_raw_str = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
            self.api_key,
            "hmac-sha256",
            "host date request-line",
            sign_sha,
        )
        logger.debug(f"Authorization: {auth_raw_str}")
        auth = base64.b64encode(auth_raw_str.encode("utf-8")).decode("utf-8")
        params = {
            "authorization": auth,
            "date": date,
            "host": host,
        }
        url = f"{self.endpoint}?{urlencode(params)}"
        logger.debug(f"URL: {url}")
        return url

    def parse_result(self, result: bytes) -> np.ndarray:
        return np.frombuffer(result, dtype=np.int16)

    async def generate(self, text: str, sampling_rate=16000):
        logger.debug("Generate URL")
        url = self.create_url()
        logger.debug("Preparing Data")
        data = self.prepare_data(text, sampling_rate)
        result = bytearray()
        try:
            async with websockets.client.connect(url) as ws:
                logger.debug("Sending Data")
                await ws.send(json.dumps(data))
                while True:
                    try:
                        message = await ws.recv()
                        message = json.loads(message)
                        logger.debug(f"Received message: {message}")
                        audio = message["data"]["audio"]
                        logger.debug(f"Received audio length: {len(audio)}")
                        audio = base64.b64decode(audio)
                        status = message["data"]["status"]
                        result += audio
                        if status == STATUS_LAST_FRAME:
                            break
                    except ConnectionClosedError:
                        break
        except InvalidStatusCode as e:
            logger.error(f"Error: {e}")
            raise e
        logger.success("Audio generation finished")
        return sampling_rate, self.parse_result(bytes(result))


__all__ = ["TTSClient"]