DocsBot / app /services /agent.py
BabaK07's picture
Polish retrieval workflow and UI
d197c9d
from typing import Literal
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.tools import StructuredTool
from langchain_groq import ChatGroq
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.config import get_settings
from app.models import User
from app.services.document_service import DocumentService
from app.services.vector_store import VectorStoreService
from app.services.web_search import build_web_search_tool
class VectorSearchInput(BaseModel):
query: str = Field(..., description="The user question to answer from uploaded documents.")
LANGGRAPH_CHECKPOINTER = MemorySaver()
def _route_tools(state: MessagesState) -> Literal["tools", "__end__"]:
last = state["messages"][-1]
if isinstance(last, AIMessage) and last.tool_calls:
return "tools"
return "__end__"
def build_agent(*, db: Session, user: User):
settings = get_settings()
if not settings.groq_api_key:
raise RuntimeError("GROQ_API_KEY is required for agent responses.")
llm = ChatGroq(api_key=settings.groq_api_key, model=settings.model_name, temperature=0)
document_service = DocumentService()
vector_store = VectorStoreService()
web_search_tool = build_web_search_tool()
def vector_search(query: str) -> str:
resolved_hashes = document_service.resolve_relevant_document_hashes(db, user=user, query=query)
if not resolved_hashes:
return "No uploaded documents are available for this user."
matches = vector_store.similarity_search(db=db, query=query, file_hashes=resolved_hashes, k=settings.retrieval_k)
if not matches:
return f"No vector matches found for hashes: {resolved_hashes}"
lines = ["Vector evidence (cite document + page + excerpt in final answer):"]
for index, match in enumerate(matches, start=1):
page_number = match["metadata"].get("page_number")
page_label = str(page_number) if page_number is not None else "unknown"
document_id = match["metadata"].get("document_id")
score_parts = [f"distance={match['distance']:.4f}"]
if "rerank_score" in match:
score_parts.append(f"rerank_score={match['rerank_score']:.4f}")
lines.append(f"{index}. document_id={document_id} | document={match['metadata']['filename']} | page={page_label} | {' | '.join(score_parts)}")
lines.append(f" excerpt: {match['content'][:900].replace(chr(10), ' ')}")
return "\n\n".join(lines)
vector_tool = StructuredTool.from_function(
func=vector_search,
name="vector_search",
description=(
"Searches the current user's uploaded documents. "
"The tool automatically resolves the most relevant documents for the current user before chunk retrieval."
),
args_schema=VectorSearchInput,
)
tools = [vector_tool]
prompt = (
"You are a document QA agent. Prefer vector_search for questions about the user's uploaded documents. "
"Do NOT include any 'Sources' section, citation list, footnotes, chunk ids, or hashes in the final answer text. "
"Only provide the concise user-facing answer. "
"Citation metadata is handled separately by the application. "
"Do not claim evidence that is not present in tool outputs."
)
if web_search_tool is not None:
tools.append(web_search_tool)
prompt += " Use web search only when the answer depends on external or current information."
else:
prompt += " Web search is currently unavailable in this environment."
llm_with_tools = llm.bind_tools(tools)
tool_node = ToolNode(tools)
system_prompt = SystemMessage(content=prompt)
def agent_node(state: MessagesState):
response = llm_with_tools.invoke([system_prompt, *state["messages"]])
return {"messages": [response]}
graph = StateGraph(MessagesState)
graph.add_node("agent", agent_node)
graph.add_node("tools", tool_node)
graph.add_edge(START, "agent")
graph.add_conditional_edges("agent", _route_tools, {"tools": "tools", "__end__": END})
graph.add_edge("tools", "agent")
return graph.compile(checkpointer=LANGGRAPH_CHECKPOINTER)