Yi Jin commited on
Commit
177a062
·
unverified ·
1 Parent(s): a7c4935

Uploaded beta code

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. .vscode/settings.json +0 -4
  3. app.py +124 -0
  4. iat.py +24 -4
  5. requirements.txt +2 -1
  6. tts.py +45 -26
.gitignore CHANGED
@@ -177,3 +177,5 @@ pyrightconfig.json
177
 
178
  config.yaml
179
 
 
 
 
177
 
178
  config.yaml
179
 
180
+
181
+ .vscode/
.vscode/settings.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "python.analysis.autoImportCompletions": true,
3
- "python.analysis.typeCheckingMode": "basic"
4
- }
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from loguru import logger
4
+ from zhipuai import ZhipuAI
5
+ from zhipuai.api_resource.chat.chat import Chat
6
+ import yaml
7
+ import json
8
+ from iat import IATClient
9
+ from tts import TTSClient
10
+ import numpy as np
11
+ from scipy.signal import resample
12
+
13
+ logger.debug("Loading config")
14
+ config_env = os.environ.get("CONFIG", "")
15
+ if config_env:
16
+ logger.debug("Using environment variable for config")
17
+ config = json.loads(config_env)
18
+ else:
19
+ logger.debug("Reading config from file")
20
+ with open("config.yaml", "r") as f:
21
+ try:
22
+ config = yaml.safe_load(f)["config"]
23
+ except yaml.YAMLError as e:
24
+ logger.error(e)
25
+ raise e
26
+
27
+ zhipuai_config = config["zhipuai"]
28
+ xfyun_config = config["xfyun"]
29
+
30
+ zhipuai = ZhipuAI(api_key=zhipuai_config["apikey"])
31
+ iat = IATClient(
32
+ xfyun_config["iat"]["appid"],
33
+ xfyun_config["iat"]["apikey"],
34
+ xfyun_config["iat"]["apisecret"],
35
+ )
36
+ tts = TTSClient(
37
+ xfyun_config["tts"]["appid"],
38
+ xfyun_config["tts"]["apikey"],
39
+ xfyun_config["tts"]["apisecret"],
40
+ )
41
+
42
+
43
+ def build_zhipuai_history(history: list[list[str]]):
44
+ result = [{"role": "system", "content": config["zhipuai"]["prompt"]}]
45
+ for history_element in history:
46
+ user_message, assistant_message = history_element
47
+ if user_message != None:
48
+ result += [{"role": "user", "content": user_message}]
49
+ if assistant_message != None:
50
+ result += [{"role": "assistant", "content": assistant_message}]
51
+ return result
52
+
53
+
54
+ def add_text(history, text):
55
+ history = history + [(text, None)]
56
+ return history, gr.Textbox(value="", interactive=False)
57
+
58
+
59
+ def bot(history):
60
+ zhipuai_history = build_zhipuai_history(history)
61
+ res = zhipuai.chat.completions.create(
62
+ model="glm-4", messages=zhipuai_history, stream=True
63
+ )
64
+ history[-1][1] = ""
65
+ for chunk in res:
66
+ history[-1][1] += chunk.choices[0].delta.content
67
+ yield history
68
+
69
+
70
+ async def generate_text(audio: tuple[int, np.ndarray]):
71
+ logger.debug(f"Generating text from audio")
72
+ logger.debug(f"Sampling rate: {audio[0]}, resampling to 16000")
73
+ audio = (16000, resample(audio[1], 16000))
74
+ result_list = []
75
+ async for result in iat.dictate(audio):
76
+ logger.debug(f"Result: {result}")
77
+ result_list.append(result)
78
+ return "".join(result_list)
79
+
80
+
81
+ async def generate_audio(history: list[list[str]]):
82
+ logger.debug(f"Generating audio from text")
83
+ text = history[-1][-1]
84
+ result = await tts.generate(text)
85
+ return result
86
+
87
+ with gr.Blocks() as demo:
88
+ title = gr.Markdown("# 老王元宇宙受害者")
89
+
90
+ chatbot = gr.Chatbot(
91
+ [],
92
+ elem_id="chatbot",
93
+ bubble_full_width=False,
94
+ )
95
+
96
+ with gr.Row():
97
+ txt = gr.Textbox(
98
+ scale=4,
99
+ show_label=False,
100
+ placeholder="Enter text and press enter",
101
+ container=False,
102
+ )
103
+ submit_button = gr.Button(value="提交", variant="primary")
104
+
105
+ with gr.Row():
106
+ with gr.Column():
107
+ user_title = gr.Markdown("## 用户语音识别")
108
+ user_audio = gr.Audio(type="numpy", sources=['microphone'])
109
+ user_audio_submit = gr.Button(value="上传用户语音并转换", variant="primary")
110
+ with gr.Column():
111
+ user_title = gr.Markdown("## 机器人语音合成")
112
+ bot_audio = gr.Audio()
113
+ bot_audio_submit = gr.Button(value="将机器人最后一个回复转换为语音", variant="primary")
114
+
115
+ user_audio_submit.click(generate_text, [user_audio], outputs=txt)
116
+ bot_audio_submit.click(generate_audio, [chatbot], outputs=bot_audio)
117
+
118
+ txt_msg = submit_button.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
119
+ bot, chatbot, chatbot, api_name="bot_response"
120
+ )
121
+
122
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
123
+
124
+ demo.launch()
iat.py CHANGED
@@ -7,6 +7,7 @@ import base64
7
  import numpy as np
