| from typing import Optional |
|
|
| from langchain_core.messages import HumanMessage |
| from langgraph.graph import START, StateGraph, END |
| from langgraph.graph.state import CompiledStateGraph |
| from langgraph.prebuilt import ToolNode |
| from langgraph.prebuilt import tools_condition |
|
|
| from core.messages import Attachment |
| from core.state import State |
| from nodes.nodes import assistant, optimize_memory, response_processing, pre_processor, agent_tools |
|
|
|
|
| class GaiaAgent: |
| react_graph: CompiledStateGraph |
|
|
| def __init__(self): |
| |
| builder = StateGraph(State) |
|
|
| |
| builder.add_node("pre_processor", pre_processor) |
| builder.add_node("assistant", assistant) |
| builder.add_node("tools", ToolNode(agent_tools)) |
| builder.add_node("optimize_memory", optimize_memory) |
| builder.add_node("response_processing", response_processing) |
|
|
| |
| builder.add_edge(START, "pre_processor") |
| builder.add_edge("pre_processor", "assistant") |
|
|
| builder.add_conditional_edges( |
| "assistant", |
| |
| |
| |
| tools_condition, {"tools": "tools", "__end__": "response_processing"} |
| ) |
|
|
| builder.add_edge("tools", "optimize_memory") |
| builder.add_edge("optimize_memory", "assistant") |
| builder.add_edge("response_processing", END) |
| self.react_graph = builder.compile() |
|
|
| def __call__(self, question: str, attachment: Optional[Attachment] = None) -> str: |
| initial_state = {"messages": [HumanMessage(content=question)], "question": question} |
| if attachment: |
| initial_state["file_reference"] = attachment.file_path |
|
|
| messages = self.react_graph.invoke(initial_state, {"recursion_limit": 30}) |
| |
| |
|
|
| answer = messages['messages'][-1].content |
| return answer |
|
|
| def __streamed_call__(self, question: str, attachment: Optional[Attachment] = None) -> str: |
| initial_state = {"messages": [HumanMessage(content=question)], "question": question} |
| if attachment: |
| initial_state["file_reference"] = attachment.file_path |
|
|
| |
| for s in self.react_graph.stream(initial_state, stream_mode="values"): |
| message = s["messages"][-1] |
| if isinstance(message, tuple): |
| print(message) |
| else: |
| message.pretty_print() |
| return message.content |
|
|