| """ |
| Enhanced GAIA-Ready AI Agent with integrated memory and reasoning systems |
| |
| This is the main integration file that combines the agent, memory system, |
| and reasoning system into a complete solution for the Hugging Face Agents Course. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import traceback |
| from typing import List, Dict, Any, Optional, Union |
| from datetime import datetime |
|
|
| |
| try: |
| from memory_system import EnhancedMemoryManager |
| from reasoning_system import ReasoningSystem |
| except ImportError: |
| print("Error: Could not import memory_system or reasoning_system modules.") |
| print("Make sure memory_system.py and reasoning_system.py are in the same directory.") |
| sys.exit(1) |
|
|
| |
| try: |
| from smolagents import Agent, InferenceClientModel, Tool, LiteLLMModel |
| except ImportError: |
| import subprocess |
| subprocess.check_call(["pip", "install", "smolagents"]) |
| from smolagents import Agent, InferenceClientModel, Tool |
| try: |
| from smolagents import LiteLLMModel |
| except ImportError: |
| print("Warning: LiteLLMModel not available, will use InferenceClientModel only.") |
|
|
| |
| from agent import ( |
| web_search_function, |
| web_page_content_function, |
| calculator_function, |
| python_executor_function, |
| image_analyzer_function, |
| text_processor_function, |
| file_manager_function |
| ) |
|
|
|
|
| class EnhancedGAIAAgent: |
| """ |
| Enhanced AI Agent designed to perform well on the GAIA benchmark |
| Integrates memory and reasoning systems with the Think-Act-Observe workflow |
| """ |
| def __init__(self, api_key=None, use_local_model=False, use_semantic_memory=True): |
| """ |
| Initialize the enhanced GAIA agent |
| |
| Args: |
| api_key: API key for Hugging Face Inference API |
| use_local_model: Whether to use a local model via Ollama |
| use_semantic_memory: Whether to use semantic search for memory retrieval |
| """ |
| |
| self.memory_manager = EnhancedMemoryManager(use_semantic_search=use_semantic_memory) |
| |
| |
| if use_local_model: |
| |
| try: |
| self.model = LiteLLMModel( |
| model_id="ollama_chat/qwen2:7b", |
| api_base="http://127.0.0.1:11434", |
| num_ctx=8192, |
| ) |
| print("Using local Ollama model: qwen2:7b") |
| except Exception as e: |
| print(f"Error initializing local model: {str(e)}") |
| print("Falling back to Hugging Face Inference API") |
| self.model = InferenceClientModel( |
| model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
| api_key=api_key or os.environ.get("HF_API_KEY", "") |
| ) |
| print("Using Hugging Face Inference API model: Mixtral-8x7B") |
| else: |
| |
| self.model = InferenceClientModel( |
| model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
| api_key=api_key or os.environ.get("HF_API_KEY", "") |
| ) |
| print("Using Hugging Face Inference API model: Mixtral-8x7B") |
| |
| |
| self.tools = [ |
| Tool( |
| name="web_search", |
| description="Search the web for information", |
| function=web_search_function |
| ), |
| Tool( |
| name="web_page_content", |
| description="Fetch and extract content from a web page", |
| function=web_page_content_function |
| ), |
| Tool( |
| name="calculator", |
| description="Perform mathematical calculations", |
| function=calculator_function |
| ), |
| Tool( |
| name="image_analyzer", |
| description="Analyze image content", |
| function=image_analyzer_function |
| ), |
| Tool( |
| name="python_executor", |
| description="Execute Python code", |
| function=python_executor_function |
| ), |
| Tool( |
| name="text_processor", |
| description="Process and analyze text", |
| function=text_processor_function |
| ), |
| Tool( |
| name="file_manager", |
| description="Save and load data from files", |
| function=file_manager_function |
| ) |
| ] |
| |
| |
| self.system_prompt = """ |
| You are an advanced AI assistant designed to solve complex tasks from the GAIA benchmark. |
| You have access to various tools that can help you solve these tasks. |
| |
| Always follow the Think-Act-Observe workflow: |
| 1. Think: Carefully analyze the task and plan your approach |
| - Break down complex tasks into smaller steps |
| - Consider what information you need and how to get it |
| - Plan your approach before taking action |
| |
| 2. Act: Use appropriate tools to gather information or perform actions |
| - web_search: Search the web for information |
| - web_page_content: Extract content from specific web pages |
| - calculator: Perform mathematical calculations |
| - image_analyzer: Analyze image content |
| - python_executor: Run Python code for complex operations |
| - text_processor: Process and analyze text (summarize, analyze_sentiment, extract_keywords) |
| - file_manager: Save and load data from files (save, load) |
| |
| 3. Observe: Analyze the results of your actions and adjust your approach |
| - Verify if the information answers the original question |
| - Identify any gaps or inconsistencies |
| - Determine if additional actions are needed |
| |
| For complex tasks: |
| - Break them down into smaller, manageable steps |
| - Keep track of your progress and intermediate results |
| - Verify each step before moving to the next |
| - Always double-check your final answer |
| |
| When reasoning: |
| - Be thorough and methodical |
| - Consider multiple perspectives |
| - Explain your thought process clearly |
| - Cite sources when providing factual information |
| |
| Remember that the GAIA benchmark tests your ability to: |
| - Reason effectively about complex problems |
| - Understand and process multimodal information |
| - Navigate the web to find information |
| - Use tools appropriately to solve tasks |
| |
| Always verify your answers before submitting them. |
| """ |
| |
| |
| self.base_agent = Agent( |
| model=self.model, |
| tools=self.tools, |
| system_prompt=self.system_prompt |
| ) |
| |
| |
| self.reasoning_system = ReasoningSystem(self.base_agent, self.memory_manager) |
| |
| |
| self.max_retries = 3 |
| self.error_log = [] |
| |
| def solve(self, query: str, max_iterations: int = 5, verbose: bool = True) -> Dict[str, Any]: |
| """ |
| Solve a task using the enhanced Think-Act-Observe workflow |
| |
| Args: |
| query: The user's query or task |
| max_iterations: Maximum number of iterations |
| verbose: Whether to print detailed progress |
| |
| Returns: |
| Dictionary containing the final answer and metadata |
| """ |
| start_time = datetime.now() |
| |
| if verbose: |
| print(f"\n{'='*50}") |
| print(f"Starting to solve: {query}") |
| print(f"{'='*50}\n") |
| |
| try: |
| |
| final_answer = self.reasoning_system.execute_reasoning_cycle(query, max_iterations) |
| |
| |
| execution_time = (datetime.now() - start_time).total_seconds() |
| |
| if verbose: |
| print(f"\n{'='*50}") |
| print(f"Task completed in {execution_time:.2f} seconds") |
| print(f"{'='*50}\n") |
| |
| |
| memory_summary = self.memory_manager.get_memory_summary() |
| |
| return { |
| "query": query, |
| "answer": final_answer, |
| "execution_time": execution_time, |
| "iterations": max_iterations, |
| "memory_summary": memory_summary, |
| "success": True, |
| "error": None |
| } |
| except Exception as e: |
| error_msg = f"Error solving task: {str(e)}\n{traceback.format_exc()}" |
| print(error_msg) |
| |
| |
| self.error_log.append({ |
| "timestamp": datetime.now().isoformat(), |
| "query": query, |
| "error": str(e), |
| "traceback": traceback.format_exc() |
| }) |
| |
| |
| try: |
| recovery_prompt = f""" |
| I encountered an error while trying to solve this task: {query} |
| |
| The error was: {str(e)} |
| |
| Based on what I know so far, please provide the best possible answer or explanation. |
| If you can't provide a complete answer, explain what you do know and what information is missing. |
| """ |
| recovery_answer = self.base_agent.chat(recovery_prompt) |
| |
| execution_time = (datetime.now() - start_time).total_seconds() |
| |
| if verbose: |
| print(f"\n{'='*50}") |
| print(f"Task completed with recovery in {execution_time:.2f} seconds") |
| print(f"{'='*50}\n") |
| |
| return { |
| "query": query, |
| "answer": recovery_answer, |
| "execution_time": execution_time, |
| "iterations": 0, |
| "success": False, |
| "error": str(e), |
| "recovery": True |
| } |
| except Exception as recovery_error: |
| |
| return { |
| "query": query, |
| "answer": f"I'm sorry, I encountered an error while solving this task and couldn't recover: {str(e)}", |
| "execution_time": (datetime.now() - start_time).total_seconds(), |
| "iterations": 0, |
| "success": False, |
| "error": str(e), |
| "recovery_error": str(recovery_error), |
| "recovery": False |
| } |
| |
| def batch_solve(self, queries: List[str], max_iterations: int = 5, verbose: bool = True) -> List[Dict[str, Any]]: |
| """ |
| Solve multiple tasks in batch |
| |
| Args: |
| queries: List of user queries or tasks |
| max_iterations: Maximum number of iterations per query |
| verbose: Whether to print detailed progress |
| |
| Returns: |
| List of results for each query |
| """ |
| results = [] |
| |
| for i, query in enumerate(queries): |
| if verbose: |
| print(f"\n{'='*50}") |
| print(f"Processing task {i+1}/{len(queries)}: {query}") |
| print(f"{'='*50}\n") |
| |
| result = self.solve(query, max_iterations, verbose) |
| results.append(result) |
| |
| |
| self.memory_manager.clear_working_memory() |
| |
| return results |
| |
| def save_results(self, results: Union[Dict[str, Any], List[Dict[str, Any]]], filename: str = "gaia_results.json") -> None: |
| """ |
| Save results to a file |
| |
| Args: |
| results: Results from solve() or batch_solve() |
| filename: Name of the file to save results to |
| """ |
| try: |
| with open(filename, 'w') as f: |
| json.dump(results, f, indent=2) |
| |
| print(f"Results saved to {filename}") |
| except Exception as e: |
| print(f"Error saving results: {str(e)}") |
| |
| def load_results(self, filename: str = "gaia_results.json") -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
| """ |
| Load results from a file |
| |
| Args: |
| filename: Name of the file to load results from |
| |
| Returns: |
| Loaded results |
| """ |
| try: |
| with open(filename, 'r') as f: |
| results = json.load(f) |
| |
| print(f"Results loaded from {filename}") |
| return results |
| except Exception as e: |
| print(f"Error loading results: {str(e)}") |
| return [] |
| |
| def evaluate_performance(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: |
| """ |
| Evaluate performance metrics from batch results |
| |
| Args: |
| results: Results from batch_solve() |
| |
| Returns: |
| Dictionary of performance metrics |
| """ |
| if not results: |
| return {"error": "No results to evaluate"} |
| |
| total_queries = len(results) |
| successful_queries = sum(1 for r in results if r.get("success", False)) |
| recovery_queries = sum(1 for r in results if not r.get("success", False) and r.get("recovery", False)) |
| failed_queries = total_queries - successful_queries - recovery_queries |
| |
| avg_execution_time = sum(r.get("execution_time", 0) for r in results) / total_queries |
| |
| return { |
| "total_queries": total_queries, |
| "successful_queries": successful_queries, |
| "recovery_queries": recovery_queries, |
| "failed_queries": failed_queries, |
| "success_rate": successful_queries / total_queries if total_queries > 0 else 0, |
| "recovery_rate": recovery_queries / total_queries if total_queries > 0 else 0, |
| "failure_rate": failed_queries / total_queries if total_queries > 0 else 0, |
| "avg_execution_time": avg_execution_time |
| } |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| agent = EnhancedGAIAAgent(use_local_model=False, use_semantic_memory=True) |
| |
| |
| sample_queries = [ |
| "What is the capital of France and what is its population? Also, calculate 15% of this population.", |
| "Who was the first person to walk on the moon? What year did this happen?", |
| "Explain the concept of photosynthesis in simple terms." |
| ] |
| |
| |
| print("\nSolving single query...") |
| result = agent.solve(sample_queries[0]) |
| print("\nFinal Answer:") |
| print(result["answer"]) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|