8
  import json
9
  import websockets.client
 
10
  from websockets.exceptions import ConnectionClosedError
11
  import time
12
 
@@ -52,29 +53,35 @@ class IATClient:
52
  path = parse_result.path
53
 
54
  sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
 
55
  sign_sha = hmac.new(
56
  self.api_secret.encode("utf-8"),
57
  sign_raw_str.encode("utf-8"),
58
  digestmod=hashlib.sha256,
59
  ).digest()
60
- auth_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
 
61
  self.api_key,
62
  "hmac-sha256",
63
  "host date request-line",
64
  sign_sha,
65
  )
66
- auth = base64.b64encode(auth_origin.encode("utf-8")).decode("utf-8")
 
67
  params = {
68
  "authorization": auth,
69
  "date": date,
70
  "host": host,
71
  }
72
  url = f"{self.endpoint}?{urlencode(params)}"
 
73
  return url
74
 
75
  def prepare_data(self, audio: bytes, chunk_size=1280, sampling_rate=16000):
76
  status = STATUS_FIRST_FRAME
 
77
  for i in range(0, len(audio), chunk_size):
 
78
  chunk = audio[i : i + chunk_size]
79
  if i + chunk_size >= len(audio):
80
  status = STATUS_LAST_FRAME
@@ -91,16 +98,29 @@ class IATClient:
91
  yield payload
92
  status = STATUS_CONTINUE_FRAME
93
 
94
- async def dictate(self, audio: tuple[int, np.ndarray], interval=0.4):
 
95
  url = self.create_url()
 
96
  sampling_rate, source = audio
97
  pcm = self.encode_pcm(source)
98
  async with websockets.client.connect(url) as ws:
99
  for payload in self.prepare_data(pcm, sampling_rate=sampling_rate):
 
100
  await ws.send(json.dumps(payload))
101
  time.sleep(interval)
102
  try:
103
  async for message in ws:
104
- yield message
 
 
 
 
 
 
 
 
 
 
105
  except ConnectionClosedError as e:
106
  print(f"Connection closed: {e.code} {e.reason}")
 
7
  import numpy as np
8
  import json
9
  import websockets.client
10
+ from loguru import logger
11
  from websockets.exceptions import ConnectionClosedError
12
  import time
13
 
 
53
  path = parse_result.path
54
 
55
  sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
56
+ logger.debug(f"Sign raw string: {sign_raw_str}")
57
  sign_sha = hmac.new(
58
  self.api_secret.encode("utf-8"),
59
  sign_raw_str.encode("utf-8"),
60
  digestmod=hashlib.sha256,
61
  ).digest()
62
+ sign_sha = base64.b64encode(sign_sha).decode("utf-8")
63
+ auth_raw_str = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
64
  self.api_key,
65
  "hmac-sha256",
66
  "host date request-line",
67
  sign_sha,
68
  )
69
+ logger.debug(f"Authorization: {auth_raw_str}")
70
+ auth = base64.b64encode(auth_raw_str.encode("utf-8")).decode("utf-8")
71
  params = {
72
  "authorization": auth,
73
  "date": date,
74
  "host": host,
75
  }
76
  url = f"{self.endpoint}?{urlencode(params)}"
77
+ logger.debug(f"URL: {url}")
78
  return url
79
 
80
  def prepare_data(self, audio: bytes, chunk_size=1280, sampling_rate=16000):
