Spaces:
Runtime error
Runtime error
Yi Jin commited on
Uploaded beta code
Browse files- .gitignore +2 -0
- .vscode/settings.json +0 -4
- app.py +124 -0
- iat.py +24 -4
- requirements.txt +2 -1
- tts.py +45 -26
.gitignore
CHANGED
|
@@ -177,3 +177,5 @@ pyrightconfig.json
|
|
| 177 |
|
| 178 |
config.yaml
|
| 179 |
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
config.yaml
|
| 179 |
|
| 180 |
+
|
| 181 |
+
.vscode/
|
.vscode/settings.json
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"python.analysis.autoImportCompletions": true,
|
| 3 |
-
"python.analysis.typeCheckingMode": "basic"
|
| 4 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from zhipuai import ZhipuAI
|
| 5 |
+
from zhipuai.api_resource.chat.chat import Chat
|
| 6 |
+
import yaml
|
| 7 |
+
import json
|
| 8 |
+
from iat import IATClient
|
| 9 |
+
from tts import TTSClient
|
| 10 |
+
import numpy as np
|
| 11 |
+
from scipy.signal import resample
|
| 12 |
+
|
| 13 |
+
logger.debug("Loading config")
|
| 14 |
+
config_env = os.environ.get("CONFIG", "")
|
| 15 |
+
if config_env:
|
| 16 |
+
logger.debug("Using environment variable for config")
|
| 17 |
+
config = json.loads(config_env)
|
| 18 |
+
else:
|
| 19 |
+
logger.debug("Reading config from file")
|
| 20 |
+
with open("config.yaml", "r") as f:
|
| 21 |
+
try:
|
| 22 |
+
config = yaml.safe_load(f)["config"]
|
| 23 |
+
except yaml.YAMLError as e:
|
| 24 |
+
logger.error(e)
|
| 25 |
+
raise e
|
| 26 |
+
|
| 27 |
+
zhipuai_config = config["zhipuai"]
|
| 28 |
+
xfyun_config = config["xfyun"]
|
| 29 |
+
|
| 30 |
+
zhipuai = ZhipuAI(api_key=zhipuai_config["apikey"])
|
| 31 |
+
iat = IATClient(
|
| 32 |
+
xfyun_config["iat"]["appid"],
|
| 33 |
+
xfyun_config["iat"]["apikey"],
|
| 34 |
+
xfyun_config["iat"]["apisecret"],
|
| 35 |
+
)
|
| 36 |
+
tts = TTSClient(
|
| 37 |
+
xfyun_config["tts"]["appid"],
|
| 38 |
+
xfyun_config["tts"]["apikey"],
|
| 39 |
+
xfyun_config["tts"]["apisecret"],
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_zhipuai_history(history: list[list[str]]):
|
| 44 |
+
result = [{"role": "system", "content": config["zhipuai"]["prompt"]}]
|
| 45 |
+
for history_element in history:
|
| 46 |
+
user_message, assistant_message = history_element
|
| 47 |
+
if user_message != None:
|
| 48 |
+
result += [{"role": "user", "content": user_message}]
|
| 49 |
+
if assistant_message != None:
|
| 50 |
+
result += [{"role": "assistant", "content": assistant_message}]
|
| 51 |
+
return result
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def add_text(history, text):
|
| 55 |
+
history = history + [(text, None)]
|
| 56 |
+
return history, gr.Textbox(value="", interactive=False)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def bot(history):
|
| 60 |
+
zhipuai_history = build_zhipuai_history(history)
|
| 61 |
+
res = zhipuai.chat.completions.create(
|
| 62 |
+
model="glm-4", messages=zhipuai_history, stream=True
|
| 63 |
+
)
|
| 64 |
+
history[-1][1] = ""
|
| 65 |
+
for chunk in res:
|
| 66 |
+
history[-1][1] += chunk.choices[0].delta.content
|
| 67 |
+
yield history
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
async def generate_text(audio: tuple[int, np.ndarray]):
|
| 71 |
+
logger.debug(f"Generating text from audio")
|
| 72 |
+
logger.debug(f"Sampling rate: {audio[0]}, resampling to 16000")
|
| 73 |
+
audio = (16000, resample(audio[1], 16000))
|
| 74 |
+
result_list = []
|
| 75 |
+
async for result in iat.dictate(audio):
|
| 76 |
+
logger.debug(f"Result: {result}")
|
| 77 |
+
result_list.append(result)
|
| 78 |
+
return "".join(result_list)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
async def generate_audio(history: list[list[str]]):
|
| 82 |
+
logger.debug(f"Generating audio from text")
|
| 83 |
+
text = history[-1][-1]
|
| 84 |
+
result = await tts.generate(text)
|
| 85 |
+
return result
|
| 86 |
+
|
| 87 |
+
with gr.Blocks() as demo:
|
| 88 |
+
title = gr.Markdown("# 老王元宇宙受害者")
|
| 89 |
+
|
| 90 |
+
chatbot = gr.Chatbot(
|
| 91 |
+
[],
|
| 92 |
+
elem_id="chatbot",
|
| 93 |
+
bubble_full_width=False,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
with gr.Row():
|
| 97 |
+
txt = gr.Textbox(
|
| 98 |
+
scale=4,
|
| 99 |
+
show_label=False,
|
| 100 |
+
placeholder="Enter text and press enter",
|
| 101 |
+
container=False,
|
| 102 |
+
)
|
| 103 |
+
submit_button = gr.Button(value="提交", variant="primary")
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
with gr.Column():
|
| 107 |
+
user_title = gr.Markdown("## 用户语音识别")
|
| 108 |
+
user_audio = gr.Audio(type="numpy", sources=['microphone'])
|
| 109 |
+
user_audio_submit = gr.Button(value="上传用户语音并转换", variant="primary")
|
| 110 |
+
with gr.Column():
|
| 111 |
+
user_title = gr.Markdown("## 机器人语音合成")
|
| 112 |
+
bot_audio = gr.Audio()
|
| 113 |
+
bot_audio_submit = gr.Button(value="将机器人最后一个回复转换为语音", variant="primary")
|
| 114 |
+
|
| 115 |
+
user_audio_submit.click(generate_text, [user_audio], outputs=txt)
|
| 116 |
+
bot_audio_submit.click(generate_audio, [chatbot], outputs=bot_audio)
|
| 117 |
+
|
| 118 |
+
txt_msg = submit_button.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 119 |
+
bot, chatbot, chatbot, api_name="bot_response"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
| 123 |
+
|
| 124 |
+
demo.launch()
|
iat.py
CHANGED
|
@@ -7,6 +7,7 @@ import base64
|
|
| 7 |
import numpy as np
|
| 8 |
import json
|
| 9 |
import websockets.client
|
|
|
|
| 10 |
from websockets.exceptions import ConnectionClosedError
|
| 11 |
import time
|
| 12 |
|
|
@@ -52,29 +53,35 @@ class IATClient:
|
|
| 52 |
path = parse_result.path
|
| 53 |
|
| 54 |
sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
|
|
|
|
| 55 |
sign_sha = hmac.new(
|
| 56 |
self.api_secret.encode("utf-8"),
|
| 57 |
sign_raw_str.encode("utf-8"),
|
| 58 |
digestmod=hashlib.sha256,
|
| 59 |
).digest()
|
| 60 |
-
|
|
|
|
| 61 |
self.api_key,
|
| 62 |
"hmac-sha256",
|
| 63 |
"host date request-line",
|
| 64 |
sign_sha,
|
| 65 |
)
|
| 66 |
-
|
|
|
|
| 67 |
params = {
|
| 68 |
"authorization": auth,
|
| 69 |
"date": date,
|
| 70 |
"host": host,
|
| 71 |
}
|
| 72 |
url = f"{self.endpoint}?{urlencode(params)}"
|
|
|
|
| 73 |
return url
|
| 74 |
|
| 75 |
def prepare_data(self, audio: bytes, chunk_size=1280, sampling_rate=16000):
|
| 76 |
status = STATUS_FIRST_FRAME
|
|
|
|
| 77 |
for i in range(0, len(audio), chunk_size):
|
|
|
|
| 78 |
chunk = audio[i : i + chunk_size]
|
| 79 |
if i + chunk_size >= len(audio):
|
| 80 |
status = STATUS_LAST_FRAME
|
|
@@ -91,16 +98,29 @@ class IATClient:
|
|
| 91 |
yield payload
|
| 92 |
status = STATUS_CONTINUE_FRAME
|
| 93 |
|
| 94 |
-
async def dictate(self, audio: tuple[int, np.ndarray], interval=0.
|
|
|
|
| 95 |
url = self.create_url()
|
|
|
|
| 96 |
sampling_rate, source = audio
|
| 97 |
pcm = self.encode_pcm(source)
|
| 98 |
async with websockets.client.connect(url) as ws:
|
| 99 |
for payload in self.prepare_data(pcm, sampling_rate=sampling_rate):
|
|
|
|
| 100 |
await ws.send(json.dumps(payload))
|
| 101 |
time.sleep(interval)
|
| 102 |
try:
|
| 103 |
async for message in ws:
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
except ConnectionClosedError as e:
|
| 106 |
print(f"Connection closed: {e.code} {e.reason}")
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import json
|
| 9 |
import websockets.client
|
| 10 |
+
from loguru import logger
|
| 11 |
from websockets.exceptions import ConnectionClosedError
|
| 12 |
import time
|
| 13 |
|
|
|
|
| 53 |
path = parse_result.path
|
| 54 |
|
| 55 |
sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
|
| 56 |
+
logger.debug(f"Sign raw string: {sign_raw_str}")
|
| 57 |
sign_sha = hmac.new(
|
| 58 |
self.api_secret.encode("utf-8"),
|
| 59 |
sign_raw_str.encode("utf-8"),
|
| 60 |
digestmod=hashlib.sha256,
|
| 61 |
).digest()
|
| 62 |
+
sign_sha = base64.b64encode(sign_sha).decode("utf-8")
|
| 63 |
+
auth_raw_str = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
|
| 64 |
self.api_key,
|
| 65 |
"hmac-sha256",
|
| 66 |
"host date request-line",
|
| 67 |
sign_sha,
|
| 68 |
)
|
| 69 |
+
logger.debug(f"Authorization: {auth_raw_str}")
|
| 70 |
+
auth = base64.b64encode(auth_raw_str.encode("utf-8")).decode("utf-8")
|
| 71 |
params = {
|
| 72 |
"authorization": auth,
|
| 73 |
"date": date,
|
| 74 |
"host": host,
|
| 75 |
}
|
| 76 |
url = f"{self.endpoint}?{urlencode(params)}"
|
| 77 |
+
logger.debug(f"URL: {url}")
|
| 78 |
return url
|
| 79 |
|
| 80 |
def prepare_data(self, audio: bytes, chunk_size=1280, sampling_rate=16000):
|
| 81 |
status = STATUS_FIRST_FRAME
|
| 82 |
+
logger.debug(f"Total audio length: {len(audio)}")
|
| 83 |
for i in range(0, len(audio), chunk_size):
|
| 84 |
+
logger.debug(f"Processing chunk {i} to {i + chunk_size}")
|
| 85 |
chunk = audio[i : i + chunk_size]
|
| 86 |
if i + chunk_size >= len(audio):
|
| 87 |
status = STATUS_LAST_FRAME
|
|
|
|
| 98 |
yield payload
|
| 99 |
status = STATUS_CONTINUE_FRAME
|
| 100 |
|
| 101 |
+
async def dictate(self, audio: tuple[int, np.ndarray], interval=0.04):
|
| 102 |
+
logger.debug(f"Generate URL")
|
| 103 |
url = self.create_url()
|
| 104 |
+
logger.debug("Encoding audio to PCM")
|
| 105 |
sampling_rate, source = audio
|
| 106 |
pcm = self.encode_pcm(source)
|
| 107 |
async with websockets.client.connect(url) as ws:
|
| 108 |
for payload in self.prepare_data(pcm, sampling_rate=sampling_rate):
|
| 109 |
+
logger.debug('Sending payload')
|
| 110 |
await ws.send(json.dumps(payload))
|
| 111 |
time.sleep(interval)
|
| 112 |
try:
|
| 113 |
async for message in ws:
|
| 114 |
+
data: dict = json.loads(message)
|
| 115 |
+
logger.debug(f"Received data: {data}")
|
| 116 |
+
if not 'data' in data.keys():
|
| 117 |
+
yield ''
|
| 118 |
+
break
|
| 119 |
+
is_end = data["data"]["status"] == STATUS_LAST_FRAME
|
| 120 |
+
ws_list = data["data"]["result"]["ws"]
|
| 121 |
+
text = ''.join([cw["w"] for cw in sum([ws["cw"] for ws in ws_list], [])])
|
| 122 |
+
yield text
|
| 123 |
+
if is_end:
|
| 124 |
+
break
|
| 125 |
except ConnectionClosedError as e:
|
| 126 |
print(f"Connection closed: {e.code} {e.reason}")
|
requirements.txt
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
gradio
|
| 2 |
jupyter
|
| 3 |
requests
|
| 4 |
-
|
| 5 |
zhipuai
|
| 6 |
loguru
|
| 7 |
numpy
|
|
|
|
|
|
| 1 |
gradio
|
| 2 |
jupyter
|
| 3 |
requests
|
| 4 |
+
websockets
|
| 5 |
zhipuai
|
| 6 |
loguru
|
| 7 |
numpy
|
| 8 |
+
scipy
|
tts.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
from urllib.parse import urlparse, urlencode
|
| 2 |
from wsgiref.handlers import format_date_time
|
| 3 |
from datetime import datetime
|
|
|
|
| 4 |
import hmac
|
| 5 |
import hashlib
|
| 6 |
import base64
|
| 7 |
import numpy as np
|
| 8 |
import json
|
| 9 |
import websockets.client
|
| 10 |
-
from websockets.exceptions import ConnectionClosedError
|
| 11 |
import time
|
| 12 |
|
| 13 |
STATUS_FIRST_FRAME = 0
|
|
@@ -16,7 +17,6 @@ STATUS_LAST_FRAME = 2
|
|
| 16 |
|
| 17 |
|
| 18 |
class TTSClient:
|
| 19 |
-
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
app_id: str,
|
|
@@ -29,22 +29,24 @@ class TTSClient:
|
|
| 29 |
self.api_secret = api_secret
|
| 30 |
self.endpoint = endpoint
|
| 31 |
self.common_args = {"app_id": self.app_id}
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
"aue": "raw",
|
| 34 |
-
"auf": "audio/L16;rate=
|
| 35 |
"vcn": "xiaoyan",
|
| 36 |
"tte": "utf8",
|
| 37 |
}
|
| 38 |
-
|
| 39 |
-
def prepare_data(self, text: str):
|
| 40 |
-
return {
|
| 41 |
"common": self.common_args,
|
| 42 |
-
"business":
|
| 43 |
"data": {
|
| 44 |
"status": 2,
|
| 45 |
"text": str(base64.b64encode(text.encode("utf-8")), "UTF8"),
|
| 46 |
},
|
| 47 |
}
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def create_url(self):
|
| 50 |
parse_result = urlparse(self.endpoint)
|
|
@@ -54,45 +56,62 @@ class TTSClient:
|
|
| 54 |
path = parse_result.path
|
| 55 |
|
| 56 |
sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
|
|
|
|
| 57 |
sign_sha = hmac.new(
|
| 58 |
self.api_secret.encode("utf-8"),
|
| 59 |
sign_raw_str.encode("utf-8"),
|
| 60 |
digestmod=hashlib.sha256,
|
| 61 |
).digest()
|
| 62 |
-
|
|
|
|
| 63 |
self.api_key,
|
| 64 |
"hmac-sha256",
|
| 65 |
"host date request-line",
|
| 66 |
sign_sha,
|
| 67 |
)
|
| 68 |
-
|
|
|
|
| 69 |
params = {
|
| 70 |
"authorization": auth,
|
| 71 |
"date": date,
|
| 72 |
"host": host,
|
| 73 |
}
|
| 74 |
url = f"{self.endpoint}?{urlencode(params)}"
|
|
|
|
| 75 |
return url
|
| 76 |
|
| 77 |
def parse_result(self, result: bytes) -> np.ndarray:
|
| 78 |
return np.frombuffer(result, dtype=np.int16)
|
| 79 |
|
| 80 |
-
async def generate(self, text: str):
|
|
|
|
| 81 |
url = self.create_url()
|
| 82 |
-
|
|
|
|
| 83 |
result = bytearray()
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
break
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from urllib.parse import urlparse, urlencode
|
| 2 |
from wsgiref.handlers import format_date_time
|
| 3 |
from datetime import datetime
|
| 4 |
+
from loguru import logger
|
| 5 |
import hmac
|
| 6 |
import hashlib
|
| 7 |
import base64
|
| 8 |
import numpy as np
|
| 9 |
import json
|
| 10 |
import websockets.client
|
| 11 |
+
from websockets.exceptions import ConnectionClosedError, InvalidStatusCode
|
| 12 |
import time
|
| 13 |
|
| 14 |
STATUS_FIRST_FRAME = 0
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class TTSClient:
|
|
|
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
app_id: str,
|
|
|
|
| 29 |
self.api_secret = api_secret
|
| 30 |
self.endpoint = endpoint
|
| 31 |
self.common_args = {"app_id": self.app_id}
|
| 32 |
+
|
| 33 |
+
def prepare_data(self, text: str, sampling_rate=16000):
|
| 34 |
+
business_args = {
|
| 35 |
"aue": "raw",
|
| 36 |
+
"auf": f"audio/L16;rate={sampling_rate}",
|
| 37 |
"vcn": "xiaoyan",
|
| 38 |
"tte": "utf8",
|
| 39 |
}
|
| 40 |
+
result = {
|
|
|
|
|
|
|
| 41 |
"common": self.common_args,
|
| 42 |
+
"business": business_args,
|
| 43 |
"data": {
|
| 44 |
"status": 2,
|
| 45 |
"text": str(base64.b64encode(text.encode("utf-8")), "UTF8"),
|
| 46 |
},
|
| 47 |
}
|
| 48 |
+
logger.debug(f"Data: {result}")
|
| 49 |
+
return result
|
| 50 |
|
| 51 |
def create_url(self):
|
| 52 |
parse_result = urlparse(self.endpoint)
|
|
|
|
| 56 |
path = parse_result.path
|
| 57 |
|
| 58 |
sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
|
| 59 |
+
logger.debug(f"Sign raw string: {sign_raw_str}")
|
| 60 |
sign_sha = hmac.new(
|
| 61 |
self.api_secret.encode("utf-8"),
|
| 62 |
sign_raw_str.encode("utf-8"),
|
| 63 |
digestmod=hashlib.sha256,
|
| 64 |
).digest()
|
| 65 |
+
sign_sha = base64.b64encode(sign_sha).decode("utf-8")
|
| 66 |
+
auth_raw_str = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
|
| 67 |
self.api_key,
|
| 68 |
"hmac-sha256",
|
| 69 |
"host date request-line",
|
| 70 |
sign_sha,
|
| 71 |
)
|
| 72 |
+
logger.debug(f"Authorization: {auth_raw_str}")
|
| 73 |
+
auth = base64.b64encode(auth_raw_str.encode("utf-8")).decode("utf-8")
|
| 74 |
params = {
|
| 75 |
"authorization": auth,
|
| 76 |
"date": date,
|
| 77 |
"host": host,
|
| 78 |
}
|
| 79 |
url = f"{self.endpoint}?{urlencode(params)}"
|
| 80 |
+
logger.debug(f"URL: {url}")
|
| 81 |
return url
|
| 82 |
|
| 83 |
def parse_result(self, result: bytes) -> np.ndarray:
|
| 84 |
return np.frombuffer(result, dtype=np.int16)
|
| 85 |
|
| 86 |
+
async def generate(self, text: str, sampling_rate=16000):
|
| 87 |
+
logger.debug("Generate URL")
|
| 88 |
url = self.create_url()
|
| 89 |
+
logger.debug("Preparing Data")
|
| 90 |
+
data = self.prepare_data(text, sampling_rate)
|
| 91 |
result = bytearray()
|
| 92 |
+
try:
|
| 93 |
+
async with websockets.client.connect(url) as ws:
|
| 94 |
+
logger.debug("Sending Data")
|
| 95 |
+
await ws.send(json.dumps(data))
|
| 96 |
+
while True:
|
| 97 |
+
try:
|
| 98 |
+
message = await ws.recv()
|
| 99 |
+
message = json.loads(message)
|
| 100 |
+
logger.debug(f"Received message: {message}")
|
| 101 |
+
audio = message["data"]["audio"]
|
| 102 |
+
logger.debug(f"Received audio length: {len(audio)}")
|
| 103 |
+
audio = base64.b64decode(audio)
|
| 104 |
+
status = message["data"]["status"]
|
| 105 |
+
result += audio
|
| 106 |
+
if status == STATUS_LAST_FRAME:
|
| 107 |
+
break
|
| 108 |
+
except ConnectionClosedError:
|
| 109 |
break
|
| 110 |
+
except InvalidStatusCode as e:
|
| 111 |
+
logger.error(f"Error: {e}")
|
| 112 |
+
raise e
|
| 113 |
+
logger.success("Audio generation finished")
|
| 114 |
+
return sampling_rate, self.parse_result(bytes(result))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
__all__ = ["TTSClient"]
|