| from typing import Literal |
|
|
| from langchain_core.messages import AIMessage, SystemMessage |
| from langchain_core.tools import StructuredTool |
| from langchain_groq import ChatGroq |
| from langgraph.checkpoint.memory import MemorySaver |
| from langgraph.graph import END, START, MessagesState, StateGraph |
| from langgraph.prebuilt import ToolNode |
| from pydantic import BaseModel, Field |
| from sqlalchemy.orm import Session |
|
|
| from app.config import get_settings |
| from app.models import User |
| from app.services.document_service import DocumentService |
| from app.services.vector_store import VectorStoreService |
| from app.services.web_search import build_web_search_tool |
|
|
|
|
| class VectorSearchInput(BaseModel): |
| query: str = Field(..., description="The user question to answer from uploaded documents.") |
|
|
|
|
| LANGGRAPH_CHECKPOINTER = MemorySaver() |
|
|
|
|
| def _route_tools(state: MessagesState) -> Literal["tools", "__end__"]: |
| last = state["messages"][-1] |
| if isinstance(last, AIMessage) and last.tool_calls: |
| return "tools" |
| return "__end__" |
|
|
|
|
| def build_agent(*, db: Session, user: User): |
| settings = get_settings() |
| if not settings.groq_api_key: |
| raise RuntimeError("GROQ_API_KEY is required for agent responses.") |
| llm = ChatGroq(api_key=settings.groq_api_key, model=settings.model_name, temperature=0) |
| document_service = DocumentService() |
| vector_store = VectorStoreService() |
| web_search_tool = build_web_search_tool() |
|
|
| def vector_search(query: str) -> str: |
| resolved_hashes = document_service.resolve_relevant_document_hashes(db, user=user, query=query) |
| if not resolved_hashes: |
| return "No uploaded documents are available for this user." |
| matches = vector_store.similarity_search(db=db, query=query, file_hashes=resolved_hashes, k=settings.retrieval_k) |
| if not matches: |
| return f"No vector matches found for hashes: {resolved_hashes}" |
| lines = ["Vector evidence (cite document + page + excerpt in final answer):"] |
| for index, match in enumerate(matches, start=1): |
| page_number = match["metadata"].get("page_number") |
| page_label = str(page_number) if page_number is not None else "unknown" |
| document_id = match["metadata"].get("document_id") |
| score_parts = [f"distance={match['distance']:.4f}"] |
| if "rerank_score" in match: |
| score_parts.append(f"rerank_score={match['rerank_score']:.4f}") |
| lines.append(f"{index}. document_id={document_id} | document={match['metadata']['filename']} | page={page_label} | {' | '.join(score_parts)}") |
| lines.append(f" excerpt: {match['content'][:900].replace(chr(10), ' ')}") |
| return "\n\n".join(lines) |
|
|
| vector_tool = StructuredTool.from_function( |
| func=vector_search, |
| name="vector_search", |
| description=( |
| "Searches the current user's uploaded documents. " |
| "The tool automatically resolves the most relevant documents for the current user before chunk retrieval." |
| ), |
| args_schema=VectorSearchInput, |
| ) |
|
|
| tools = [vector_tool] |
| prompt = ( |
| "You are a document QA agent. Prefer vector_search for questions about the user's uploaded documents. " |
| "Do NOT include any 'Sources' section, citation list, footnotes, chunk ids, or hashes in the final answer text. " |
| "Only provide the concise user-facing answer. " |
| "Citation metadata is handled separately by the application. " |
| "Do not claim evidence that is not present in tool outputs." |
| ) |
| if web_search_tool is not None: |
| tools.append(web_search_tool) |
| prompt += " Use web search only when the answer depends on external or current information." |
| else: |
| prompt += " Web search is currently unavailable in this environment." |
|
|
| llm_with_tools = llm.bind_tools(tools) |
| tool_node = ToolNode(tools) |
| system_prompt = SystemMessage(content=prompt) |
|
|
| def agent_node(state: MessagesState): |
| response = llm_with_tools.invoke([system_prompt, *state["messages"]]) |
| return {"messages": [response]} |
|
|
| graph = StateGraph(MessagesState) |
| graph.add_node("agent", agent_node) |
| graph.add_node("tools", tool_node) |
| graph.add_edge(START, "agent") |
| graph.add_conditional_edges("agent", _route_tools, {"tools": "tools", "__end__": END}) |
| graph.add_edge("tools", "agent") |
|
|
| return graph.compile(checkpointer=LANGGRAPH_CHECKPOINTER) |
|
|