| |
| from __future__ import annotations |
|
|
| import json |
| from typing import Any, Dict, List, Optional, TypedDict |
|
|
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage |
| from langgraph.graph import END, START, StateGraph |
| from langgraph.prebuilt import ToolNode |
|
|
| from .runtime_ctx import set_df_payload, set_df_summary, set_sota_bundled |
| from .tools import ( |
| tool_describe_step, |
| tool_inspect_dataset, |
| tool_list_steps, |
| tool_list_versions, |
| tool_propose_plan, |
| tool_reset_to_version, |
| tool_run_step, |
| tool_sota_preprocessing, |
| ) |
|
|
|
|
| def _to_text(content: Any, limit: int = 4000) -> str: |
| """Coerce any message content to a string that Gemini will accept.""" |
| if content is None: |
| return "" |
| if isinstance(content, str): |
| return content |
| try: |
| s = json.dumps(content, default=str, ensure_ascii=False) |
| except Exception: |
| s = str(content) |
| |
| return (s[:limit] + " …") if len(s) > limit else s |
|
|
| def _sanitize_messages(msgs: list[Any]) -> list[Any]: |
| """Keep only system/human/assistant messages and ensure content is str.""" |
| clean = [] |
| for m in msgs or []: |
| |
| role = getattr(m, "type", None) or getattr(m, "role", None) or "" |
| if isinstance(m, ToolMessage) or role == "tool": |
| |
| txt = _to_text(getattr(m, "content", None)) |
| if txt: |
| clean.append(AIMessage(content=f"[Tool result] {txt}")) |
| continue |
|
|
| c = _to_text(getattr(m, "content", None)) |
| if isinstance(m, SystemMessage): |
| clean.append(SystemMessage(content=c)) |
| elif isinstance(m, HumanMessage): |
| clean.append(HumanMessage(content=c)) |
| elif isinstance(m, AIMessage): |
| clean.append(AIMessage(content=c)) |
| else: |
| |
| r = str(role).lower() |
| if r == "system": |
| clean.append(SystemMessage(content=c)) |
| elif r in ("human", "user"): |
| clean.append(HumanMessage(content=c)) |
| elif r in ("assistant", "ai", "aimessage"): |
| clean.append(AIMessage(content=c)) |
| |
| return clean |
|
|
| TOOLS = [tool_inspect_dataset, tool_sota_preprocessing, tool_list_steps, tool_describe_step, tool_propose_plan, tool_run_step, tool_list_versions, tool_reset_to_version] |
|
|
| SYSTEM_PRIMER = ( |
| "You are a data-quality assistant.\n" |
| "\n" |
| "Workflow:\n" |
| "1) Call inspect_dataset() to summarize columns/dtypes and GUESS task/label.\n" |
| " • If you are NOT SURE about the task (or the label for supervised tasks), ASK the user to confirm and END THE TURN.\n" |
| " • Do NOT call sota_preprocessing until the user explicitly confirms the task (and label if supervised).\n" |
| " Acceptable confirmations include messages like: " |
| " 'task=classification label=HARDSHIP_INDEX', 'Task: regression', or 'Unsupervised'.\n" |
| "2) After the user confirms, call sota_preprocessing(task, modality, ...) and PRESENT a brief 'SOTA Evidence' section (3–6 bullets with titles and links from the tool).\n" |
| "3) Call list_steps() and map SOTA insights to the available tools. Produce a plan (no execution yet); cite up to 2 SOTA sources per step.\n" |
| "4) Ask: 'Which step should we execute first?' Do NOT call run_step until the user explicitly picks.\n" |
| "5) After the user picks, call describe_step(name) and list ONLY real parameters from the tool. Ask for missing/optional params and confirm them.\n" |
| "6) Execute with run_step(name, params_json). Version controls inside params_json when relevant:\n" |
| " • source: 'current' | 'prev' | 'base' | '@-1' | '@-2' | <int>\n" |
| " • dry_run: true|false (preview without mutating)\n" |
| " • new_version: true|false (create new snapshot vs replace current)\n" |
| " Avoid loops: if the same step+params just ran, ask to change parameters or source.\n" |
| "7) Summarize results; optionally call list_versions() and offer reset_to_version(spec). If helpful, research again before proposing next steps.\n" |
| "\n" |
| "Rules:\n" |
| "- Return exactly one tool call at a time.\n" |
| "- Never call sota_preprocessing before explicit task confirmation.\n" |
| "- Never call run_step without an explicit user choice.\n" |
| "- When users ask about parameters, use describe_step (or list_steps) and answer ONLY from tool output.\n" |
| "- Reject parameters that are not in the tool signature.\n" |
| ) |
|
|
|
|
| class AgentState(TypedDict): |
| messages: List[Any] |
| df_payload: Optional[Dict[str, Any]] |
| results: List[Dict[str, Any]] |
| steps_taken: int |
| max_steps: int |
| confirmed_step: Optional[str] |
| confirmed_params: Dict[str, Any] |
| last_task: Optional[str] |
| plan: Optional[Dict[str, Any]] |
|
|
| def make_agent_node(llm): |
| """LLM emits tool calls; we sanitize history and ALWAYS append an AIMessage.""" |
| llm_with_tools = llm.bind_tools(TOOLS) |
|
|
| def _node(state: AgentState) -> AgentState: |
| d = (state.get("df_payload") or {}).get("data", {}) |
| rows = len(d.get("data", []) or []) |
| cols = len(d.get("columns", []) or []) |
| shape_note = SystemMessage(content=f"Current dataset shape: {rows} rows × {cols} columns.") |
|
|
| history = _sanitize_messages(state.get("messages", [])) |
| inputs = [SystemMessage(content=SYSTEM_PRIMER), *history, shape_note] |
|
|
| ai = llm_with_tools.invoke(inputs) |
| |
| if not isinstance(ai, AIMessage): |
| ai = AIMessage(content=_to_text(getattr(ai, "content", ai))) |
|
|
| state["messages"] = state["messages"] + [ai] |
| |
| |
| return state |
|
|
| return _node |
|
|
| def tools_exec_node(): |
| """ |
| Execute tools only here, after injecting df_payload into runtime context. |
| Also updates state with tool outputs (summary/SOTA/plan/step_result). |
| """ |
| tool_node = ToolNode(TOOLS) |
|
|
| def _node(state: AgentState) -> AgentState: |
| |
| set_df_payload(state.get("df_payload")) |
|
|
| |
| if state.get("df_payload") is None: |
| state["messages"].append(type(state["messages"][-1])(content="I don't have a dataset yet. Please upload one.")) |
| return state |
|
|
| |
| last = state["messages"][-1] |
| tool_calls = getattr(last, "tool_calls", None) or [] |
| for c in tool_calls: |
| if c.get("name") == "run_step": |
| intended = (c.get("args") or {}).get("name") |
| if intended and intended != state.get("confirmed_step"): |
| state["messages"].append(type(last)(content="I have a plan ready. Which step should we run first?")) |
| return state |
|
|
| |
| out = tool_node.invoke({"messages": state["messages"]}) |
| |
| new_msgs = [m for m in out["messages"] if isinstance(m, ToolMessage)] |
| if not new_msgs: |
| |
| if len(out["messages"]) > len(state["messages"]): |
| new_msgs = out["messages"][len(state["messages"]):] |
| else: |
| new_msgs = out["messages"] |
|
|
| state["messages"] = state["messages"] + new_msgs |
|
|
| |
| payload = new_msgs[-1].content if new_msgs else None |
| if isinstance(payload, dict): |
| typ = payload.get("type") |
| if typ == "dataset_summary": |
| set_df_summary(payload) |
| state["last_task"] = payload.get("task_guess") |
| elif typ == "sota": |
| set_sota_bundled(payload.get("bundled_results") or []) |
| elif typ == "plan": |
| state["plan"] = payload |
| elif typ == "step_result": |
| state["df_payload"] = payload["df"] |
| set_df_payload(state["df_payload"]) |
| state["results"].append({"name": payload["name"], "stats": payload["stats"]}) |
| state["steps_taken"] += 1 |
| state["confirmed_step"] = None |
| state["confirmed_params"] = {} |
|
|
| |
| return state |
|
|
| return _node |
|
|
| def should_continue(state: AgentState) -> str: |
| last = state["messages"][-1] |
| if state.get("steps_taken", 0) >= state.get("max_steps", 8): |
| return "end" |
| |
| return "continue" if getattr(last, "tool_calls", None) else "end" |
|
|
| def build_app(llm): |
| g = StateGraph(AgentState) |
| g.add_node("agent", make_agent_node(llm)) |
| g.add_node("tools", tools_exec_node()) |
|
|
| g.add_edge(START, "agent") |
| g.add_conditional_edges("agent", should_continue, {"continue": "tools", "end": END}) |
| g.add_edge("tools", "agent") |
|
|
| return g.compile() |
|
|