import asyncio import re import time import traceback from collections.abc import AsyncGenerator from astrbot.core import logger from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.message.components import BaseMessageComponent, Json, Plain from astrbot.core.message.message_event_result import ( MessageChain, MessageEventResult, ResultContentType, ) from astrbot.core.persona_error_reply import ( extract_persona_custom_error_message_from_event, ) from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.provider import TTSProvider AgentRunner = ToolLoopAgentRunner[AstrAgentContext] def _should_stop_agent(astr_event) -> bool: return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested")) def _truncate_tool_result(text: str, limit: int = 70) -> str: if limit <= 0: return "" if len(text) <= limit: return text if limit <= 3: return text[:limit] return f"{text[: limit - 3]}..." def _extract_chain_json_data(msg_chain: MessageChain) -> dict | None: if not msg_chain.chain: return None first_comp = msg_chain.chain[0] if isinstance(first_comp, Json) and isinstance(first_comp.data, dict): return first_comp.data return None def _record_tool_call_name( tool_info: dict | None, tool_name_by_call_id: dict[str, str] ) -> None: if not isinstance(tool_info, dict): return tool_call_id = tool_info.get("id") tool_name = tool_info.get("name") if tool_call_id is None or tool_name is None: return tool_name_by_call_id[str(tool_call_id)] = str(tool_name) def _build_tool_call_status_message(tool_info: dict | None) -> str: if tool_info: return f"🔨 调用工具: {tool_info.get('name', 'unknown')}" return "🔨 调用工具..." def _build_tool_result_status_message( msg_chain: MessageChain, tool_name_by_call_id: dict[str, str] ) -> str: tool_name = "unknown" tool_result = "" result_data = _extract_chain_json_data(msg_chain) if result_data: tool_call_id = result_data.get("id") if tool_call_id is not None: tool_name = tool_name_by_call_id.pop(str(tool_call_id), "unknown") tool_result = str(result_data.get("result", "")) if not tool_result: tool_result = msg_chain.get_plain_text(with_other_comps_mark=True) tool_result = _truncate_tool_result(tool_result, 70) status_msg = f"🔨 调用工具: {tool_name}" if tool_result: status_msg = f"{status_msg}\n📎 返回结果: {tool_result}" return status_msg async def run_agent( agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True, show_tool_call_result: bool = False, stream_to_general: bool = False, show_reasoning: bool = False, ) -> AsyncGenerator[MessageChain | None, None]: step_idx = 0 astr_event = agent_runner.run_context.context.event tool_name_by_call_id: dict[str, str] = {} while step_idx < max_step + 1: step_idx += 1 if step_idx == max_step + 1: logger.warning( f"Agent reached max steps ({max_step}), forcing a final response." ) if not agent_runner.done(): # 拔掉所有工具 if agent_runner.req: agent_runner.req.func_tool = None # 注入提示词 agent_runner.run_context.messages.append( Message( role="user", content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。", ) ) stop_watcher = asyncio.create_task( _watch_agent_stop_signal(agent_runner, astr_event), ) try: async for resp in agent_runner.step(): if _should_stop_agent(astr_event): agent_runner.request_stop() if resp.type == "aborted": if not stop_watcher.done(): stop_watcher.cancel() try: await stop_watcher except asyncio.CancelledError: pass astr_event.set_extra("agent_user_aborted", True) astr_event.set_extra("agent_stop_requested", False) return if _should_stop_agent(astr_event): continue if resp.type == "tool_call_result": msg_chain = resp.data["chain"] astr_event.trace.record( "agent_tool_result", tool_result=msg_chain.get_plain_text( with_other_comps_mark=True ), ) if msg_chain.type == "tool_direct_result": # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 await astr_event.send(msg_chain) continue if astr_event.get_platform_id() == "webchat": await astr_event.send(msg_chain) elif show_tool_use and show_tool_call_result: status_msg = _build_tool_result_status_message( msg_chain, tool_name_by_call_id ) await astr_event.send( MessageChain(type="tool_call").message(status_msg) ) # 对于其他情况,暂时先不处理 continue elif resp.type == "tool_call": if agent_runner.streaming: # 用来标记流式响应需要分节 yield MessageChain(chain=[], type="break") tool_info = _extract_chain_json_data(resp.data["chain"]) astr_event.trace.record( "agent_tool_call", tool_name=tool_info if tool_info else "unknown", ) _record_tool_call_name(tool_info, tool_name_by_call_id) if astr_event.get_platform_name() == "webchat": await astr_event.send(resp.data["chain"]) elif show_tool_use: if show_tool_call_result and isinstance(tool_info, dict): # Delay tool status notification until tool_call_result. continue chain = MessageChain(type="tool_call").message( _build_tool_call_status_message(tool_info) ) await astr_event.send(chain) continue if stream_to_general and resp.type == "streaming_delta": continue if stream_to_general or not agent_runner.streaming: content_typ = ( ResultContentType.LLM_RESULT if resp.type == "llm_result" else ResultContentType.GENERAL_RESULT ) astr_event.set_result( MessageEventResult( chain=resp.data["chain"].chain, result_content_type=content_typ, ), ) yield astr_event.clear_result() elif resp.type == "streaming_delta": chain = resp.data["chain"] if chain.type == "reasoning" and not show_reasoning: # display the reasoning content only when configured continue yield resp.data["chain"] # MessageChain if not stop_watcher.done(): stop_watcher.cancel() try: await stop_watcher except asyncio.CancelledError: pass if agent_runner.done(): # send agent stats to webchat if astr_event.get_platform_name() == "webchat": await astr_event.send( MessageChain( type="agent_stats", chain=[Json(data=agent_runner.stats.to_dict())], ) ) break except Exception as e: if "stop_watcher" in locals() and not stop_watcher.done(): stop_watcher.cancel() try: await stop_watcher except asyncio.CancelledError: pass logger.error(traceback.format_exc()) custom_error_message = extract_persona_custom_error_message_from_event( astr_event ) if custom_error_message: err_msg = custom_error_message else: err_msg = ( f"Error occurred during AI execution.\n" f"Error Type: {type(e).__name__}\n" f"Error Message: {str(e)}" ) error_llm_response = LLMResponse( role="err", completion_text=err_msg, ) try: await agent_runner.agent_hooks.on_agent_done( agent_runner.run_context, error_llm_response ) except Exception: logger.exception("Error in on_agent_done hook") if agent_runner.streaming: yield MessageChain().message(err_msg) else: astr_event.set_result(MessageEventResult().message(err_msg)) return async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> None: while not agent_runner.done(): if _should_stop_agent(astr_event): agent_runner.request_stop() return await asyncio.sleep(0.5) async def run_live_agent( agent_runner: AgentRunner, tts_provider: TTSProvider | None = None, max_step: int = 30, show_tool_use: bool = True, show_tool_call_result: bool = False, show_reasoning: bool = False, ) -> AsyncGenerator[MessageChain | None, None]: """Live Mode 的 Agent 运行器,支持流式 TTS Args: agent_runner: Agent 运行器 tts_provider: TTS Provider 实例 max_step: 最大步数 show_tool_use: 是否显示工具使用 show_tool_call_result: 是否显示工具返回结果 show_reasoning: 是否显示推理过程 Yields: MessageChain: 包含文本或音频数据的消息链 """ # 如果没有 TTS Provider,直接发送文本 if not tts_provider: async for chain in run_agent( agent_runner, max_step=max_step, show_tool_use=show_tool_use, show_tool_call_result=show_tool_call_result, stream_to_general=False, show_reasoning=show_reasoning, ): yield chain return support_stream = tts_provider.support_stream() if support_stream: logger.info("[Live Agent] 使用流式 TTS(原生支持 get_audio_stream)") else: logger.info( f"[Live Agent] 使用 TTS({tts_provider.meta().type} " "使用 get_audio,将按句子分块生成音频)" ) # 统计数据初始化 tts_start_time = time.time() tts_first_frame_time = 0.0 first_chunk_received = False # 创建队列 text_queue: asyncio.Queue[str | None] = asyncio.Queue() # audio_queue stored bytes or (text, bytes) audio_queue: asyncio.Queue[bytes | tuple[str, bytes] | None] = asyncio.Queue() # 1. 启动 Agent Feeder 任务:负责运行 Agent 并将文本分句喂给 text_queue feeder_task = asyncio.create_task( _run_agent_feeder( agent_runner, text_queue, max_step, show_tool_use, show_tool_call_result, show_reasoning, ) ) # 2. 启动 TTS 任务:负责从 text_queue 读取文本并生成音频到 audio_queue if support_stream: tts_task = asyncio.create_task( _safe_tts_stream_wrapper(tts_provider, text_queue, audio_queue) ) else: tts_task = asyncio.create_task( _simulated_stream_tts(tts_provider, text_queue, audio_queue) ) # 3. 主循环:从 audio_queue 读取音频并 yield try: while True: queue_item = await audio_queue.get() if queue_item is None: break text = None if isinstance(queue_item, tuple): text, audio_data = queue_item else: audio_data = queue_item if not first_chunk_received: # 记录首帧延迟(从开始处理到收到第一个音频块) tts_first_frame_time = time.time() - tts_start_time first_chunk_received = True # 将音频数据封装为 MessageChain import base64 audio_b64 = base64.b64encode(audio_data).decode("utf-8") comps: list[BaseMessageComponent] = [Plain(audio_b64)] if text: comps.append(Json(data={"text": text})) chain = MessageChain(chain=comps, type="audio_chunk") yield chain except Exception as e: logger.error(f"[Live Agent] 运行时发生错误: {e}", exc_info=True) finally: # 清理任务 if not feeder_task.done(): feeder_task.cancel() if not tts_task.done(): tts_task.cancel() # 确保队列被消费 pass tts_end_time = time.time() # 发送 TTS 统计信息 try: astr_event = agent_runner.run_context.context.event if astr_event.get_platform_name() == "webchat": tts_duration = tts_end_time - tts_start_time await astr_event.send( MessageChain( type="tts_stats", chain=[ Json( data={ "tts_total_time": tts_duration, "tts_first_frame_time": tts_first_frame_time, "tts": tts_provider.meta().type, "chat_model": agent_runner.provider.get_model(), } ) ], ) ) except Exception as e: logger.error(f"发送 TTS 统计信息失败: {e}") async def _run_agent_feeder( agent_runner: AgentRunner, text_queue: asyncio.Queue, max_step: int, show_tool_use: bool, show_tool_call_result: bool, show_reasoning: bool, ) -> None: """运行 Agent 并将文本输出分句放入队列""" buffer = "" try: async for chain in run_agent( agent_runner, max_step=max_step, show_tool_use=show_tool_use, show_tool_call_result=show_tool_call_result, stream_to_general=False, show_reasoning=show_reasoning, ): if chain is None: continue # 提取文本 text = chain.get_plain_text() if text: buffer += text # 分句逻辑:匹配标点符号 # r"([.。!!??\n]+)" 会保留分隔符 parts = re.split(r"([.。!!??\n]+)", buffer) if len(parts) > 1: # 处理完整的句子 # range step 2 因为 split 后是 [text, delim, text, delim, ...] temp_buffer = "" for i in range(0, len(parts) - 1, 2): sentence = parts[i] delim = parts[i + 1] full_sentence = sentence + delim temp_buffer += full_sentence if len(temp_buffer) >= 10: if temp_buffer.strip(): logger.info(f"[Live Agent Feeder] 分句: {temp_buffer}") await text_queue.put(temp_buffer) temp_buffer = "" # 更新 buffer 为剩余部分 buffer = temp_buffer + parts[-1] # 处理剩余 buffer if buffer.strip(): await text_queue.put(buffer) except Exception as e: logger.error(f"[Live Agent Feeder] Error: {e}", exc_info=True) finally: # 发送结束信号 await text_queue.put(None) async def _safe_tts_stream_wrapper( tts_provider: TTSProvider, text_queue: asyncio.Queue[str | None], audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]", ) -> None: """包装原生流式 TTS 确保异常处理和队列关闭""" try: await tts_provider.get_audio_stream(text_queue, audio_queue) except Exception as e: logger.error(f"[Live TTS Stream] Error: {e}", exc_info=True) finally: await audio_queue.put(None) async def _simulated_stream_tts( tts_provider: TTSProvider, text_queue: asyncio.Queue[str | None], audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]", ) -> None: """模拟流式 TTS 分句生成音频""" try: while True: text = await text_queue.get() if text is None: break try: audio_path = await tts_provider.get_audio(text) if audio_path: with open(audio_path, "rb") as f: audio_data = f.read() await audio_queue.put((text, audio_data)) except Exception as e: logger.error( f"[Live TTS Simulated] Error processing text '{text[:20]}...': {e}" ) # 继续处理下一句 except Exception as e: logger.error(f"[Live TTS Simulated] Critical Error: {e}", exc_info=True) finally: await audio_queue.put(None)