81
  status = STATUS_FIRST_FRAME
82
+ logger.debug(f"Total audio length: {len(audio)}")
83
  for i in range(0, len(audio), chunk_size):
84
+ logger.debug(f"Processing chunk {i} to {i + chunk_size}")
85
  chunk = audio[i : i + chunk_size]
86
  if i + chunk_size >= len(audio):
87
  status = STATUS_LAST_FRAME
 
98
  yield payload
99
  status = STATUS_CONTINUE_FRAME
100
 
101
+ async def dictate(self, audio: tuple[int, np.ndarray], interval=0.04):
102
+ logger.debug(f"Generate URL")
103
  url = self.create_url()
104
+ logger.debug("Encoding audio to PCM")
105
  sampling_rate, source = audio
106
  pcm = self.encode_pcm(source)
107
  async with websockets.client.connect(url) as ws:
108
  for payload in self.prepare_data(pcm, sampling_rate=sampling_rate):
109
+ logger.debug('Sending payload')
110
  await ws.send(json.dumps(payload))
111
  time.sleep(interval)
112
  try:
113
  async for message in ws:
114
+ data: dict = json.loads(message)
115
+ logger.debug(f"Received data: {data}")
116
+ if not 'data' in data.keys():
117
+ yield ''
118
+ break
119
+ is_end = data["data"]["status"] == STATUS_LAST_FRAME
120
+ ws_list = data["data"]["result"]["ws"]
121
+ text = ''.join([cw["w"] for cw in sum([ws["cw"] for ws in ws_list], [])])
122
+ yield text
123
+ if is_end:
124
+ break
125
  except ConnectionClosedError as e:
126
  print(f"Connection closed: {e.code} {e.reason}")
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
  gradio
2
  jupyter
3
  requests
4
- websocket-client
5
  zhipuai
6
  loguru
7
  numpy
 
 
1
  gradio
2
  jupyter
3
  requests
4
+ websockets
5
  zhipuai
6
  loguru
7
  numpy
8
+ scipy
tts.py CHANGED
@@ -1,13 +1,14 @@
1
  from urllib.parse import urlparse, urlencode
2
  from wsgiref.handlers import format_date_time
3
  from datetime import datetime
 
4
  import hmac
5
  import hashlib
6
  import base64
7
  import numpy as np
8
  import json
9
  import websockets.client
10
- from websockets.exceptions import ConnectionClosedError
11
  import time
12
 
13
  STATUS_FIRST_FRAME = 0
@@ -16,7 +17,6 @@ STATUS_LAST_FRAME = 2
16
 
17
 
18
  class TTSClient:
19
-
20
  def __init__(
21
  self,
22
  app_id: str,
@@ -29,22 +29,24 @@ class TTSClient:
29
  self.api_secret = api_secret
30
  self.endpoint = endpoint
31
  self.common_args = {"app_id": self.app_id}
32
- self.business_args = {
 
 
33
  "aue": "raw",
34
- "auf": "audio/L16;rate=16000",
35
  "vcn": "xiaoyan",
36
  "tte": "utf8",
37
  }
38
-
39
- def prepare_data(self, text: str):
40
- return {
41
  "common": self.common_args,
42
- "business": self.business_args,
43
  "data": {
44
  "status": 2,
45
  "text": str(base64.b64encode(text.encode("utf-8")), "UTF8"),
46
  },
47
  }
 
 
48
 
49
  def create_url(self):
50
  parse_result = urlparse(self.endpoint)
@@ -54,45 +56,62 @@ class TTSClient:
54
  path = parse_result.path
55
 
56
  sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
 
57
  sign_sha = hmac.new(
58
  self.api_secret.encode("utf-8"),
59
  sign_raw_str.encode("utf-8"),
60
  digestmod=hashlib.sha256,
61
  ).digest()
62
- auth_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
 
63
  self.api_key,
64
  "hmac-sha256",
65
  "host date request-line",
66
  sign_sha,
67
  )
68
- auth = base64.b64encode(auth_origin.encode("utf-8")).decode("utf-8")
 
69
  params = {
70
  "authorization": auth,
71
  "date": date,
72
  "host": host,
73
  }
74
  url = f"{self.endpoint}?{urlencode(params)}"
 
75
  return url
76
 
77
  def parse_result(self, result: bytes) -> np.ndarray:
78
  return np.frombuffer(result, dtype=np.int16)
79
 
80
- async def generate(self, text: str):
 
81
  url = self.create_url()
