from typing import Any, Dict, Iterator, List, Optional from zhipuai import ZhipuAI from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, SystemMessage, ChatMessage, HumanMessage ) from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult import time def _convert_message_to_dict(message: BaseMessage) -> dict: """ 把LangChain的消息格式转为智谱支持的格式 Args: message: The LangChain message. Returns: The dictionary. """ message_dict: Dict[str, Any] = {"content": message.content} if (name := message.name or message.additional_kwargs.get("name")) is not None: message_dict["name"] = name # populate role and additional message data if isinstance(message, ChatMessage): message_dict["role"] = message.role elif isinstance(message, HumanMessage): message_dict["role"] = "user" elif isinstance(message, AIMessage): message_dict["role"] = "assistant" elif isinstance(message, SystemMessage): message_dict["role"] = "system" else: raise TypeError(f"Got unknown type {message}") return message_dict class ZhipuaiLLM(BaseChatModel): """自定义Zhipuai聊天模型。 """ model_name: str = None temperature: Optional[float] = None max_tokens: Optional[int] = None timeout: Optional[int] = None stop: Optional[List[str]] = None max_retries: int = 3 api_key: str | None = None def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """通过调用智谱API从而响应输入。 Args: messages: 由messages列表组成的prompt stop: 在模型生成的回答中有该字符串列表中的元素则停止响应 run_manager: 一个为LLM提供回调的运行管理器 """ # 列表推导式 将 messages 的元素逐个转为智谱的格式 messages = [_convert_message_to_dict(message) for message in messages] # 定义推理的开始时间 start_time = time.time() # 调用 ZhipuAI 对处理消息 response = ZhipuAI(api_key=self.api_key).chat.completions.create( model=self.model_name, temperature=self.temperature, max_tokens=self.max_tokens, timeout=self.timeout, stop=stop, messages=messages ) # 计算运行时间 由现在时间 time.time() 减去 开始时间start_time得到 time_in_seconds = time.time() - start_time # 将返回的消息封装并返回 message = AIMessage( content=response.choices[0].message.content, # 响应的结果 additional_kwargs={}, # 额外信息 response_metadata={ "time_in_seconds": round(time_in_seconds, 3), # 响应源数据 这里是运行时间 也可以添加其他信息 }, # 本次推理消耗的token usage_metadata={ "input_tokens": response.usage.prompt_tokens, # 输入token "output_tokens": response.usage.completion_tokens, # 输出token "total_tokens": response.usage.total_tokens, # 全部token }, ) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """通过调用智谱API返回流式输出。 Args: messages: 由messages列表组成的prompt stop: 在模型生成的回答中有该字符串列表中的元素则停止响应 run_manager: 一个为LLM提供回调的运行管理器 """ messages = [_convert_message_to_dict(message) for message in messages] response = ZhipuAI().chat.completions.create( model=self.model_name, stream=True, # 将stream 设置为 True 返回的是迭代器,可以通过for循环取值 temperature=self.temperature, max_tokens=self.max_tokens, timeout=self.timeout, stop=stop, messages=messages ) start_time = time.time() # 使用for循环存取结果 for res in response: if res.usage: # 如果 res.usage 存在则存储token使用情况 usage_metadata = UsageMetadata( { "input_tokens": res.usage.prompt_tokens, "output_tokens": res.usage.completion_tokens, "total_tokens": res.usage.total_tokens, } ) # 封装每次返回的chunk chunk = ChatGenerationChunk( message=AIMessageChunk(content=res.choices[0].delta.content) ) if run_manager: # This is optional in newer versions of LangChain # The on_llm_new_token will be called automatically run_manager.on_llm_new_token(res.choices[0].delta.content, chunk=chunk) # 使用yield返回 结果是一个生成器 同样可以使用for循环调用 yield chunk time_in_sec = time.time() - start_time # Let's add some other information (e.g., response metadata) # 最终返回运行时间 chunk = ChatGenerationChunk( message=AIMessageChunk(content="", response_metadata={"time_in_sec": round(time_in_sec, 3)}, usage_metadata=usage_metadata) ) if run_manager: # This is optional in newer versions of LangChain # The on_llm_new_token will be called automatically run_manager.on_llm_new_token("", chunk=chunk) yield chunk @property def _llm_type(self) -> str: """获取此聊天模型使用的语言模型类型。""" return self.model_name @property def _identifying_params(self) -> Dict[str, Any]: """返回一个标识参数的字典。 该信息由LangChain回调系统使用,用于跟踪目的,使监视llm成为可能。 """ return { "model_name": self.model_name, }