| import sys |
| import os |
|
|
| from langgraph.graph import START, END, StateGraph |
| from langchain_openai import OpenAIEmbeddings |
| from chains import simple_chain, llm_with_tools |
| from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage |
| from typing import TypedDict, Optional, Dict, List, Union, Annotated |
| from langchain_core.messages import AnyMessage |
| from langgraph.graph.message import add_messages |
| from langgraph.prebuilt import ToolNode, tools_condition |
| from langchain.agents import initialize_agent, Tool |
| from langchain.agents.agent_types import AgentType |
| from langgraph.checkpoint.memory import MemorySaver |
| import json |
| import langchain |
| from tools import json_to_table, goal_feasibility, rag_tool, save_data |
| import re |
|
|
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| memory = MemorySaver() |
| config = {"thread_id":"sample"} |
| tools = [json_to_table, rag_tool] |
| |
| json_to_table_node = ToolNode([json_to_table]) |
|
|
| rag_tool_node = ToolNode([rag_tool]) |
| class Graph(TypedDict): |
| query: Annotated[list[AnyMessage], add_messages] |
| |
| user_data : Dict |
| allocations : Dict |
| |
| output : Dict |
| retrieved_context: str |
|
|
| def chat(state): |
| inputs = { |
| "query": state["query"], |
| "user_data": state["user_data"], |
| "allocations": state["allocations"], |
| |
| "chat_history": state["query"], |
| "retrieved_context": state.get("retrieved_context", "") |
| } |
|
|
| result = simple_chain.invoke(inputs) |
| |
|
|
| return { |
| "query": state["query"], |
| "user_data": state["user_data"], |
| "allocations": state["allocations"], |
| |
| "retrieved_context": "", |
| "output": result.content |
| } |
|
|
| def json_to_table_node(state): |
| tool_output = json_to_table(state["allocations"]) |
| return AIMessage(content=tool_output) |
|
|
| def tools_condition(state): |
| last_message = state["query"][-1] |
| if isinstance(last_message, AIMessage): |
| tool_calls = getattr(last_message, "tool_calls", None) |
| |
| |
| if tool_calls: |
| tool_name = tool_calls[0].get('name', '') |
| |
| if tool_name == "json_to_table": |
| return "show_allocation_table" |
| |
| elif tool_name == "rag_tool": |
| return "query_rag" |
| else: |
| return "tools" |
| return "END" |
|
|
|
|
| |
| graph = StateGraph(Graph) |
|
|
| |
| graph.add_node("chat", chat) |
| graph.add_node("show_allocation_table", json_to_table_node) |
| |
| graph.add_node("query_rag", rag_tool_node) |
| graph.add_node("tool_output_to_message", lambda state: AIMessage(content=state["tool_output"])) |
|
|
|
|
| |
|
|
|
|
| |
| graph.add_edge(START, "chat") |
| graph.add_conditional_edges("chat", tools_condition) |
|
|
| |
| graph.add_edge("show_allocation_table", "chat") |
| |
| graph.add_edge("query_rag", "chat") |
|
|
| |
| graph.add_edge("chat", END) |
|
|
|
|
| |
| app = graph.compile(checkpointer=memory) |
|
|
| ''' |
| with open('/home/pavan/Desktop/FOLDERS/RUBIC/RAG_without_profiler/RAG_rubik/sample_data/sample_alloc.json', 'r') as f: |
| data = json.load(f) |
| with open('/home/pavan/Desktop/FOLDERS/RUBIC/RAG_without_profiler/RAG_rubik/sample_data/sample_alloc.json', 'r') as f: |
| allocs = json.load(f) |
| inputs = { |
| "query":"display my investments.", |
| "user_data":data, |
| "allocations":allocs, |
| "data":"", |
| "chat_history": [], |
| |
| } |
| |
| langchain.debug = True |
| print(app.invoke(inputs, config={"configurable": {"thread_id": "sample"}}).get('output')) |
| #print(json_to_table.args_schema.model_json_schema()) |
| ''' |