82
- data = self.prepare_data(text)
 
83
  result = bytearray()
84
- async with websockets.client.connect(url) as ws:
85
- await ws.send(json.dumps(data))
86
- while True:
87
- try:
88
- message = await ws.recv()
89
- message = json.loads(message)
90
- audio = message["data"]["audio"]
91
- audio = base64.b64decode(audio)
92
- status = message["data"]["status"]
93
- result += audio
94
- if status == STATUS_LAST_FRAME:
 
 
 
 
 
 
95
  break
96
- except ConnectionClosedError:
97
- break
98
- return self.parse_result(bytes(result))
 
 
 
 
 
 
1
  from urllib.parse import urlparse, urlencode
2
  from wsgiref.handlers import format_date_time
3
  from datetime import datetime
4
+ from loguru import logger
5
  import hmac
6
  import hashlib
7
  import base64
8
  import numpy as np
9
  import json
10
  import websockets.client
11
+ from websockets.exceptions import ConnectionClosedError, InvalidStatusCode
12
  import time
13
 
14
  STATUS_FIRST_FRAME = 0
 
17
 
18
 
19
  class TTSClient:
 
20
  def __init__(
21
  self,
22
  app_id: str,
 
29
  self.api_secret = api_secret
30
  self.endpoint = endpoint
31
  self.common_args = {"app_id": self.app_id}
32
+
33
+ def prepare_data(self, text: str, sampling_rate=16000):
34
+ business_args = {
35
  "aue": "raw",
36
+ "auf": f"audio/L16;rate={sampling_rate}",
37
  "vcn": "xiaoyan",
38
  "tte": "utf8",
39
  }
40
+ result = {
 
 
41
  "common": self.common_args,
42
+ "business": business_args,
43
  "data": {
44
  "status": 2,
45
  "text": str(base64.b64encode(text.encode("utf-8")), "UTF8"),
46
  },
47
  }
48
+ logger.debug(f"Data: {result}")
49
+ return result
50
 
51
  def create_url(self):
52
  parse_result = urlparse(self.endpoint)
 
56
  path = parse_result.path
57
 
58
  sign_raw_str = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
59
+ logger.debug(f"Sign raw string: {sign_raw_str}")
60
  sign_sha = hmac.new(
61
  self.api_secret.encode("utf-8"),
62
  sign_raw_str.encode("utf-8"),
63
  digestmod=hashlib.sha256,
64
  ).digest()
65
+ sign_sha = base64.b64encode(sign_sha).decode("utf-8")
66
+ auth_raw_str = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
67
  self.api_key,
68
  "hmac-sha256",
69
  "host date request-line",
70
  sign_sha,
71
  )
72
+ logger.debug(f"Authorization: {auth_raw_str}")
73
+ auth = base64.b64encode(auth_raw_str.encode("utf-8")).decode("utf-8")
74
  params = {
75
  "authorization": auth,
76
  "date": date,
77
  "host": host,
78
  }
79
  url = f"{self.endpoint}?{urlencode(params)}"
80
+ logger.debug(f"URL: {url}")
81
  return url
82
 
83
  def parse_result(self, result: bytes) -> np.ndarray:
84
  return np.frombuffer(result, dtype=np.int16)
85
 
86
+ async def generate(self, text: str, sampling_rate=16000):
87
+ logger.debug("Generate URL")
88
  url = self.create_url()
89
+ logger.debug("Preparing Data")
90
+ data = self.prepare_data(text, sampling_rate)
91
  result = bytearray()
92
+ try:
93
+ async with websockets.client.connect(url) as ws:
94
+ logger.debug("Sending Data")
95
+ await ws.send(json.dumps(data))
96
+ while True:
97
+ try:
98
+ message = await ws.recv()
99
+ message = json.loads(message)
100
+ logger.debug(f"Received message: {message}")
101
+ audio = message["data"]["audio"]
102
+ logger.debug(f"Received audio length: {len(audio)}")
103
+ audio = base64.b64decode(audio)
104
+ status = message["data"]["status"]
105
+ result += audio
106
+ if status == STATUS_LAST_FRAME:
107
+ break
108
+ except ConnectionClosedError:
109
  break
110
+ except InvalidStatusCode as e:
111
+ logger.error(f"Error: {e}")
112
+ raise e
113
+ logger.success("Audio generation finished")
114
+ return sampling_rate, self.parse_result(bytes(result))
115
+
116
+
117
+ __all__ = ["TTSClient"]