| """Retrieval Agent - Handles information gathering and search tasks""" |
| import os |
| import requests |
| from typing import Dict, Any, List |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage |
| from langchain_core.tools import tool |
| from langchain_groq import ChatGroq |
| from langchain_community.tools.tavily_search import TavilySearchResults |
| from langchain_community.document_loaders import WikipediaLoader, ArxivLoader |
| from langchain.tools.retriever import create_retriever_tool |
| from src.memory import memory_manager |
| from src.tracing import get_langfuse_callback_handler |
|
|
|
|
| |
| @tool |
| def wiki_search(input: str) -> str: |
| """Search Wikipedia for a query and return maximum 2 results. |
| |
| Args: |
| input: The search query.""" |
| try: |
| search_docs = WikipediaLoader(query=input, load_max_docs=2).load() |
| if not search_docs: |
| return "No Wikipedia results found for the query." |
| formatted_search_docs = "\n\n---\n\n".join( |
| [ |
| f'<Document source="{doc.metadata.get("source", "Unknown")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
| for doc in search_docs |
| ]) |
| return formatted_search_docs |
| except Exception as e: |
| print(f"Error in wiki_search: {e}") |
| return f"Error searching Wikipedia: {e}" |
|
|
|
|
| @tool |
| def web_search(input: str) -> str: |
| """Search Tavily for a query and return maximum 3 results. |
| |
| Args: |
| input: The search query.""" |
| try: |
| search_docs = TavilySearchResults(max_results=3).invoke(input) |
| if not search_docs: |
| return "No web search results found for the query." |
| formatted_search_docs = "\n\n---\n\n".join( |
| [ |
| f'<Document source="{doc.get("url", "Unknown")}" />\n{doc.get("content", "No content")}\n</Document>' |
| for doc in search_docs |
| ]) |
| return formatted_search_docs |
| except Exception as e: |
| print(f"Error in web_search: {e}") |
| return f"Error searching web: {e}" |
|
|
|
|
| @tool |
| def arvix_search(input: str) -> str: |
| """Search Arxiv for a query and return maximum 3 results. |
| |
| Args: |
| input: The search query.""" |
| try: |
| search_docs = ArxivLoader(query=input, load_max_docs=3).load() |
| if not search_docs: |
| return "No Arxiv results found for the query." |
| formatted_search_docs = "\n\n---\n\n".join( |
| [ |
| f'<Document source="{doc.metadata.get("source", "Unknown")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' |
| for doc in search_docs |
| ]) |
| return formatted_search_docs |
| except Exception as e: |
| print(f"Error in arvix_search: {e}") |
| return f"Error searching Arxiv: {e}" |
|
|
|
|
| def load_retrieval_prompt() -> str: |
| """Load the retrieval prompt from file""" |
| try: |
| with open("./prompts/retrieval_prompt.txt", "r", encoding="utf-8") as f: |
| return f.read().strip() |
| except FileNotFoundError: |
| return """You are a specialized retrieval agent. Use available tools to search for information and provide comprehensive answers.""" |
|
|
|
|
| def get_retrieval_tools() -> List: |
| """Get list of tools available to the retrieval agent""" |
| tools = [wiki_search, web_search, arvix_search] |
| |
| |
| if memory_manager.vector_store: |
| try: |
| retrieval_tool = create_retriever_tool( |
| retriever=memory_manager.vector_store.as_retriever(), |
| name="question_search", |
| description="A tool to retrieve similar questions from a vector store.", |
| ) |
| tools.append(retrieval_tool) |
| except Exception as e: |
| print(f"Could not create retrieval tool: {e}") |
| |
| return tools |
|
|
|
|
| def execute_tool_calls(tool_calls: list, tools: list) -> list: |
| """Execute tool calls and return results""" |
| tool_messages = [] |
| |
| |
| tool_map = {tool.name: tool for tool in tools} |
| |
| for tool_call in tool_calls: |
| tool_name = tool_call['name'] |
| tool_args = tool_call['args'] |
| tool_call_id = tool_call['id'] |
| |
| if tool_name in tool_map: |
| try: |
| print(f"Retrieval Agent: Executing {tool_name} with args: {tool_args}") |
| result = tool_map[tool_name].invoke(tool_args) |
| tool_messages.append( |
| ToolMessage( |
| content=str(result), |
| tool_call_id=tool_call_id |
| ) |
| ) |
| except Exception as e: |
| print(f"Error executing {tool_name}: {e}") |
| tool_messages.append( |
| ToolMessage( |
| content=f"Error executing {tool_name}: {e}", |
| tool_call_id=tool_call_id |
| ) |
| ) |
| else: |
| tool_messages.append( |
| ToolMessage( |
| content=f"Unknown tool: {tool_name}", |
| tool_call_id=tool_call_id |
| ) |
| ) |
| |
| return tool_messages |
|
|
|
|
| def fetch_attachment_if_needed(query: str) -> str: |
| """Fetch attachment content if the query matches a known task""" |
| try: |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
| resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=30) |
| resp.raise_for_status() |
| questions = resp.json() |
| |
| for q in questions: |
| if str(q.get("question")).strip() == str(query).strip(): |
| task_id = str(q.get("task_id")) |
| print(f"Retrieval Agent: Downloading attachment for task {task_id}") |
| file_resp = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=60) |
| if file_resp.status_code == 200 and file_resp.content: |
| try: |
| file_text = file_resp.content.decode("utf-8", errors="replace") |
| except Exception: |
| file_text = "(binary or non-UTF8 file omitted)" |
| MAX_CHARS = 8000 |
| if len(file_text) > MAX_CHARS: |
| file_text = file_text[:MAX_CHARS] + "\n… (truncated)" |
| return f"Attached file content for task {task_id}:\n```python\n{file_text}\n```" |
| else: |
| print(f"No attachment for task {task_id}") |
| return "" |
| return "" |
| except Exception as e: |
| print(f"Error fetching attachment: {e}") |
| return "" |
|
|
|
|
| def retrieval_agent(state: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Retrieval agent that handles information gathering tasks |
| """ |
| print("Retrieval Agent: Processing information retrieval request") |
| |
| try: |
| |
| retrieval_prompt = load_retrieval_prompt() |
| |
| |
| llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3) |
| tools = get_retrieval_tools() |
| llm_with_tools = llm.bind_tools(tools) |
| |
| |
| callback_handler = get_langfuse_callback_handler() |
| callbacks = [callback_handler] if callback_handler else [] |
| |
| |
| messages = state.get("messages", []) |
| |
| |
| retrieval_messages = [SystemMessage(content=retrieval_prompt)] |
| |
| |
| user_query = None |
| for msg in reversed(messages): |
| if msg.type == "human": |
| user_query = msg.content |
| break |
| |
| |
| if user_query: |
| similar_qa = memory_manager.get_similar_qa(user_query) |
| if similar_qa: |
| context_msg = HumanMessage( |
| content=f"Here is a similar question and answer for reference:\n\n{similar_qa}" |
| ) |
| retrieval_messages.append(context_msg) |
| |
| |
| attachment_content = fetch_attachment_if_needed(user_query) |
| if attachment_content: |
| attachment_msg = HumanMessage(content=attachment_content) |
| retrieval_messages.append(attachment_msg) |
| |
| |
| for msg in messages: |
| if msg.type != "system": |
| retrieval_messages.append(msg) |
| |
| |
| response = llm_with_tools.invoke(retrieval_messages, config={"callbacks": callbacks}) |
|
|
| max_tool_iterations = 3 |
| iteration = 0 |
|
|
| while response.tool_calls and iteration < max_tool_iterations: |
| iteration += 1 |
| print(f"Retrieval Agent: LLM requested {len(response.tool_calls)} tool calls (iteration {iteration})") |
|
|
| |
| tool_messages = execute_tool_calls(response.tool_calls, tools) |
|
|
| |
| retrieval_messages.extend([response] + tool_messages) |
|
|
| |
| response = llm_with_tools.invoke(retrieval_messages, config={"callbacks": callbacks}) |
|
|
| |
| retrieval_messages.append(response) |
|
|
| return { |
| **state, |
| "messages": retrieval_messages, |
| "agent_response": response, |
| "current_step": "verification" |
| } |
| |
| except Exception as e: |
| print(f"Retrieval Agent Error: {e}") |
| error_response = AIMessage(content=f"I encountered an error while processing your request: {e}") |
| return { |
| **state, |
| "messages": state.get("messages", []) + [error_response], |
| "agent_response": error_response, |
| "current_step": "verification" |
| } |